In [1]:
import os
os.chdir(r'9 - Rebuild')
import torch
from miditok import REMI, TokenizerConfig  # here we choose to use REMI
from pathlib import Path
import random
from miditok.utils import split_files_for_training
from miditok.data_augmentation import augment_dataset
from miditok.pytorch_data import DatasetMIDI, DataCollator
from torch.utils.data import DataLoader
from miditok import TokSequence
from multiprocessing import Pool
from memorizing_transformers_pytorch import MemorizingTransformer
import tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device}.")

Using cuda.


In [2]:
if device == "cuda":
    print(f"Device: {torch.cuda.get_device_name()}.")

Device: NVIDIA GeForce RTX 4090.


# Memorizing Transformers

From our MIDITok research, we know that we we want to 

- Train a BPE tokenizer on the entire dataset.
- Save it / Load it (BPE is deterministic so data doesn't need to be decoded with a the same tokenizer it was encoded with, providing whatever is used was trained on the same data with the same config. Unigram is *not* deterministic however so would require the same exact tokenizer for encode / decode).
- Shuffle file names, so songs aren't biased to a set.
- Split into test / train / validation sets.
- Split the files into chunks for each set.
- Optionally augment the dataset with pitch / velocity / duration shifted versions
- Shuffle the chunks when loading, so that parts of a single song aren't biased to a batch.
- Load the chunks with `max_seq_len` equal to that used when splitting files, to minimise padding / truncated data.
- Split the chunks into context-length sequences and feed them through contiguously.
- Manually reset memories between chunks rather than auto-reset on BOS / EOS tokens.

In [3]:
CHUNK_LENGTH = 2048
SEGMENTS = 8
TIMESTEPS = CHUNK_LENGTH // SEGMENTS # Context length
BATCH_SIZE = 16
VOCAB_SIZE = 411 # REMI untrained token count is 411
N_EMBED = 512
N_LAYER = 8
N_HEAD = 8
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-3
NUM_BATCHES = int(1e5)
MAX_GRAD_CLIP_NORM = 0.5
VALIDATE_EVERY  = 64
DIM_HEAD = N_EMBED // N_HEAD
VERSION_LABEL = "Base_vocab"
TOKENIZER_CONFIG = 'Basic'

midi_path = Path(f'../data/midi')
dataset_name = 'vg_large'
midi_dataset_path = Path(f'{midi_path}/{dataset_name}')
midi_file_paths = [p.resolve() for p in midi_dataset_path.glob("**/*.mid")]

tokenizer_save_path = Path(f'../data/vocab/MidiTok/{dataset_name}_{VOCAB_SIZE}_{TOKENIZER_CONFIG}.json')

if not tokenizer_save_path.exists():
    TOKENIZER_PARAMS = {
        "pitch_range": (21, 109),
        "beat_res": {(0, 4): 8, (4, 12): 4},
        "num_velocities": 32,
        "use_programs": True
        # "use_chords": True,
        # "use_time_signatures": True,
        # "use_tempos": True,
        # "num_tempos": 32,  # number of tempo bins
        # "tempo_range": (40, 250)
    }
    tokenizer_confg = TokenizerConfig(**TOKENIZER_PARAMS)
    tokenizer = REMI(tokenizer_confg)
    print(f"Untrained token count: {tokenizer.len}")
    tokenizer.train(vocab_size=VOCAB_SIZE, files_paths=midi_file_paths)
    tokenizer.save(tokenizer_save_path)
else:
    tokenizer = REMI(params=tokenizer_save_path)

Untrained token count: 411


  super().__init__(tokenizer_config, params)
  tokenizer.train(vocab_size=VOCAB_SIZE, files_paths=midi_file_paths)


In [4]:
random.seed(42)
random.shuffle(midi_file_paths)
len(midi_file_paths)

3839

In [5]:
n1 = int(0.8 * len(midi_file_paths))
n2 = int(0.9 * len(midi_file_paths))
train_filepaths = midi_file_paths[:n1]
valid_filepaths = midi_file_paths[n1:n2]
test_filepaths = midi_file_paths[n2:]

print(f'Train files: {len(train_filepaths)}, Valid files: {len(valid_filepaths)}, Test files: {len(test_filepaths)}')

Train files: 3071, Valid files: 384, Test files: 384


In [6]:
chunk_path = Path(f'{midi_path}/mtok_split/{dataset_name}/v-{VOCAB_SIZE}_t-{TOKENIZER_CONFIG}_c-{CHUNK_LENGTH}')
train_chunk_path = Path(f'{chunk_path}/train')
valid_chunk_path = Path(f'{chunk_path}/valid')
test_chunk_path = Path(f'{chunk_path}/test')

split_data = [
    (train_filepaths, train_chunk_path),
    (valid_filepaths, valid_chunk_path),
    (test_filepaths, test_chunk_path)
]

def chunk_files(filepaths, tokenizer, chunks_dir, max_seq_len):
    split_files_for_training(
        files_paths=filepaths,
        tokenizer=tokenizer,
        save_dir=chunks_dir,
        max_seq_len=max_seq_len,
        num_overlap_bars=1
    )

if not chunk_path.exists():
    with Pool(processes=3) as pool:
        pool.starmap(chunk_files, [(filepaths, tokenizer, chunks_dir, CHUNK_LENGTH) for filepaths, chunks_dir in split_data])

In [7]:
train_chunk_filepaths = list(train_chunk_path.glob("**/*.mid"))
valid_chunk_filepaths = list(valid_chunk_path.glob("**/*.mid"))
test_chunk_filepaths = list(test_chunk_path.glob("**/*.mid"))

print(f'Train chunks: {len(train_chunk_filepaths)}, Valid chunks: {len(valid_chunk_filepaths)}, Test chunks: {len(test_chunk_filepaths)}')

Train chunks: 7973, Valid chunks: 1063, Test chunks: 943


In [8]:
chunk_paths = [train_chunk_filepaths, valid_chunk_filepaths, test_chunk_filepaths]

for chunk_path in chunk_paths:
    random.shuffle(chunk_path)

In [9]:
def create_data_loader(chunks_path, tokenizer, max_seq_len, batch_size):
    collator = DataCollator(tokenizer.pad_token_id) # copy_inputs_as_labels and shift_labels not needed as done by the transformer
    dataset = DatasetMIDI(
        pre_tokenize=True,
        files_paths=chunks_path,
        tokenizer=tokenizer,
        max_seq_len=max_seq_len,
        bos_token_id=tokenizer["BOS_None"],
        eos_token_id=tokenizer["EOS_None"])
    return DataLoader(dataset=dataset, collate_fn=collator, batch_size=batch_size)

In [10]:
from itertools import cycle

train_loader, valid_loader, test_loader = map(
    lambda chunk_filepaths: cycle(create_data_loader(chunk_filepaths, tokenizer, CHUNK_LENGTH, BATCH_SIZE)),
    chunk_paths
)

Pre-tokenizing: 100%|██████████| 7973/7973 [01:22<00:00, 96.07it/s] 
Pre-tokenizing: 100%|██████████| 1063/1063 [00:10<00:00, 96.85it/s] 
Pre-tokenizing: 100%|██████████| 943/943 [00:09<00:00, 97.35it/s] 


In [11]:
model_name = f'memorizing_miditok_{dataset_name}_t-{TIMESTEPS}_v-{VOCAB_SIZE}_{VERSION_LABEL}'
model_load_path = Path(f'../data/checkpoints/{model_name}.dat')
model_save_path = Path(f'../data/checkpoints/{model_name}.dat')
log_dir = Path(f'../tensorboard/{model_name}')
tensorboard_writer = SummaryWriter(log_dir)

Now we can create a transformer and set up our training loop

In [12]:
model = MemorizingTransformer(
    num_tokens = VOCAB_SIZE,
    dim = N_EMBED,
    depth = N_LAYER,
    heads = N_HEAD,
    dim_head = DIM_HEAD,
    attn_dropout = 0.2,
    ff_dropout = 0.2,
    memorizing_layers = (4, 5),
    max_knn_memories = CHUNK_LENGTH, # No point in having more meories than the chunk length as we clear them at the end of each chunk
    num_retrieved_memories = 32, # Top K
    xl_memory_layers = (2, 3, 4, 5),
    xl_max_memories = TIMESTEPS, # One context-length of XL memory
    pad_id = tokenizer.pad_token_id,
    # shift_knn_memories_down = 1,
    # shift_xl_memories_down = 1
).to(device)

print(sum(p.numel() for p in model.parameters()))

21958571


In [13]:
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

Model's state_dict:
token_emb.weight 	 torch.Size([411, 512])
rel_pos_bias.relative_attention_bias.weight 	 torch.Size([32, 8])
knn_rel_pos_bias.relative_attention_bias.weight 	 torch.Size([32, 8])
layers.0.0.fn.to_q.weight 	 torch.Size([512, 512])
layers.0.0.fn.to_kv.weight 	 torch.Size([128, 512])
layers.0.0.fn.to_out.weight 	 torch.Size([512, 512])
layers.0.0.fn.to_out.bias 	 torch.Size([512])
layers.0.0.norm.weight 	 torch.Size([512])
layers.0.0.norm.bias 	 torch.Size([512])
layers.0.1.fn.net.0.weight 	 torch.Size([2048, 512])
layers.0.1.fn.net.0.bias 	 torch.Size([2048])
layers.0.1.fn.net.3.weight 	 torch.Size([512, 2048])
layers.0.1.fn.net.3.bias 	 torch.Size([512])
layers.0.1.norm.weight 	 torch.Size([512])
layers.0.1.norm.bias 	 torch.Size([512])
layers.1.0.fn.to_q.weight 	 torch.Size([512, 512])
layers.1.0.fn.to_kv.weight 	 torch.Size([128, 512])
layers.1.0.fn.to_out.weight 	 torch.Size([512, 512])
layers.1.0.fn.to_out.bias 	 torch.Size([512])
layers.1.0.norm.weight 	 torch.Si

In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

In [15]:
def save_checkpoint(optimizer, completed_iterations, train_loss, val_loss):
    tensorboard_writer.add_scalar('Loss/train', train_loss, completed_iterations)
    tensorboard_writer.add_scalar('Loss/val', val_loss, completed_iterations)
    print(f'Writing to Tensorboard: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
    torch.save({
        'iter': completed_iterations,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, model_save_path)

completed_iterations = 0
if model_load_path.exists():
    checkpoint = torch.load(model_load_path)
    completed_iterations = checkpoint['iter']
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print(f"Loaded model from iteration {completed_iterations}")

In [None]:
for i in tqdm.tqdm(range(NUM_BATCHES - completed_iterations), mininterval = 10., desc = 'training'):
    model.train()

    data = next(train_loader)["input_ids"].to(device)
    if data.shape[0] != BATCH_SIZE:
        print(f'Skipping batch {i} as it is not of size {BATCH_SIZE}, but {data.shape[0]}')
        data = next(train_loader)["input_ids"].to(device)

    train_loss = 0.
    with model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
        xl_memories = None    
        seq, labels = data[:, :-1], data[:, 1:]

        for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):
            loss, xl_memories = model(
                seq_segment,
                labels = labels_segment,
                knn_memories = knn_memories,
                xl_memories = xl_memories
            )

            train_loss += loss.item() / SEGMENTS
            (loss / SEGMENTS).backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_CLIP_NORM)
    optimizer.step()
    optimizer.zero_grad()

    if not (i % VALIDATE_EVERY):
        model.eval()

        valid_data = next(valid_loader)["input_ids"].to(device)
        if valid_data.shape[0] != BATCH_SIZE:
            print(f'Skipping validation batch {i} as it is not of size {BATCH_SIZE}, but {valid_data.shape[0]}')
            valid_data = next(valid_loader)["input_ids"].to(device)

        valid_loss = 0.

        with torch.no_grad(), model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
            xl_memories = None    
            seq, labels = data[:, :-1], data[:, 1:]

            for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):
                loss, xl_memories = model(
                    seq_segment,
                    labels = labels_segment,
                    knn_memories = knn_memories,
                    xl_memories = xl_memories
                )

                valid_loss += loss.item() / SEGMENTS

        save_checkpoint(optimizer, i + completed_iterations, train_loss, valid_loss)

  x.storage().data_ptr() + x.storage_offset() * 4)


Writing to Tensorboard: Train Loss: 6.2693, Val Loss: 4.1836


training:   0%|          | 52/100000 [00:30<16:01:51,  1.73it/s]

Writing to Tensorboard: Train Loss: 2.1497, Val Loss: 2.1087


training:   0%|          | 126/100000 [01:12<15:34:38,  1.78it/s]

Writing to Tensorboard: Train Loss: 1.9683, Val Loss: 1.9291


training:   0%|          | 184/100000 [01:45<15:12:08,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.7628, Val Loss: 1.7350


training:   0%|          | 242/100000 [02:17<15:08:48,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.9042, Val Loss: 1.8593


training:   0%|          | 301/100000 [02:49<15:06:42,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.7302, Val Loss: 1.7346


training:   0%|          | 379/100000 [03:32<14:57:55,  1.85it/s]

Writing to Tensorboard: Train Loss: 1.7596, Val Loss: 1.7101


training:   0%|          | 437/100000 [04:05<15:10:49,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.7866, Val Loss: 1.7593


training:   0%|          | 494/100000 [04:36<15:03:53,  1.83it/s]

Skipping batch 498 as it is not of size 16, but 5
Writing to Tensorboard: Train Loss: 1.7670, Val Loss: 1.7274


training:   1%|          | 570/100000 [05:22<15:42:45,  1.76it/s]

Writing to Tensorboard: Train Loss: 1.7356, Val Loss: 1.6817


training:   1%|          | 628/100000 [05:55<15:21:08,  1.80it/s]

Writing to Tensorboard: Train Loss: 1.8315, Val Loss: 1.7716


training:   1%|          | 686/100000 [06:27<15:03:33,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.5568, Val Loss: 1.5042


training:   1%|          | 765/100000 [07:09<14:39:21,  1.88it/s]

Writing to Tensorboard: Train Loss: 1.7403, Val Loss: 1.6859


training:   1%|          | 824/100000 [07:41<14:46:55,  1.86it/s]

Writing to Tensorboard: Train Loss: 1.6864, Val Loss: 1.6440


training:   1%|          | 883/100000 [08:14<14:56:14,  1.84it/s]

Writing to Tensorboard: Train Loss: 1.6408, Val Loss: 1.5967


training:   1%|          | 941/100000 [08:46<14:56:43,  1.84it/s]

Writing to Tensorboard: Train Loss: 1.6232, Val Loss: 1.5786


training:   1%|          | 980/100000 [09:09<15:19:01,  1.80it/s]

Skipping batch 996 as it is not of size 16, but 5


training:   1%|          | 1020/100000 [09:30<15:01:37,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.6134, Val Loss: 1.5679


training:   1%|          | 1078/100000 [10:02<15:02:44,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.5843, Val Loss: 1.5448


training:   1%|          | 1137/100000 [10:35<14:56:10,  1.84it/s]

Writing to Tensorboard: Train Loss: 1.6038, Val Loss: 1.5572


training:   1%|          | 1214/100000 [11:16<14:31:05,  1.89it/s]

Writing to Tensorboard: Train Loss: 1.6108, Val Loss: 1.5599


training:   1%|▏         | 1273/100000 [11:48<14:39:59,  1.87it/s]

Writing to Tensorboard: Train Loss: 1.6550, Val Loss: 1.5968


training:   1%|▏         | 1332/100000 [12:21<14:43:01,  1.86it/s]

Writing to Tensorboard: Train Loss: 1.6314, Val Loss: 1.5829


training:   1%|▏         | 1391/100000 [12:54<14:51:44,  1.84it/s]

Writing to Tensorboard: Train Loss: 1.6107, Val Loss: 1.5549


training:   1%|▏         | 1469/100000 [13:36<14:37:57,  1.87it/s]

Writing to Tensorboard: Train Loss: 1.5759, Val Loss: 1.5232


training:   1%|▏         | 1489/100000 [13:48<15:10:01,  1.80it/s]

Skipping batch 1494 as it is not of size 16, but 5


training:   2%|▏         | 1527/100000 [14:09<15:06:34,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.6062, Val Loss: 1.5665


training:   2%|▏         | 1584/100000 [14:41<15:06:43,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.6650, Val Loss: 1.6294


training:   2%|▏         | 1661/100000 [15:22<14:37:38,  1.87it/s]

Writing to Tensorboard: Train Loss: 1.5821, Val Loss: 1.5234


training:   2%|▏         | 1720/100000 [15:56<14:53:46,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.5094, Val Loss: 1.4769


training:   2%|▏         | 1778/100000 [16:28<15:01:34,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.5480, Val Loss: 1.5117


training:   2%|▏         | 1856/100000 [17:10<14:34:04,  1.87it/s]

Writing to Tensorboard: Train Loss: 1.5104, Val Loss: 1.4660


training:   2%|▏         | 1915/100000 [17:43<14:48:19,  1.84it/s]

Writing to Tensorboard: Train Loss: 1.3809, Val Loss: 1.3376


training:   2%|▏         | 1973/100000 [18:16<14:53:56,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.5811, Val Loss: 1.5266
Skipping batch 1992 as it is not of size 16, but 5


training:   2%|▏         | 2032/100000 [18:49<14:55:29,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.6714, Val Loss: 1.6277


training:   2%|▏         | 2111/100000 [19:31<14:32:09,  1.87it/s]

Writing to Tensorboard: Train Loss: 1.6124, Val Loss: 1.5666


training:   2%|▏         | 2170/100000 [20:04<14:40:32,  1.85it/s]

Writing to Tensorboard: Train Loss: 1.5650, Val Loss: 1.5105


training:   2%|▏         | 2228/100000 [20:36<14:41:43,  1.85it/s]

Writing to Tensorboard: Train Loss: 1.5388, Val Loss: 1.4898


training:   2%|▏         | 2287/100000 [21:08<14:34:12,  1.86it/s]

Writing to Tensorboard: Train Loss: 1.5978, Val Loss: 1.5428


training:   2%|▏         | 2366/100000 [21:51<14:21:39,  1.89it/s]

Writing to Tensorboard: Train Loss: 1.6212, Val Loss: 1.5591


training:   2%|▏         | 2425/100000 [22:24<14:46:33,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.5522, Val Loss: 1.4981


training:   2%|▏         | 2484/100000 [22:56<14:43:36,  1.84it/s]

Skipping batch 2490 as it is not of size 16, but 5
Writing to Tensorboard: Train Loss: 1.7039, Val Loss: 1.6566


training:   3%|▎         | 2543/100000 [23:29<14:44:00,  1.84it/s]

Writing to Tensorboard: Train Loss: 1.5505, Val Loss: 1.5049


training:   3%|▎         | 2622/100000 [24:11<14:20:25,  1.89it/s]

Writing to Tensorboard: Train Loss: 1.6266, Val Loss: 1.5674


training:   3%|▎         | 2681/100000 [24:44<14:27:56,  1.87it/s]

Writing to Tensorboard: Train Loss: 1.5359, Val Loss: 1.4862


training:   3%|▎         | 2740/100000 [25:16<14:25:34,  1.87it/s]

Writing to Tensorboard: Train Loss: 1.5189, Val Loss: 1.4679


training:   3%|▎         | 2799/100000 [25:48<14:25:01,  1.87it/s]

Writing to Tensorboard: Train Loss: 1.5337, Val Loss: 1.4880


training:   3%|▎         | 2878/100000 [26:30<14:16:52,  1.89it/s]

Writing to Tensorboard: Train Loss: 1.6196, Val Loss: 1.5672


training:   3%|▎         | 2936/100000 [27:03<14:41:50,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.4966, Val Loss: 1.4472


training:   3%|▎         | 2975/100000 [27:25<15:03:25,  1.79it/s]

Skipping batch 2988 as it is not of size 16, but 5


training:   3%|▎         | 2994/100000 [27:36<14:50:05,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.5470, Val Loss: 1.4887


training:   3%|▎         | 3069/100000 [28:18<14:46:58,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.5640, Val Loss: 1.5051


training:   3%|▎         | 3126/100000 [28:50<14:53:40,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.5297, Val Loss: 1.4861


training:   3%|▎         | 3184/100000 [29:23<14:49:33,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.5502, Val Loss: 1.4947


training:   3%|▎         | 3261/100000 [30:05<14:36:49,  1.84it/s]

Writing to Tensorboard: Train Loss: 1.5692, Val Loss: 1.5202


training:   3%|▎         | 3315/100000 [30:36<14:49:37,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.6574, Val Loss: 1.6061


training:   3%|▎         | 3391/100000 [31:18<14:38:10,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.6146, Val Loss: 1.5726


training:   3%|▎         | 3449/100000 [31:51<14:45:56,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.4500, Val Loss: 1.4004


training:   3%|▎         | 3469/100000 [32:03<15:13:44,  1.76it/s]

Skipping batch 3486 as it is not of size 16, but 5


training:   4%|▎         | 3507/100000 [32:24<15:09:19,  1.77it/s]

Writing to Tensorboard: Train Loss: 1.5067, Val Loss: 1.4602


training:   4%|▎         | 3583/100000 [33:06<14:37:28,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.6486, Val Loss: 1.5975


training:   4%|▎         | 3640/100000 [33:38<14:47:06,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.5458, Val Loss: 1.4873


training:   4%|▎         | 3697/100000 [34:10<14:45:55,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.5010, Val Loss: 1.4597


training:   4%|▍         | 3776/100000 [34:53<14:29:04,  1.85it/s]

Writing to Tensorboard: Train Loss: 1.5004, Val Loss: 1.4435


training:   4%|▍         | 3833/100000 [35:26<14:46:38,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.5763, Val Loss: 1.5204


training:   4%|▍         | 3889/100000 [35:58<14:54:44,  1.79it/s]

Writing to Tensorboard: Train Loss: 1.4177, Val Loss: 1.3616


training:   4%|▍         | 3966/100000 [36:40<14:29:24,  1.84it/s]

Writing to Tensorboard: Train Loss: 1.5300, Val Loss: 1.4810
Skipping batch 3984 as it is not of size 16, but 5


training:   4%|▍         | 4024/100000 [37:13<14:40:13,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.4803, Val Loss: 1.4205


training:   4%|▍         | 4082/100000 [37:45<14:40:37,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.4798, Val Loss: 1.4286


training:   4%|▍         | 4158/100000 [38:27<14:19:48,  1.86it/s]

Writing to Tensorboard: Train Loss: 1.5253, Val Loss: 1.4601


training:   4%|▍         | 4217/100000 [39:00<14:29:33,  1.84it/s]

Skipping validation batch 4224 as it is not of size 16, but 7
Writing to Tensorboard: Train Loss: 1.5809, Val Loss: 1.5338


training:   4%|▍         | 4276/100000 [39:33<14:33:51,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.6798, Val Loss: 1.6271


training:   4%|▍         | 4335/100000 [40:06<14:33:45,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.4930, Val Loss: 1.4388


training:   4%|▍         | 4412/100000 [40:48<14:20:59,  1.85it/s]

Writing to Tensorboard: Train Loss: 1.5761, Val Loss: 1.5179


training:   4%|▍         | 4470/100000 [41:21<14:41:28,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.5812, Val Loss: 1.5333
Skipping batch 4482 as it is not of size 16, but 5


training:   5%|▍         | 4529/100000 [41:54<14:46:01,  1.80it/s]

Writing to Tensorboard: Train Loss: 1.5890, Val Loss: 1.5166


training:   5%|▍         | 4604/100000 [42:36<14:28:47,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.5632, Val Loss: 1.5066


training:   5%|▍         | 4662/100000 [43:09<14:37:07,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.6153, Val Loss: 1.5588


training:   5%|▍         | 4719/100000 [43:41<14:43:41,  1.80it/s]

Writing to Tensorboard: Train Loss: 1.5519, Val Loss: 1.4992


training:   5%|▍         | 4795/100000 [44:23<14:25:07,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.5201, Val Loss: 1.4642


training:   5%|▍         | 4852/100000 [44:56<14:42:56,  1.80it/s]

Writing to Tensorboard: Train Loss: 1.5906, Val Loss: 1.5337


training:   5%|▍         | 4909/100000 [45:28<14:38:53,  1.80it/s]

Writing to Tensorboard: Train Loss: 1.5338, Val Loss: 1.4706


training:   5%|▍         | 4966/100000 [46:00<14:41:20,  1.80it/s]

Skipping batch 4980 as it is not of size 16, but 5


training:   5%|▍         | 4985/100000 [46:10<14:29:31,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.4715, Val Loss: 1.4273


training:   5%|▌         | 5041/100000 [46:42<14:47:39,  1.78it/s]

Writing to Tensorboard: Train Loss: 1.6072, Val Loss: 1.5519


training:   5%|▌         | 5119/100000 [47:25<14:23:20,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.5540, Val Loss: 1.5108


training:   5%|▌         | 5174/100000 [47:57<14:33:35,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.5600, Val Loss: 1.5284


training:   5%|▌         | 5231/100000 [48:29<14:32:04,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.7683, Val Loss: 1.7065


training:   5%|▌         | 5308/100000 [49:11<14:19:30,  1.84it/s]

Writing to Tensorboard: Train Loss: 1.6177, Val Loss: 1.5552


training:   5%|▌         | 5366/100000 [49:44<14:32:44,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.6525, Val Loss: 1.5993


training:   5%|▌         | 5423/100000 [50:16<14:31:17,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.6175, Val Loss: 1.5618


training:   5%|▌         | 5462/100000 [50:39<14:47:49,  1.77it/s]

Skipping batch 5478 as it is not of size 16, but 5


training:   6%|▌         | 5500/100000 [50:59<14:36:14,  1.80it/s]

Writing to Tensorboard: Train Loss: 1.6686, Val Loss: 1.6050


training:   6%|▌         | 5556/100000 [51:31<14:35:37,  1.80it/s]

Writing to Tensorboard: Train Loss: 1.5749, Val Loss: 1.5165


training:   6%|▌         | 5632/100000 [52:13<14:15:27,  1.84it/s]

Writing to Tensorboard: Train Loss: 1.7130, Val Loss: 1.6504


training:   6%|▌         | 5690/100000 [52:46<14:24:03,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.6857, Val Loss: 1.6444


training:   6%|▌         | 5749/100000 [53:19<14:24:23,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.6471, Val Loss: 1.5951


training:   6%|▌         | 5807/100000 [53:51<14:27:09,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.6083, Val Loss: 1.5600


training:   6%|▌         | 5886/100000 [54:35<14:14:06,  1.84it/s]

Writing to Tensorboard: Train Loss: 1.6576, Val Loss: 1.6212


training:   6%|▌         | 5941/100000 [55:06<14:25:13,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.7139, Val Loss: 1.6722


training:   6%|▌         | 5961/100000 [55:18<14:47:24,  1.77it/s]

Skipping batch 5976 as it is not of size 16, but 5


training:   6%|▌         | 5997/100000 [55:38<14:46:54,  1.77it/s]

Writing to Tensorboard: Train Loss: 1.6466, Val Loss: 1.5834


training:   6%|▌         | 6075/100000 [56:21<14:10:36,  1.84it/s]

Writing to Tensorboard: Train Loss: 1.7054, Val Loss: 1.6800


training:   6%|▌         | 6132/100000 [56:53<14:25:10,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.6412, Val Loss: 1.5988


training:   6%|▌         | 6208/100000 [57:36<14:18:24,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.6916, Val Loss: 1.6708


training:   6%|▋         | 6265/100000 [58:08<14:24:00,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.6427, Val Loss: 1.5869


training:   6%|▋         | 6323/100000 [58:41<14:22:36,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.6904, Val Loss: 1.6453


training:   6%|▋         | 6400/100000 [59:24<14:10:38,  1.83it/s]

Writing to Tensorboard: Train Loss: 1.6779, Val Loss: 1.6336


training:   6%|▋         | 6456/100000 [59:56<14:25:26,  1.80it/s]

Writing to Tensorboard: Train Loss: 1.6191, Val Loss: 1.5751
Skipping batch 6474 as it is not of size 16, but 5


training:   7%|▋         | 6513/100000 [1:00:28<14:24:16,  1.80it/s]

Writing to Tensorboard: Train Loss: 1.6502, Val Loss: 1.6079


training:   7%|▋         | 6591/100000 [1:01:11<14:03:08,  1.85it/s]

Writing to Tensorboard: Train Loss: 1.6531, Val Loss: 1.6076


training:   7%|▋         | 6649/100000 [1:01:43<14:14:12,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.6442, Val Loss: 1.5916


training:   7%|▋         | 6708/100000 [1:02:16<14:16:33,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.6712, Val Loss: 1.6290


training:   7%|▋         | 6784/100000 [1:02:59<14:14:54,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.6941, Val Loss: 1.6353


training:   7%|▋         | 6840/100000 [1:03:31<14:23:57,  1.80it/s]

Writing to Tensorboard: Train Loss: 1.5922, Val Loss: 1.5571


training:   7%|▋         | 6896/100000 [1:04:03<14:23:25,  1.80it/s]

Writing to Tensorboard: Train Loss: 1.5819, Val Loss: 1.5563


training:   7%|▋         | 6972/100000 [1:04:45<14:13:40,  1.82it/s]

Skipping batch 6972 as it is not of size 16, but 5
Writing to Tensorboard: Train Loss: 1.7238, Val Loss: 1.6834


training:   7%|▋         | 7028/100000 [1:05:16<14:15:31,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.7064, Val Loss: 1.6797


training:   7%|▋         | 7086/100000 [1:05:49<14:14:43,  1.81it/s]

Writing to Tensorboard: Train Loss: 1.6963, Val Loss: 1.6421


training:   7%|▋         | 7164/100000 [1:06:32<14:00:16,  1.84it/s]

Writing to Tensorboard: Train Loss: 1.7299, Val Loss: 1.6729


training:   7%|▋         | 7221/100000 [1:07:04<14:10:19,  1.82it/s]

Writing to Tensorboard: Train Loss: 1.6831, Val Loss: 1.6431


training:   7%|▋         | 7240/100000 [1:07:16<14:33:51,  1.77it/s]

# Initial Results

With the base REMI vocab size of 411 tokens along with the same embed / head / layer count as our hand built model, the results are significantly worse. The training curve almost immediately flattened out at 1.6 loss, compared to under 1.0 previously.

This could be due to e.g.

- Different encoding. Not only are tokens different in MidiTok / Remi but we used bar / beat *embeddings* rather than tokens previously. There are so many more meanings to the tokens in addition to the count rather than just note / duration / sos / eos / pad that it would be much harder to guess the right answer. If we do see convergence it should mean we have learnt more useful information rather than just gamed the system so to speak.
- Larger vocabulary = more choices to pick for the next token
- The MIDITok data loader only provides max `CHUNK_LENGTH` tokens, whereas our custom 'contiguous' data loader always fed in entire tracks. Only a handful of tracks are between 1 and 2k tokens, most are 3 - 5K, so are being split into multiple pieces.
- The LucidRains Memorizing Transformers architecture is very similar to our model but there are a couple of changes (memory on a subset of layer vs all, lack of absolute embeddings) and of course it is a different implementation. I would expect it to be less buggy if anything though.

I did have to tweak the transformer's KNN memory to run exclusively on the GPU as index was on CPU and the data on disk, which was a major bottleneck (as found in our model). Luckily once again the code was nearly identical so it was easy to follow / edit.

## Things to try

- Train the tokenizer with 1K, 5K, 10K, 20K, 30K size vocab. This should result in less tokens per song. We might need to drop the max seq length to avoid over-padded songs.

- Add in chord / tempo / time sig / rest tokens. This should result in more tokens per song. We might need to increase the max seq length to avoid over-split songs.

> Note - the files will need re-encoding as they are split based on estimated token length, which depends on encoding and vocab size!

> Another note! - When the vocab size was halved, the number of tokens per file seemed to double, as you might expect given the bpe was trained on the same data. That meant doubling the chunk size and segment count in order to split the files into the same number of pieces.

- Augment the dataset with pitch and velocity shifted versions using the MIDITok tools

- Full Lakh dataset, although if we can't do well on a curated, stylistic dataset then a more varied one is likely to decrease performance.


## Experiment log

- Was seeing increase in loss after initial drop which eventually came back down again. Decreasing learning rate made it happen later and worse. Increased learning rate (and added dropout) which seemed to negate it.

- Some apparent cyclic loss patterns so tried shuffling between epochs. Requires either waiting for a pretokenise op between epochs (long even on vg_large, let alone Lakh) or tokenise on the fly which reduces GPU usage to around 70%

# Parallel DataLoaders

Shuffling between epochs should avoid cyclic learning behaviour, but
- If we pretokenise there is a big delay between epochs
- If we don't, the GPU isn't used as effectively

We could have two dataloaders, and pretokenise one whilst the other is in use.

This doesn't seem to be an issue right now but we could implement it if needed.

Claude suggested the following:

In [None]:
from concurrent.futures import ThreadPoolExecutor
from torch.utils.data import DataLoader

class ParallelDataLoaderManager:
    def __init__(self, chunk_filepaths: list, tokenizer, chunk_length: int, batch_size: int):
        self.chunk_filepaths = chunk_filepaths
        self.tokenizer = tokenizer
        self.chunk_length = chunk_length
        self.batch_size = batch_size
        self.executor = ThreadPoolExecutor(max_workers=1)
        self.next_loader_future = None
        self.current_loader = None

    def create_dataloader(self, filepaths) -> DataLoader:
        # Using the create_data_loader function from your workspace
        return create_data_loader(
            filepaths,
            self.tokenizer,
            self.chunk_length, 
            self.batch_size
        )
    
    def prepare_next_loader(self):
        # Shuffle filepaths for next epoch
        shuffled_paths = self.chunk_filepaths.copy()
        random.shuffle(shuffled_paths)
        
        # Start creating next loader in background
        self.next_loader_future = self.executor.submit(
            self.create_dataloader,
            shuffled_paths
        )

    def get_loader(self) -> DataLoader:
        if self.current_loader is None:
            # First call - create initial loader
            self.current_loader = self.create_dataloader(self.chunk_filepaths)
            # Start preparing next loader 
            self.prepare_next_loader()
            return self.current_loader

        # Wait for next loader to be ready and swap
        self.current_loader = self.next_loader_future.result()
        # Start preparing next loader
        self.prepare_next_loader()
        return self.current_loader

    def __del__(self):
        self.executor.shutdown()