# Training

In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
from model import NoteComposeNet
from dataset import MidiDataset
from torch.utils.data import DataLoader
from train import TrainPipeline

import torch
import pandas as pd

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from datetime import datetime

In [15]:
BATCH_SIZE = 64
TRAIN_SAMPLES_PER_TRACK = 16
VALIDATE_SAMPLES_PER_TRACK = 1
EPOCHS = 8
EPOCHS_SO_FAR = 0

In [16]:
#CHECKPOINT = torch.load(r'./checkpoints/model_20230801_152146_0')
model = NoteComposeNet()
#model.load_state_dict(CHECKPOINT)

param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 16.142MB


In [17]:
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated() / 1024 ** 2)

49.181640625


In [18]:
CSV_PATH = r'datasets/midi-dataset.csv'
df = pd.read_csv(CSV_PATH)
train_midi = MidiDataset(df, context_len = model._context_len, train_samples=TRAIN_SAMPLES_PER_TRACK, validate_samples=VALIDATE_SAMPLES_PER_TRACK)
del df

In [19]:
loss_fn = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=0.0001,
    weight_decay=0.1, 
    betas=(0.9, 0.95),
    eps=1e-05,
    amsgrad=False
    )

from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
scheduler = CosineAnnealingWarmRestarts(
    optimizer=optimizer, 
    T_0=2,
    eta_min=0.0001,
)

for i in range(0, EPOCHS_SO_FAR):
    scheduler.step()

In [20]:
pipeline = TrainPipeline(train_midi, model, loss_fn, optimizer, validate=False, 
                         batch_size=BATCH_SIZE)
pipeline.train(EPOCHS)

Epoch 1: 100%|██████████| 21306/21306 [59:39<00:00,  5.95batch/s, train loss=4.64574] 
Epoch 2: 100%|██████████| 21306/21306 [59:29<00:00,  5.97batch/s, train loss=4.39574] 
Epoch 3: 100%|██████████| 21306/21306 [1:00:42<00:00,  5.85batch/s, train loss=4.64574]
Epoch 4: 100%|██████████| 21306/21306 [59:38<00:00,  5.95batch/s, train loss=4.39574] 
Epoch 5:   2%|▏         | 446/21306 [04:19<55:11,  6.30batch/s, train loss=4.52074]    