# Training

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from models.data_split import MidiDataset, VOCABULARY
from models.model import NoteComposeNet
from torch.utils.data import DataLoader
from train import TrainPipeline

import torch
import pandas as pd

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from datetime import datetime

import math

In [3]:
BATCH_SIZE = 32
TRAIN_SAMPLES_PER_TRACK = int(1e6)
VALIDATE_SAMPLES_PER_TRACK = int(1e5)
EPOCHS = 20
EPOCHS_SO_FAR = 0
GRADIENT_ACC = 32

In [4]:
# CHECKPOINT = torch.load(r'checkpoints/maestro_tuned_16_last')
model = NoteComposeNet()
# model.load_state_dict(CHECKPOINT)

param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

FileNotFoundError: [Errno 2] No such file or directory: 'checkpoints/maestro_tuned_16_last'

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated() / 1024 ** 2)

32.59765625


In [None]:
CSV_PATH = r'datasets/midi-dataset-flat.csv'
df = pd.read_csv(CSV_PATH)
train_midi = MidiDataset(df, context_len = model._context_len, train_samples=TRAIN_SAMPLES_PER_TRACK, validate_samples=VALIDATE_SAMPLES_PER_TRACK)
del df

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=0.0001,
    weight_decay=0.1, 
    betas=(0.9, 0.95),
    eps=1e-05,
    amsgrad=False
    )

from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
scheduler = CosineAnnealingWarmRestarts(
    optimizer=optimizer, 
    T_0= int(math.ceil(len(train_midi) / (GRADIENT_ACC * BATCH_SIZE))),
    eta_min=1e-6,
)

In [None]:
print("Total samples: ", EPOCHS * len(train_midi))
print("Total tokens:", EPOCHS * len(train_midi) * 2048)
print("Total updates per epoch", int(math.ceil(len(train_midi) / (GRADIENT_ACC * BATCH_SIZE))))

Total samples:  20000000
Total tokens: 40960000000
Total updates per epoch 977


In [None]:
pipeline = TrainPipeline(train_midi, model, loss_fn, optimizer, validate=True, 
                         batch_size=BATCH_SIZE, scheduler = scheduler, grad_acc = GRADIENT_ACC)
pipeline.train(EPOCHS)

Epoch 1:  43%|████▎     | 13386/31250 [38:23<51:51,  5.74batch/s, train loss=3.62088, lr=[6.178063871676258e-05]]  