# Training

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from model import NoteComposeNet
from dataset import MidiDataset, VOCABULARY
from torch.utils.data import DataLoader
from train import TrainPipeline

import torch
import pandas as pd

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from datetime import datetime

In [3]:
BATCH_SIZE = 32
TRAIN_SAMPLES_PER_TRACK = 8
VALIDATE_SAMPLES_PER_TRACK = 1
EPOCHS = 8
EPOCHS_SO_FAR = 0

In [4]:
#CHECKPOINT = torch.load(r'./checkpoints/model_20230801_152146_0')
model = NoteComposeNet()
#model.load_state_dict(CHECKPOINT)

param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 16.142MB


In [5]:
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated() / 1024 ** 2)

16.298828125


In [7]:
CSV_PATH = r'datasets/midi-dataset-mini.csv'
df = pd.read_csv(CSV_PATH)
train_midi = MidiDataset(df, context_len = model._context_len, train_samples=TRAIN_SAMPLES_PER_TRACK, validate_samples=VALIDATE_SAMPLES_PER_TRACK)
del df

In [8]:
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1)

optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=0.0001,
    weight_decay=0.1, 
    betas=(0.9, 0.95),
    eps=1e-05,
    amsgrad=False
    )

from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
scheduler = CosineAnnealingWarmRestarts(
    optimizer=optimizer, 
    T_0=2,
    eta_min=0.0001,
)

for i in range(0, EPOCHS_SO_FAR):
    scheduler.step()

In [9]:
pipeline = TrainPipeline(train_midi, model, loss_fn, optimizer, validate=True, 
                         batch_size=BATCH_SIZE, scheduler = scheduler, grad_acc = 2)
pipeline.train(EPOCHS)

Epoch 1: 100%|██████████| 34/34 [00:13<00:00,  2.54batch/s, train loss=5.12884]
Epoch 1: 100%|██████████| 5/5 [00:08<00:00,  1.63s/batch, val loss=5.16644]
  0%|          | 0/34 [00:06<?, ?batch/s]


KeyboardInterrupt: 