In [1]:
import os
from datetime import datetime
import pickle

import numpy as np
from sacred import Experiment
from sacred.commands import print_config, save_config
from sacred.observers import FileStorageObserver
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from evaluate import evaluate, evaluate_wo_velocity # These two lines requires GPU
from onsets_and_frames import *
from onsets_and_frames.transcriber import OnsetsAndFrames_TCN, OnsetsAndFrames_biTCN
ex = Experiment('train_transcriber', interactive=True)

STFT filter created, time used = 0.2073 seconds
Mel filter created, time used = 0.0053 seconds


In [2]:
batch_size = 1

In [3]:
def save_dict(obj, name ):
    with open('./runs/config/'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name):
    with open('./runs/config/' + name + '.pkl', 'rb') as f:
        return pickle.load(f)

TCN_layers = [600, 500, 400, 300, 200, 100, 90]
logdir = 'runs/BiTCN-' + ','.join(str(x) for x in TCN_layers) + '-' + datetime.now().strftime('%y%m%d-%H%M%S')
device = f'cuda' if torch.cuda.is_available() else 'cpu'
iterations = 500000
resume_iteration = None
checkpoint_interval = 1000
train_on = 'MAESTRO'

batch_size = 8
sequence_length = 327680
model_complexity = 48
if torch.cuda.is_available() and torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory < 10e9:
    batch_size //= 2
    sequence_length //= 2
    print(f'Reducing batch size to {batch_size} and sequence_length to {sequence_length} to save memory')

learning_rate = 5e-4
learning_rate_decay_steps = 300
learning_rate_decay_rate = 0.98

leave_one_out = None

clip_gradient_norm = 3

validation_length = sequence_length
validation_interval = 500

refresh = True

In [4]:
os.makedirs(logdir, exist_ok=True)
writer = SummaryWriter(logdir)

train_groups, validation_groups = ['train'], ['validation']

In [5]:
dataset = MAPS(groups=['AkPnBcht', 'AkPnBsdf', 'AkPnCGdD', 'AkPnStgb', 'SptkBGAm', 'SptkBGCl', 'StbgTGd2'], sequence_length=sequence_length, refresh=refresh)
validation_dataset = MAPS(groups=['ENSTDkAm', 'ENSTDkCl'], sequence_length=validation_length, refresh=refresh)

Loading group AkPnBcht:   0%|          | 0/30 [00:00<?, ?it/s]

Loading 7 groups of MAPS at data/MAPS


Loading group AkPnBcht: 100%|██████████| 30/30 [00:06<00:00,  4.85it/s]
Loading group AkPnBsdf: 100%|██████████| 30/30 [00:06<00:00,  2.52it/s]
Loading group AkPnCGdD: 100%|██████████| 30/30 [00:05<00:00,  5.43it/s]
Loading group AkPnStgb: 100%|██████████| 30/30 [00:09<00:00,  2.43it/s]
Loading group SptkBGAm: 100%|██████████| 30/30 [00:05<00:00,  5.85it/s]
Loading group SptkBGCl: 100%|██████████| 30/30 [00:07<00:00,  2.64it/s]
Loading group StbgTGd2: 100%|██████████| 30/30 [00:08<00:00,  5.01it/s]
Loading group ENSTDkAm:   0%|          | 0/30 [00:00<?, ?it/s]

Loading 2 groups of MAPS at data/MAPS


Loading group ENSTDkAm: 100%|██████████| 30/30 [00:08<00:00,  2.65it/s]
Loading group ENSTDkCl: 100%|██████████| 30/30 [00:08<00:00,  5.07it/s]


In [6]:
loader = DataLoader(dataset, batch_size, shuffle=True, drop_last=True)

In [7]:
next(iter(loader))

step_begin = 476
result['audio'].shape = torch.Size([327680])
result['label'].shape = torch.Size([1280, 88])
step_begin = 8707
result['audio'].shape = torch.Size([327680])
result['label'].shape = torch.Size([1280, 88])
step_begin = 30101
result['audio'].shape = torch.Size([327680])
result['label'].shape = torch.Size([1280, 88])
step_begin = 1012
result['audio'].shape = torch.Size([327680])
result['label'].shape = torch.Size([1280, 88])
step_begin = 430
result['audio'].shape = torch.Size([327680])
result['label'].shape = torch.Size([1280, 88])
step_begin = 19244
result['audio'].shape = torch.Size([327680])
result['label'].shape = torch.Size([1280, 88])
step_begin = 4310
result['audio'].shape = torch.Size([327680])
result['label'].shape = torch.Size([1280, 88])
step_begin = 8728
result['audio'].shape = torch.Size([327680])
result['label'].shape = torch.Size([1280, 88])


{'path': ['data/MAPS/flac/MAPS_MUS-burg_quelle_SptkBGAm.flac',
  'data/MAPS/flac/MAPS_MUS-bach_850_AkPnBsdf.flac',
  'data/MAPS/flac/MAPS_MUS-mz_330_1_SptkBGCl.flac',
  'data/MAPS/flac/MAPS_MUS-chpn-p6_SptkBGCl.flac',
  'data/MAPS/flac/MAPS_MUS-chpn-p3_AkPnBcht.flac',
  'data/MAPS/flac/MAPS_MUS-bk_xmas1_StbgTGd2.flac',
  'data/MAPS/flac/MAPS_MUS-chpn-p8_AkPnBcht.flac',
  'data/MAPS/flac/MAPS_MUS-appass_3_AkPnStgb.flac'],
 'audio': tensor([[ 0.0127,  0.0175,  0.0226,  ...,  0.0351,  0.0340,  0.0323],
         [ 0.0033,  0.0027,  0.0021,  ..., -0.0040,  0.0010,  0.0087],
         [-0.0097, -0.0088, -0.0070,  ..., -0.0490, -0.0471, -0.0354],
         ...,
         [-0.0040, -0.0046, -0.0054,  ..., -0.0107, -0.0119, -0.0129],
         [ 0.0106,  0.0139,  0.0150,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0164, -0.0121, -0.0092,  ..., -0.0155, -0.0148, -0.0136]],
        device='cuda:0'),
 'label': tensor([[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0