In [51]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
from model import NoteComposeNet
from dataset import MidiDataset, VOCABULARY
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
import librosa 
import torch

In [53]:
df = {
    'notes': [torch.tensor([1 for i in range(0, 255)])],
    'velocities': [torch.tensor([1 for i in range(0, 255)])],
    'durations': [torch.tensor([1 for i in range(0, 255)])],
    'times': [torch.tensor([1 for i in range(0, 255)])],     
}

In [14]:
model = NoteComposeNet()

In [16]:
# Model Specifications
param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 0.079MB


In [17]:
midi = MidiDataset(df, context_len = model._context_len)
train_loader = DataLoader(midi, batch_size=1)

data = next(iter(train_loader))
note_tensor = data['notes'].to("cuda")
velocity_tensor = data['velocities'].to("cuda")
duration_tensor = data['durations'].to("cuda")
time_tensor = data['times'].to("cuda")

input = {
    "notes": note_tensor,
    "velocities": velocity_tensor,
    "durations": duration_tensor ,
    "times": time_tensor
}

NameError: name 'df' is not defined

In [19]:
test_inputs = np.random.randint(0, len(VOCABULARY), size=255, dtype=int)
next_notes = model.generate(test_inputs, max_len=100)
print("Note = ", model.detokenize(next_notes))

  def detokenize(self, inputs):


Note =  ['A♯-1', 'D♯8', '<EOS>', 'D♯5', 'A1', 'E9', 'A♯0', 'D♯8', 'B2', 'E1', 'A6', 'C♯5', 'G♯3', 'F9', 'C5', 'A6', 'C9', 'B3', 'E1', 'C1', 'C9', 'F5', 'E1', 'A6', 'C9', 'D♯5', 'A1', 'E1', 'D♯5', 'D♯8', 'D♯8', 'B7', 'C♯0', 'F♯9', 'B7', 'F♯1', 'F2', 'D♯5', 'C♯2', 'D7', 'G2', 'A♯0', 'C♯5', 'G5', 'F0', 'D1', 'G♯0', 'E1', 'G♯3', 'E1', 'A1', 'C7', 'F0', 'G0', 'C5', 'F5', 'B3', 'F5', 'E1', 'G2', 'G3', 'A6', 'F2', 'F0', 'B0', 'A6', 'C♯5', 'F♯6', 'C♯2', 'E1', 'A6', 'A1', 'A6', 'A1', 'A6', 'F5', 'G0', 'C5', 'G0', 'A6', 'F5', 'B3', 'D♯-1', 'A6', 'C♯5', 'B3', 'E3', 'A6', 'G0', 'A♯-1', 'F0', 'E0', 'B2', 'E1', 'D2', 'C♯5', 'C♯-1', 'F5', 'E1', 'A6']


# Training

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from model import NoteComposeNet
from dataset import MidiDataset
from torch.utils.data import DataLoader
from train import TrainPipeline

import torch
import pandas as pd

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

In [3]:
BATCH_SIZE = 32
TRAIN_SAMPLES_PER_TRACK = 128
VALIDATE_SAMPLES_PER_TRACK = 4
EPOCHS = 100

In [4]:
model = NoteComposeNet()

In [5]:
CSV_PATH = r'datasets/midi-dataset-mini.csv'
df = pd.read_csv(CSV_PATH)
train_midi = MidiDataset(df, context_len = model._context_len, samples_per_track=TRAIN_SAMPLES_PER_TRACK)
vali_midi = MidiDataset(df, context_len= model._context_len, samples_per_track=VALIDATE_SAMPLES_PER_TRACK)
del df

In [6]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=0.001, 
    weight_decay=0.004,
    betas=(0.9, 0.999),
    eps=1e-07, 
    amsgrad=False)
train_loader = DataLoader(
    train_midi, 
    batch_size=BATCH_SIZE,
    num_workers=4,
    shuffle=True, 
)
validation_loader = DataLoader(
    vali_midi, 
    batch_size=BATCH_SIZE,
    num_workers=4,
    shuffle=True
)

In [7]:
pipeline = TrainPipeline(train_loader, validation_loader, model, loss_fn, optimizer)
pipeline.train(EPOCHS)

EPOCH 1:
  batch 1 loss: 4.88645900785923
  batch 2 loss: 4.886491045355797
  batch 3 loss: 4.88607856631279
  batch 4 loss: 4.885797053575516
  batch 5 loss: 4.885819911956787
  batch 6 loss: 4.885721489787102
  batch 7 loss: 4.886048048734665
  batch 8 loss: 4.885497733950615
  batch 9 loss: 4.885235980153084
  batch 10 loss: 4.885341718792915
  batch 11 loss: 4.885132327675819
  batch 12 loss: 4.884786814451218
  batch 13 loss: 4.884548336267471
  batch 14 loss: 4.884933724999428
  batch 15 loss: 4.884401947259903
  batch 16 loss: 4.884033352136612
  batch 17 loss: 4.884138315916061
  batch 18 loss: 4.884180814027786
  batch 19 loss: 4.883368223905563
  batch 20 loss: 4.883040085434914
  batch 21 loss: 4.883487179875374
  batch 22 loss: 4.8831798285245895
  batch 23 loss: 4.882603242993355
  batch 24 loss: 4.882429331541061
  batch 25 loss: 4.881964534521103
  batch 26 loss: 4.882133200764656
  batch 27 loss: 4.881354928016663
  batch 28 loss: 4.881146058440208
  batch 29 loss: 4.88

KeyboardInterrupt: 