In [51]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
from model import NoteComposeNet
from dataset import MidiDataset, VOCABULARY
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
import librosa 
import torch

In [53]:
df = {
    'notes': [torch.tensor([1 for i in range(0, 255)])],
    'velocities': [torch.tensor([1 for i in range(0, 255)])],
    'durations': [torch.tensor([1 for i in range(0, 255)])],
    'times': [torch.tensor([1 for i in range(0, 255)])],     
}

In [14]:
model = NoteComposeNet()

In [16]:
# Model Specifications
param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 0.079MB


In [17]:
midi = MidiDataset(df, context_len = model._context_len)
train_loader = DataLoader(midi, batch_size=1)

data = next(iter(train_loader))
note_tensor = data['notes'].to("cuda")
velocity_tensor = data['velocities'].to("cuda")
duration_tensor = data['durations'].to("cuda")
time_tensor = data['times'].to("cuda")

input = {
    "notes": note_tensor,
    "velocities": velocity_tensor,
    "durations": duration_tensor ,
    "times": time_tensor
}

NameError: name 'df' is not defined

In [19]:
test_inputs = np.random.randint(0, len(VOCABULARY), size=255, dtype=int)
next_notes = model.generate(test_inputs, max_len=100)
print("Note = ", model.detokenize(next_notes))

  def detokenize(self, inputs):


Note =  ['A♯-1', 'D♯8', '<EOS>', 'D♯5', 'A1', 'E9', 'A♯0', 'D♯8', 'B2', 'E1', 'A6', 'C♯5', 'G♯3', 'F9', 'C5', 'A6', 'C9', 'B3', 'E1', 'C1', 'C9', 'F5', 'E1', 'A6', 'C9', 'D♯5', 'A1', 'E1', 'D♯5', 'D♯8', 'D♯8', 'B7', 'C♯0', 'F♯9', 'B7', 'F♯1', 'F2', 'D♯5', 'C♯2', 'D7', 'G2', 'A♯0', 'C♯5', 'G5', 'F0', 'D1', 'G♯0', 'E1', 'G♯3', 'E1', 'A1', 'C7', 'F0', 'G0', 'C5', 'F5', 'B3', 'F5', 'E1', 'G2', 'G3', 'A6', 'F2', 'F0', 'B0', 'A6', 'C♯5', 'F♯6', 'C♯2', 'E1', 'A6', 'A1', 'A6', 'A1', 'A6', 'F5', 'G0', 'C5', 'G0', 'A6', 'F5', 'B3', 'D♯-1', 'A6', 'C♯5', 'B3', 'E3', 'A6', 'G0', 'A♯-1', 'F0', 'E0', 'B2', 'E1', 'D2', 'C♯5', 'C♯-1', 'F5', 'E1', 'A6']


# Training

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from model import NoteComposeNet
from dataset import MidiDataset
from torch.utils.data import DataLoader

import torch
import pandas as pd

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

In [7]:
BATCH_SIZE = 32
SAMPLES_PER_TRACK = 128
EPOCHS = 100

In [3]:
model = NoteComposeNet()

In [8]:
CSV_PATH = r'midi-dataset-mini.csv'
df = pd.read_csv(CSV_PATH)
midi = MidiDataset(df, context_len = model._context_len, samples_per_track=SAMPLES_PER_TRACK)
del df

In [9]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=0.001, 
    weight_decay=0.004,
    betas=(0.9, 0.999),
    eps=1e-07, 
    amsgrad=False)

In [10]:
train_loader = DataLoader(
    midi, 
    batch_size=BATCH_SIZE,
    num_workers=1,
    shuffle=True, 
)

In [11]:
def unpack_batch(batch):
    b, attn, gt = batch
    
    notes = b['notes'].to(model._device)
    notes_gt = gt['notes'].to(model._device)

    output_logits = model.forward(notes)

    return output_logits, notes_gt
    

In [12]:
def train_one_epoch(epoch_index, tb_writer):
    total_loss = 0

    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()
        
        output_logits, notes_gt = unpack_batch(batch)
        
        loss = loss_fn(output_logits, notes_gt)
        loss.backward() 

        optimizer.step()
        
        # Gather data and report
        total_loss += loss.item()
        last_loss = loss.item()
        print('  batch {} loss: {}'.format(i + 1, loss.item()))
        tb_x = epoch_index * len(train_loader) + i + 1
        tb_writer.add_scalar('Loss/train', last_loss, tb_x)

    return last_loss

In [13]:
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/composer_{}'.format(timestamp))
epoch_number = 0

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)

    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(train_loader):
            voutputs, vgt = unpack_batch(vdata)
            vloss = loss_fn(voutputs, vgt)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'checkpoints/model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
  batch 1 loss: 4.887590438127518
  batch 2 loss: 4.887490317225456
  batch 3 loss: 4.887367427349091
  batch 4 loss: 4.8870381116867065
  batch 5 loss: 4.886938378214836
  batch 6 loss: 4.887105941772461
  batch 7 loss: 4.8870065957307816
  batch 8 loss: 4.886593982577324
  batch 9 loss: 4.886748388409615
  batch 10 loss: 4.88700906932354
  batch 11 loss: 4.886569797992706
  batch 12 loss: 4.886604607105255
  batch 13 loss: 4.886535182595253
  batch 14 loss: 4.886377707123756
  batch 15 loss: 4.886149913072586
  batch 16 loss: 4.886054769158363
  batch 17 loss: 4.885629251599312
  batch 18 loss: 4.885914504528046
  batch 19 loss: 4.88531257212162
  batch 20 loss: 4.885616093873978
  batch 21 loss: 4.885588496923447
  batch 22 loss: 4.885253548622131
  batch 23 loss: 4.885059744119644
  batch 24 loss: 4.884831979870796
  batch 25 loss: 4.884999170899391
  batch 26 loss: 4.884742856025696
  batch 27 loss: 4.884558469057083
  batch 28 loss: 4.884173110127449
  batch 29 loss: 4.8

KeyboardInterrupt: 