In [1]:
import os
from datetime import datetime

import numpy as np
from sacred import Experiment
from sacred.commands import print_config
from sacred.observers import FileStorageObserver
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from evaluate import evaluate, evaluate_wo_velocity # These two lines requires GPU
from onsets_and_frames import *

STFT filter created, time used = 0.2081 seconds
Mel filter created, time used = 0.0051 seconds


In [2]:
logdir = 'runs/transcriber-' + datetime.now().strftime('%y%m%d-%H%M%S')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
iterations = 500000
resume_iteration = None
checkpoint_interval = 1000
train_on = 'MAPS'

batch_size = 8
sequence_length = 327680
model_complexity = 48

if torch.cuda.is_available() and torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory < 10e9:
    batch_size //= 2
    sequence_length //= 2
    print(f'Reducing batch size to {batch_size} and sequence_length to {sequence_length} to save memory')

learning_rate = 0.0006
learning_rate_decay_steps = 10000
learning_rate_decay_rate = 0.98

leave_one_out = None

clip_gradient_norm = 3

validation_length = sequence_length
validation_interval = 500




os.makedirs(logdir, exist_ok=True)
writer = SummaryWriter(logdir)

train_groups, validation_groups = ['train'], ['validation']

if leave_one_out is not None:
    all_years = {'2004', '2006', '2008', '2009', '2011', '2013', '2014', '2015', '2017'}
    train_groups = list(all_years - {str(leave_one_out)})
    validation_groups = [str(leave_one_out)]

if train_on == 'MAESTRO':
    dataset = MAESTRO(groups=train_groups, sequence_length=sequence_length)
    validation_dataset = MAESTRO(groups=validation_groups, sequence_length=sequence_length)
else:
    dataset = MAPS(groups=['AkPnBcht', 'AkPnBsdf', 'AkPnCGdD', 'AkPnStgb', 'SptkBGAm', 'SptkBGCl', 'StbgTGd2'], sequence_length=sequence_length)
    validation_dataset = MAPS(groups=['ENSTDkAm', 'ENSTDkCl'], sequence_length=validation_length)

loader = DataLoader(dataset, batch_size, shuffle=True, drop_last=True)

if resume_iteration is None:
    model = OnsetsAndFrames(N_MELS, MAX_MIDI - MIN_MIDI + 1, model_complexity).to(device)
    optimizer = torch.optim.Adam(model.parameters(), learning_rate)
    resume_iteration = 0
else:
    model_path = os.path.join(logdir, f'model-{resume_iteration}.pt')
    model = torch.load(model_path)
    optimizer = torch.optim.Adam(model.parameters(), learning_rate)
    optimizer.load_state_dict(torch.load(os.path.join(logdir, 'last-optimizer-state.pt')))

summary(model)
scheduler = StepLR(optimizer, step_size=learning_rate_decay_steps, gamma=learning_rate_decay_rate)

Loading group AkPnBcht: 100%|██████████| 30/30 [00:00<00:00, 291.29it/s]
Loading group AkPnBsdf:   0%|          | 0/30 [00:00<?, ?it/s]

Loading 7 groups of MAPS at data/MAPS


Loading group AkPnBsdf: 100%|██████████| 30/30 [00:00<00:00, 253.12it/s]
Loading group AkPnCGdD: 100%|██████████| 30/30 [00:00<00:00, 271.50it/s]
Loading group AkPnStgb: 100%|██████████| 30/30 [00:00<00:00, 199.04it/s]
Loading group SptkBGAm: 100%|██████████| 30/30 [00:00<00:00, 299.94it/s]
Loading group SptkBGCl: 100%|██████████| 30/30 [00:00<00:00, 219.53it/s]
Loading group StbgTGd2: 100%|██████████| 30/30 [00:00<00:00, 208.81it/s]
Loading group ENSTDkAm: 100%|██████████| 30/30 [00:00<00:00, 226.55it/s]
Loading group ENSTDkCl:   0%|          | 0/30 [00:00<?, ?it/s]

Loading 2 groups of MAPS at data/MAPS


Loading group ENSTDkCl: 100%|██████████| 30/30 [00:00<00:00, 208.69it/s]


OnsetsAndFrames(
  (onset_stack): Sequential(
    (0): ConvStack(
      (cnn): Sequential(
        (0): Conv2d(1, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), [92m480[0m params
        (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), [92m96[0m params
        (2): ReLU(), [92m0[0m params
        (3): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), [92m20,784[0m params
        (4): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), [92m96[0m params
        (5): ReLU(), [92m0[0m params
        (6): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False), [92m0[0m params
        (7): Dropout(p=0.25, inplace=False), [92m0[0m params
        (8): Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), [92m41,568[0m params
        (9): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), [92m192[0m params
      

In [9]:
eval_dict = evaluate_wo_velocity(validation_dataset, model).items()

In [34]:
category

'frame'

In [36]:
category

'frame'

In [37]:
if name

'chroma_total_error'

In [41]:
name

'chroma_total_error'

In [46]:
'f1' in 'note f1'

True

In [50]:
for key, values in eval_dict:
    if key.startswith('metric/'):
        _, category, name = key.split('/')
        if ('precision' in name or 'recall' in name or 'f1' in name) and 'chroma' not in name:
            print(f'{category:>32} {name:25}: {np.mean(values):.3f} ± {np.std(values):.3f}')

                            note precision                : 0.000 ± 0.000
                            note recall                   : 0.000 ± 0.000
                            note f1                       : 0.000 ± 0.000
               note-with-offsets precision                : 0.000 ± 0.000
               note-with-offsets recall                   : 0.000 ± 0.000
               note-with-offsets f1                       : 0.000 ± 0.000
                           frame f1                       : 0.000 ± 0.000
                           frame precision                : 0.000 ± 0.000
                           frame recall                   : 0.000 ± 0.000


In [5]:
# loop = tqdm(range(resume_iteration + 1, iterations + 1))
epoches = 100
total_batch = len(loader.dataset)
for ep in range(1, epoches):
    model.train()
    total_loss = 0
    batch_idx = 0
    for batch in loader:
        predictions, losses = model.run_on_batch(batch)

        loss = sum(losses.values())
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        if clip_gradient_norm:
            clip_grad_norm_(model.parameters(), clip_gradient_norm)
        batch_idx += 1
        print(f'Train Epoch: {ep} [{batch_idx*batch_size}/{total_batch}'
                f'({100. * batch_idx*batch_size / total_batch:.0f}%)]'
                f'\tLoss: {loss.item():.6f}'
                , end='\r') 
    print(' '*100, end = '\r')            
    print(f'Train Epoch: {ep}\tLoss: {total_loss/len(loader):.6f}')

    if ep%10 == 0:
        model.eval()
        with torch.no_grad():
            for key, values in evaluate_wo_velocity(validation_dataset, model).items():
                if key.startswith('metric/'):
                        _, category, name = key.split('/')
                        print(f'{category:>32} {name:25}: {np.mean(values):.3f} ± {np.std(values):.3f}')


    # for key, value in {'loss': loss, **losses}.items():
    #     writer.add_scalar(key, value.item(), global_step=i)

    # if i % validation_interval == 0:
    #     model.eval()
    #     with torch.no_grad():
    #         for key, value in evaluate(validation_dataset, model).items():
    #             writer.add_scalar('validation/' + key.replace(' ', '_'), np.mean(value), global_step=i)
    #     model.train()

    # if i % checkpoint_interval == 0:
    #     torch.save(model, os.path.join(logdir, f'model-{i}.pt'))
    #     torch.save(optimizer.state_dict(), os.path.join(logdir, 'last-optimizer-state.pt'))

Train Epoch: 1	Loss: 0.362543                                                                       
Train Epoch: 2	Loss: 0.156357                                                                       
Train Epoch: 3	Loss: 0.155357                                                                       
Train Epoch: 4	Loss: 0.154539                                                                       
Train Epoch: 5	Loss: 0.151398                                                                       

KeyboardInterrupt: 

In [None]:
for key, value in {'loss': loss, **losses}.items():
    print(key)

In [None]:
{'loss': loss, **losses}.items()

In [16]:
{'loss': loss, **losses}

{'loss': tensor(0.1346, device='cuda:0', grad_fn=<AddBackward0>),
 'loss/onset': tensor(0.0255, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>),
 'loss/frame': tensor(0.1091, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)}

In [17]:
losses

{'loss/onset': tensor(0.0255, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>),
 'loss/frame': tensor(0.1091, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)}

In [14]:
for value in losses:
    print(value)

loss/onset
loss/frame
