In [1]:
import os
os.chdir(r'7 - Putting it together')
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import music21 as m21
musescore_path = '/usr/bin/mscore'
m21.environment.set('musicxmlPath', musescore_path)
m21.environment.set('musescoreDirectPNGPath', musescore_path)
from midi_encoding import *
from data_loading import *
from model import *

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device}.")

Using cuda.


In [2]:
!nvidia-smi

Tue Sep 10 19:21:45 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.02              Driver Version: 560.94         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:01:00.0  On |                  Off |
| 30%   33C    P2             48W /  450W |    3144MiB /  24564MiB |     38%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
vg_large_path = Path('../data/midi/vg_large')
vg_large_file_names = [f for f in os.listdir(vg_large_path) if os.path.isfile(os.path.join(vg_large_path, f))]

# Ensure files are shuffled directly after assignment.
# If they are shuffled in a different cell, and that cell is run multiple times, the order will change as we are shuffling the already-shuffled list.
random.seed(42)
random.shuffle(vg_large_file_names)

len(vg_large_file_names)

3839

In [4]:
sample_length = 256
max_file_length = 32

midi_path = Path('../data/midi/vg_large')
score_path = Path(f'../data/numpy/vg_large/all')

n1 = int(0.8 * len(vg_large_file_names))
n2 = int(0.9 * len(vg_large_file_names))
train_filenames = vg_large_file_names[:n1]
valid_filenames = vg_large_file_names[n1:n2]
test_filenames = vg_large_file_names[n2:]

print(f'Train file names: {len(train_filenames)}, Valid file names: {len(valid_filenames)}, Test file names: {len(test_filenames)}')

Train file names: 3071, Valid file names: 384, Test file names: 384


In [5]:
vocab_size = 400
vocab_name = f'vg_large-bpe_{vocab_size}' # Also try 500, 750, 1000 vocab size
vocab_state_dict_path = Path(f'../data/vocab/{vocab_name}.pkl')
model_name = f'midi_transformer_knn-xl_{vocab_name}'
model_load_path = Path(f'../data/checkpoints/{model_name}.dat')
model_save_path = Path(f'../data/checkpoints/{model_name}.dat')
log_dir = Path(f'../tensorboard/{model_name}')
tensorboard_writer = SummaryWriter(log_dir)

MidiDataset says "Look for the given filenames at the given score path. Load if they exist, if not create them'.

We can use this to encode with the trained vocab if we pass it in as a param.

In [6]:
if not vocab_state_dict_path.exists():
    # Basic vocab
    dataset_vocab = MusicVocab()
    
    # Load (and / or create) dataset of unmerged samples
    vocab_training_dataset = MidiDataset(vg_large_file_names, midi_path, score_path, sample_length, max_file_length)
    print(f"Loading dataset samples...")
    vocab_training_dataset.load_samples(dataset_vocab, "cpu")

    # Train the vocab on the unmerged dataset, so it can learn the merges
    print(f"Training vocab...")
    trained_vocab = MusicVocab()
    trained_vocab.train(vocab_training_dataset, max_vocab_size=vocab_size)
    
    print(f"Saving vocab...")
    trained_vocab.save(vocab_state_dict_path)
else:
    print(f"Loading vocab...")
    trained_vocab = MusicVocab.load(vocab_state_dict_path)

Loading dataset samples...


In [None]:
merged_score_path = Path(f'../data/numpy/{vocab_name}')

# Use the trained vocab to load GPU datasets, which will create merged samples if we pass a new path
train_dataset = MidiDataset(train_filenames, midi_path, merged_score_path, sample_length, max_file_length)
valid_dataset = MidiDataset(valid_filenames, midi_path, merged_score_path, sample_length, max_file_length)
test_dataset = MidiDataset(test_filenames, midi_path, merged_score_path, sample_length, max_file_length)

print(f'Loading train samples')
train_dataset.load_samples(trained_vocab, device)

print(f'Loading valid samples')
valid_dataset.load_samples(trained_vocab, device)

print(f'Loading test samples')
test_dataset.load_samples(trained_vocab, device)

In [None]:
print(f'Train files: {len(train_dataset.file_lengths)}, Valid files: {len(valid_dataset.file_lengths)}, Test files: {len(test_dataset.file_lengths)}')

In [None]:
# Batch size can be changed for a second phase of training quite quickly, it only requires re-computing the indices, not re-loading the data.
batch_size = 32
train_sampler = ContiguousBatchSampler(train_dataset)
valid_sampler = ContiguousBatchSampler(valid_dataset)
test_sampler = ContiguousBatchSampler(test_dataset)

print(f'Precomputing indices')
train_sampler.precompute_indices(batch_size)
valid_sampler.precompute_indices(batch_size)
test_sampler.precompute_indices(batch_size)

train_data_loader = DataLoader(train_dataset, batch_sampler=train_sampler)
valid_data_loader = DataLoader(valid_dataset, batch_sampler=valid_sampler)
test_data_loader = DataLoader(test_dataset, batch_sampler=test_sampler)

In [None]:
model = DecoderTransformer_KNN_XL(db_filepath=Path('../data/numpy/knn-demo'), vocab=trained_vocab, sample_length=sample_length, max_file_length=max_file_length, use_knn=True)

print(sum(p.numel() for p in model.parameters()))

In [None]:
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

In [None]:
learning_rate = 3e-4
weight_decay = 1e-3
eval_iters = 100

model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

@torch.no_grad()
def estimate_loss(data_loader):
    model.eval()
        
    xl_memories = None
    losses = torch.zeros(eval_iters)
    
    # Not working
    # Start at a random point in the data, making sure we have enough data to evaluate
    # offset = random.randint(0, len(data_loader.dataset) - (eval_iters + 1))
    # data_iter = itertools.islice(iter(data_loader), offset, None)

    data_iter = iter(data_loader)

    for k in range(eval_iters):
        file_idxs, batch = next(data_iter)
        X, Y = batch[:, :-1], batch[:, 1:, 0] # drop absolute position from Y
        _, loss, xl_memories = model(file_idxs, X, xl_memories, Y)
        losses[k] = loss.item()

    model.train()
    
    return losses.mean()

In [None]:
average_log_losses = {  
    "train" : [],
    "val" : []
}

epochs = 0

def save_checkpoint(iterations):
    train_loss = estimate_loss(train_data_loader)
    val_loss = estimate_loss(valid_data_loader)
    tensorboard_writer.add_scalar('Loss/train', train_loss, iterations)
    tensorboard_writer.add_scalar('Loss/val', val_loss, iterations)
    train_log_loss = train_loss.log10().item()
    val_log_loss = val_loss.log10().item()
    average_log_losses['train'].append(train_log_loss)
    average_log_losses['val'].append(val_log_loss)
    print(f'Epoch {epochs} / Iteration {iterations}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
    torch.save({
        'iter': iterations,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'vocab_state_dict': trained_vocab.state_dict(),
        'losses': average_log_losses,
        'epochs': epochs
    }, model_save_path)

In [None]:
eval_interval = 200
total_iterations = 100000
start_iter = 0

if model_load_path.exists():
    checkpoint = torch.load(model_load_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    average_log_losses = checkpoint['losses']
    iterations = checkpoint['iter']
    epochs = checkpoint['epochs']
    start_iter = iterations + 1
    print(f"Loaded model from iteration {iterations}")

In [None]:
model.train()

remaining_iters = total_iterations - start_iter
if remaining_iters != -1:

    print(f"Training from epoch {epochs} for {remaining_iters} iterations")
    
    xl_memories = None
    initial_file_idxs = None
    train_data = iter(train_data_loader)
    offset_iter = start_iter

    for iteration in range(remaining_iters):
        offset_iter = iteration + start_iter

        if offset_iter % eval_interval == 0:
            # Because we don't explicitly clear xl and knn mem here, there is always a risk that the eval loop leaves the file idx
            # the same as the train loop, but with memories of the 'future' which aren't cleared. It could also break the epoch counter.
            # The risk would be much smaller with a bigger data set, but with vg_large we loop through the data quite quickly.
            save_checkpoint(offset_iter)

        # Configure minibatch
        file_idxs, batch = next(train_data)

        if initial_file_idxs is None:
            initial_file_idxs = file_idxs

        if torch.equal(initial_file_idxs, file_idxs):
            epochs += 1
        
        X, Y = batch[:, :-1], batch[:, 1:, 0]

        # Forward pass
        logits, loss, xl_memories = model(file_idxs, X, xl_memories, Y)

        # Backward pass
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    save_checkpoint(offset_iter + 1)

In [None]:
print('Final training loss:', 10 ** average_log_losses['train'][-1])

In [None]:
print('Final validation loss:', 10 ** average_log_losses['val'][-1])

In [None]:
plt.plot(average_log_losses['train'][1:])

In [None]:
plt.plot(average_log_losses['val'][1:])

In [None]:
init_idx = torch.zeros((1,1,2), dtype=torch.long, device=device)
generated_tokens = model.generate(init_idx, max_new_tokens=512).cpu()

In [None]:
generated_tokens.shape

In [None]:
score = generated_tokens[0, :, 0]
trained_vocab.to_tokens(score)

In [None]:
generated_stream = idx_to_stream_enc(np.array(score), trained_vocab)
generated_stream.plot()

In [None]:
generated_stream.show('midi')

In [None]:
random_test_midi_file = random.choice(test_filenames)
random_test_path = Path(midi_path, random_test_midi_file)
random_test_idx_score = midifile_to_idx_score(random_test_path, trained_vocab)
random_test_intro = random_test_idx_score[:sample_length]
random_test_intro_stream = idx_to_stream_enc(np.array(random_test_intro[:, 0]), trained_vocab)
random_test_intro_stream.plot()

In [None]:
random_test_intro_stream.show('midi')

In [None]:
random_test_init = torch.tensor(random_test_intro, device=device).unsqueeze(0)
random_test_continued = model.generate(random_test_init, max_new_tokens=512).cpu()[0, :, 0]
random_test_continued_stream = idx_to_stream_enc(np.array(random_test_continued), trained_vocab)
random_test_continued_stream.plot()

In [None]:
random_test_continued_stream.show('midi')