# Training

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from models.data_split import MidiDataset, VOCABULARY
from models.model import NoteComposeNet
from torch.utils.data import DataLoader
from train import TrainPipeline

import torch
import pandas as pd

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from datetime import datetime

import math

In [3]:
BATCH_SIZE = 32
TRAIN_SAMPLES_PER_TRACK = int(1e6)
VALIDATE_SAMPLES_PER_TRACK = int(1e5)
EPOCHS = 20
EPOCHS_SO_FAR = 0
GRADIENT_ACC = 32

In [4]:
CHECKPOINT = torch.load(r'./checkpoints/model_20230805_182913_8')
model = NoteComposeNet()
model.load_state_dict(CHECKPOINT)

param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 16.142MB


In [5]:
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated() / 1024 ** 2)

32.44140625


In [6]:
CSV_PATH = r'datasets/maestro.csv'
df = pd.read_csv(CSV_PATH)
train_midi = MidiDataset(df, context_len = model._context_len, train_samples=TRAIN_SAMPLES_PER_TRACK, validate_samples=VALIDATE_SAMPLES_PER_TRACK)
del df

In [7]:
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1)

optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=0.0005,
    weight_decay=0.1, 
    betas=(0.9, 0.95),
    eps=1e-05,
    amsgrad=False
    )

from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
scheduler = CosineAnnealingWarmRestarts(
    optimizer=optimizer, 
    T_0= int(math.ceil(len(train_midi) / (GRADIENT_ACC * BATCH_SIZE))),
    eta_min=1e-6,
)

In [8]:
print("Total samples: ", EPOCHS * len(train_midi))
print("Total tokens:", EPOCHS * len(train_midi) * 2048)
print("Total updates per epoch", int(math.ceil(len(train_midi) / (GRADIENT_ACC * BATCH_SIZE))))

Total samples:  20000000
Total tokens: 40960000000
Total updates per epoch 977


In [9]:
pipeline = TrainPipeline(train_midi, model, loss_fn, optimizer, validate=True, 
                         batch_size=BATCH_SIZE, scheduler = scheduler, grad_acc = GRADIENT_ACC)
pipeline.train(EPOCHS)

Epoch 1: 100%|██████████| 31250/31250 [1:27:16<00:00,  5.97batch/s, train loss=3.92915, lr=[1.0051595243538356e-06]] 
Epoch 1: 100%|██████████| 3125/3125 [10:01<00:00,  5.20batch/s, val loss=3.87160]
Epoch 2: 100%|██████████| 31250/31250 [1:26:52<00:00,  5.99batch/s, train loss=3.88268, lr=[1.0116088797822096e-06]]
Epoch 2: 100%|██████████| 3125/3125 [10:02<00:00,  5.19batch/s, val loss=3.85902]
Epoch 3:  55%|█████▍    | 17077/31250 [50:38<40:07,  5.89batch/s, train loss=3.85861, lr=[0.0002173041167590045]]   