In [2]:
%load_ext autoreload
%autoreload 2
import torch
from tqdm.auto import tqdm
from src.common import MaestroSplitType
from torch.utils.data import DataLoader
from src.maestro2 import MaestroDatasetSplit, FrameContextDataset, DynamicBatchIterableDataset2, custom_collate_fn
from torch.nn import MSELoss


dataset = MaestroDatasetSplit(MaestroSplitType.TRAIN)


Creating 2006/MIDI-Unprocessed_17_R1_2006_01-06_ORIG_MID--AUDIO_17_R1_2006_04_Track04_wav.wav (140292477785328)
Creating 2009/MIDI-Unprocessed_07_R1_2009_04-05_ORIG_MID--AUDIO_07_R1_2009_07_R1_2009_04_WAV.wav (140292474063952)
Creating 2009/MIDI-Unprocessed_02_R1_2009_03-06_ORIG_MID--AUDIO_02_R1_2009_02_R1_2009_05_WAV.wav (140292478688704)
Creating 2009/MIDI-Unprocessed_11_R1_2009_06-09_ORIG_MID--AUDIO_11_R1_2009_11_R1_2009_06_WAV.wav (140292474279776)
Creating 2013/ORIG-MIDI_01_7_6_13_Group__MID--AUDIO_03_R1_2013_wav--4.wav (140297121328416)
Creating 2014/MIDI-UNPROCESSED_14-15_R1_2014_MID--AUDIO_15_R1_2014_wav--3.wav (140297121327264)
Creating 2014/MIDI-UNPROCESSED_14-15_R1_2014_MID--AUDIO_15_R1_2014_wav--4.wav (140297121328320)
Creating 2006/MIDI-Unprocessed_22_R1_2006_01-04_ORIG_MID--AUDIO_22_R1_2006_02_Track02_wav.wav (140297121327312)
Creating 2018/MIDI-Unprocessed_Recital8_MID--AUDIO_08_R1_2018_wav--2.wav (140297121317760)
Creating 2018/MIDI-Unprocessed_Recital9-11_MID--AUDIO_10

In [3]:
n_context = 21
n_predict = 3
dataset2 = FrameContextDataset(dataset, n_context, n_predict)

In [4]:

batch_size = 32
num_workers = 4
epochs = 10
learning_rate = 0.001
device = 'cuda' if torch.cuda.is_available() else 'cpu'
wrapped_dataset = DynamicBatchIterableDataset2(dataset2, batch_size)
data_loader = DataLoader(
    wrapped_dataset, 
    batch_size=1,  # Let the collate_fn handle the final batching
    collate_fn=custom_collate_fn,
    num_workers=num_workers,
    prefetch_factor=(dataset2[0][0].shape[0]*4) // (batch_size * num_workers),
    multiprocessing_context='spawn',
    pin_memory=True,
    pin_memory_device=device,
)

Loading audio 2006/MIDI-Unprocessed_17_R1_2006_01-06_ORIG_MID--AUDIO_17_R1_2006_04_Track04_wav.wav (140292477785328)
Batch size: 12466


In [7]:
from model import get_model
model = get_model(n_predict, device)

loss = MSELoss()
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
## Pass dummy batch
x, y = next(iter(dataset2))
x = x.unsqueeze(1)[0:1].cuda()
y = y[0:1].cuda()
y_ = model.forward(x)

Batch size: 12466
torch.Size([1, 1, 229, 21])
torch.Size([1, 128, 3]) torch.Size([1, 128, 3])


In [8]:
import wandb


In [9]:
wandb.init(
    # set the wandb project where this run will be logged
    project="realtime-piano-transcription",

    # track hyperparameters and run metadata
    config={
        "learning_rate": lr,
        "architecture": "Basic-CNN-v2",
        "dataset": "MAESTRO-Validation",
        "epochs": epochs,
    }
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfedericowilliamson[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
from wandb import watch

watch(model)

[]

In [12]:
model = model.to(device)

for epoch in range(epochs):
    # Train
    with tqdm(total=len(data_loader)) as pbar:
        total_loss = 0
        for idx, (x, y) in enumerate(data_loader):
            x = x.unsqueeze(1).cuda()
            y = y.cuda()
            optimizer.zero_grad()
            output = model(x)
            loss_val = loss(output, y)
            loss_val.backward()
            optimizer.step()
            total_loss += loss_val.item()
            pbar.update(1)
            wandb.log({'train_loss': loss_val.item(), 'epoch': epoch, 'batch': idx})
            if idx % 100 == 0:
                text = f'Epoch {epoch} - Loss: {total_loss / 100}'
                pbar.set_description(text)
                print(text)
                total_loss = 0


  0%|          | 0/68456 [00:00<?, ?it/s]

Epoch 0 - Loss: 0.5473227691650391
Epoch 0 - Loss: 101.984193983078
Epoch 0 - Loss: 118.9992817889154
Epoch 0 - Loss: 121.81392496585846
Epoch 0 - Loss: 112.85234426498413
Epoch 0 - Loss: 93.02914060279727
Epoch 0 - Loss: 130.32944813728332
Epoch 0 - Loss: 149.05300729751588
Epoch 0 - Loss: 99.72540678501129
Epoch 0 - Loss: 97.53247800827026
Epoch 0 - Loss: 109.93645703315735
Epoch 0 - Loss: 118.76010772705078
Epoch 0 - Loss: 143.4345231437683
Epoch 0 - Loss: 197.1595188140869
Epoch 0 - Loss: 180.51504091262817
Epoch 0 - Loss: 216.58167566299437
Epoch 0 - Loss: 160.08410074710847
Epoch 0 - Loss: 92.95043471306562
Epoch 0 - Loss: 150.02301519393922
Epoch 0 - Loss: 131.79983839035035
Epoch 0 - Loss: 142.65334775924683
Epoch 0 - Loss: 151.31156542539597
Epoch 0 - Loss: 169.43529619693757
Epoch 0 - Loss: 193.39922052383423
Epoch 0 - Loss: 312.2622915649414
Epoch 0 - Loss: 197.85100730895996
Epoch 0 - Loss: 144.86228741168975
Epoch 0 - Loss: 79.7053432226181
Epoch 0 - Loss: 77.3953785216808

  0%|          | 0/68456 [00:00<?, ?it/s]

Epoch 1 - Loss: 0.5347832870483399
Epoch 1 - Loss: 101.98178053855897
Epoch 1 - Loss: 118.9992817889154
Epoch 1 - Loss: 121.81392496585846
Epoch 1 - Loss: 112.85234426498413
Epoch 1 - Loss: 93.02914060279727
Epoch 1 - Loss: 130.32944813728332
Epoch 1 - Loss: 149.05300729751588
Epoch 1 - Loss: 99.72540678501129
Epoch 1 - Loss: 97.53247800827026
Epoch 1 - Loss: 109.93645703315735
Epoch 1 - Loss: 118.76010772705078
Epoch 1 - Loss: 143.4345231437683
Epoch 1 - Loss: 197.1595188140869
Epoch 1 - Loss: 180.51504091262817
Epoch 1 - Loss: 216.58167566299437
Epoch 1 - Loss: 160.08410074710847
Epoch 1 - Loss: 92.95043471306562
Epoch 1 - Loss: 150.02301519393922
Epoch 1 - Loss: 131.79983839035035
Epoch 1 - Loss: 142.65334775924683
Epoch 1 - Loss: 151.31156542539597
Epoch 1 - Loss: 169.43529619693757
Epoch 1 - Loss: 193.39922052383423
Epoch 1 - Loss: 312.2622915649414
Epoch 1 - Loss: 197.85100730895996
Epoch 1 - Loss: 144.86228741168975
Epoch 1 - Loss: 79.7053432226181
Epoch 1 - Loss: 77.39537852168

  0%|          | 0/68456 [00:00<?, ?it/s]

Epoch 2 - Loss: 0.5347832870483399
Epoch 2 - Loss: 101.98178053855897
Epoch 2 - Loss: 118.9992817889154
Epoch 2 - Loss: 121.81392496585846
Epoch 2 - Loss: 112.85234426498413
Epoch 2 - Loss: 93.02914060279727
Epoch 2 - Loss: 130.32944813728332
Epoch 2 - Loss: 149.05300729751588
Epoch 2 - Loss: 99.72540678501129
Epoch 2 - Loss: 97.53247800827026
Epoch 2 - Loss: 109.93645703315735
Epoch 2 - Loss: 118.76010772705078
Epoch 2 - Loss: 143.4345231437683
Epoch 2 - Loss: 197.1595188140869
Epoch 2 - Loss: 180.51504091262817
Epoch 2 - Loss: 216.58167566299437
Epoch 2 - Loss: 160.08410074710847
Epoch 2 - Loss: 92.95043471306562
Epoch 2 - Loss: 150.02301519393922
Epoch 2 - Loss: 131.79983839035035
Epoch 2 - Loss: 142.65334775924683
Epoch 2 - Loss: 151.31156542539597
Epoch 2 - Loss: 169.43529619693757
Epoch 2 - Loss: 193.39922052383423
Epoch 2 - Loss: 312.2622915649414
Epoch 2 - Loss: 197.85100730895996
Epoch 2 - Loss: 144.86228741168975
Epoch 2 - Loss: 79.7053432226181
Epoch 2 - Loss: 77.39537852168

KeyboardInterrupt: 

In [15]:
wandb.finish()

VBox(children=(Label(value='0.109 MB of 0.109 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
batch,▁▁▂▂▂▂▂▃▄▄▄▄▄▄▄▅▆▆▆▆▇▇███▁▂▂▂▃▃▃▄▄▅▆▇▇▇▃
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅███
train_loss,▃▂▂▃▁▂▂▂▄▁█▅▃▃▄▄▄▂▃▆▂▂▂▃▂▅▁▆▃▁▂▃▃▂▃▁▂▂▁▄

0,1
batch,16808.0
epoch,2.0
train_loss,1853.11316


In [None]:
dataset.split.entries[1].csv_duration

In [None]:
(388 + 2)*32