# Training

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from models.data_split import MidiDataset, VOCABULARY
from models.model import NoteComposeNet
from torch.utils.data import DataLoader
from train import TrainPipeline

import torch
import pandas as pd

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from datetime import datetime

import math

In [3]:
BATCH_SIZE = 32
TRAIN_SAMPLES_PER_TRACK = int(1e6)
VALIDATE_SAMPLES_PER_TRACK = int(1e5)
EPOCHS = 20
EPOCHS_SO_FAR = 0
GRADIENT_ACC = 32

In [4]:
# CHECKPOINT = torch.load(r'checkpoints/model_20230811_063842_0')
model = NoteComposeNet()
# model.load_state_dict(CHECKPOINT)

param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 0.386MB


In [5]:
CSV_PATH = r'datasets/midi-dataset-flat.csv'
df = pd.read_csv(CSV_PATH)
train_midi = MidiDataset(df, context_len = model._context_len, train_samples=TRAIN_SAMPLES_PER_TRACK, validate_samples=VALIDATE_SAMPLES_PER_TRACK, start_length=-1)
del df

In [6]:
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.01)

optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=0.0001,
    weight_decay=0.1, 
    betas=(0.9, 0.95),
    eps=1e-05,
    amsgrad=False
    )

scheduler = CosineAnnealingWarmRestarts(
    optimizer=optimizer, 
    T_0= int(math.ceil(len(train_midi) / (GRADIENT_ACC * BATCH_SIZE))),
    eta_min=1e-6,
)

In [7]:
print("Total samples: ", EPOCHS * len(train_midi))
print("Total tokens:", EPOCHS * len(train_midi) * 2048)
print("Total updates per epoch", int(math.ceil(len(train_midi) / (GRADIENT_ACC * BATCH_SIZE))))

Total samples:  20000000
Total tokens: 40960000000
Total updates per epoch 977


In [8]:
pipeline = TrainPipeline(train_midi, model, loss_fn, optimizer, validate=True, 
                         batch_size=BATCH_SIZE, scheduler = scheduler, grad_acc = GRADIENT_ACC)
pipeline.train(EPOCHS)

Epoch 1: 100%|██████████| 31250/31250 [37:03<00:00, 14.06batch/s, train loss=4.09873, lr=[1.001023633088236e-06]] 
Epoch 1: 100%|██████████| 3125/3125 [01:25<00:00, 36.50batch/s, val loss=4.15298]
Epoch 2: 100%|██████████| 31250/31250 [35:59<00:00, 14.47batch/s, train loss=3.98249, lr=[1.0023031645259293e-06]]
Epoch 2: 100%|██████████| 3125/3125 [01:24<00:00, 37.07batch/s, val loss=3.97534]
Epoch 3: 100%|██████████| 31250/31250 [35:55<00:00, 14.50batch/s, train loss=3.92029, lr=[1.0040944900166022e-06]]
Epoch 3: 100%|██████████| 3125/3125 [01:23<00:00, 37.21batch/s, val loss=3.94870]
Epoch 4: 100%|██████████| 31250/31250 [36:11<00:00, 14.39batch/s, train loss=3.96700, lr=[1.0063975910383878e-06]]
Epoch 4: 100%|██████████| 3125/3125 [01:24<00:00, 36.99batch/s, val loss=3.94328]
Epoch 5: 100%|██████████| 31250/31250 [35:56<00:00, 14.49batch/s, train loss=3.97733, lr=[1.009212443777784e-06]] 
Epoch 5: 100%|██████████| 3125/3125 [01:24<00:00, 37.14batch/s, val loss=3.93598]
Epoch 6: 100%|█