In [None]:
%load_ext autoreload
%autoreload 2
import torch
from tqdm.auto import tqdm
from src.common import MaestroSplitType
from torch.utils.data import DataLoader
from src.maestro2 import MaestroDatasetSplit, FrameContextDataset, DynamicBatchIterableDataset2, custom_collate_fn
from torch.nn import MSELoss, BCEWithLogitsLoss, CrossEntropyLoss

dataset = MaestroDatasetSplit(MaestroSplitType.TRAIN)
print(len(dataset.split.entries))
dataset.split.entries = dataset.split.entries[:10]
dataset.split.df_entries = dataset.split.df_entries[:10]

In [None]:
n_context = 32
n_predict = 32
dataset2 = FrameContextDataset(dataset, n_context, n_predict)

In [None]:

batch_size = 32
num_workers = 2
epochs = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'
wrapped_dataset = DynamicBatchIterableDataset2(dataset2, batch_size)
data_loader = DataLoader(
    wrapped_dataset, 
    batch_size=1,  # Let the collate_fn handle the final batching
    collate_fn=custom_collate_fn,
    num_workers=num_workers,
    prefetch_factor=(dataset2[0][0].shape[0]*4) // (batch_size * num_workers),
    multiprocessing_context='spawn',
    pin_memory=True,
    pin_memory_device=device,
)
len(data_loader)

In [None]:
from model import get_model2
model = get_model2(n_predict, device)
# initialize the weights
def init_weights(m):
    if type(m) == torch.nn.Conv2d:
        torch.nn.init.normal_(m.weight)
        m.bias.data.fill_(0.2)
    elif type(m) == torch.nn.Linear:
        torch.nn.init.normal_(m.weight)
        m.bias.data.fill_(0.2)
        

model.apply(init_weights)

lr = 1e-5
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
## Pass dummy batch
x, y = next(iter(dataset2))
x = x[1000:1101].transpose(1, 2).unsqueeze(1).cuda()
y = y[1000:1101].transpose(1, 2).unsqueeze(1).cuda()
x = (x - x.mean()) / x.std()
y = y / 100.
print(x.shape, y.shape)
y_ = model.forward(x)
print(y_.shape)
loss = MSELoss()

l = loss(y, y_)
total_params = sum(p.numel() for p in model.parameters())
total_params, l.item()

In [None]:
y.min(), y.max(), y_.min(), y_.max()

In [None]:
x.shape, y_.shape, y.shape

In [None]:
y[0, 0, :, 29]

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
dataset.split.entries[0].load_audio.display_ipython(0., 10.)
xc = x.detach().cpu().squeeze(1)
y_c = y_.detach().cpu().squeeze(1)
yc = y.detach().cpu().squeeze(1)
# 1001, 32, 229
# To 1001, 229
xc = xc[:, 0, :]
y_c = y_c[:, 0, :]
yc = yc[:, 0, :]

plt.figure(figsize=(20, 10))
plt.imshow(xc[:5000, :].transpose(0, 1), cmap='gray', interpolation='none')
plt.show()
plt.figure(figsize=(20, 10))
plt.imshow(y_c[:5000, :].transpose(0, 1), cmap='gray', interpolation='none')
plt.show()
plt.figure(figsize=(20, 10))
plt.imshow(yc[:5000, :].transpose(0, 1), cmap='gray', interpolation='none')
plt.show()

In [None]:
y_c.std(axis=0).max()

In [None]:
import wandb


In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="realtime-piano-transcription",

    # track hyperparameters and run metadata
    config={
        "learning_rate": lr,
        "architecture": "Basic-CNN-v2",
        "dataset": "MAESTRO-Validation",
        "epochs": epochs,
    }
)

In [None]:
from wandb import watch

watch(model)

In [None]:
model = model.to(device)

for epoch in range(epochs):
    # Train
    with tqdm(total=len(data_loader)) as pbar:
        total_loss = 0
        for idx, (x, y) in enumerate(data_loader):
            x = x.transpose(1, 2).unsqueeze(1).cuda()
            # Normalize x
            x = (x - x.mean()) / x.std()
            # Expand y from (batch, key, n_predict) to (batch, mel_bin, key); where key = 128, and mel_bin = 229; do this by padding with zeros
            y = y.transpose(1, 2).unsqueeze(1).cuda()
            y = y / 100.
            optimizer.zero_grad()
            output = model(x)
            loss_val = loss(output, y)
            loss_val.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            total_loss += loss_val.item()
            pbar.update(1)
            wandb.log({'train_loss': loss_val.item(), 'epoch': epoch, 'batch': idx})
        text = f'Epoch {epoch} - Loss: {total_loss}'
        pbar.set_description(text)
        print(text)


In [None]:
wandb.finish()

In [None]:
for param in model.parameters():
    print(param.grad.norm())

In [None]:
fx, fy = next(iter(dataset2))
x = fx[1000:1101].transpose(1, 2).unsqueeze(1).cuda()
y = fy[1000:1101].transpose(1, 2).unsqueeze(1).cuda()
x = (x - x.mean()) / x.std()
y = y / 100.
y_ = model.forward(x)
l = loss(y, y_)
total_params = sum(p.numel() for p in model.parameters())
total_params, l.item()

In [None]:
y.min(), y.max(), y_.min(), y_.max()

In [None]:
fy.shape.numel() / fy.count_nonzero()

In [None]:
y_c = y_.detach().cpu()
yc = y.detach().cpu()

In [None]:
import matplotlib.pyplot as plt
display(plt.imshow(y_c[0, 0, :, :].transpose(0, 1), cmap='gray', interpolation='none'))

In [None]:
display(plt.imshow(yc[0, 0, :, :].transpose(0, 1), cmap='gray', interpolation='none'))

In [None]:
# torch.save(model.state_dict(), 'model2.pth')