In [1]:
import os
os.chdir(r'9 - Rebuild')
import torch
from miditok import REMI, TokenizerConfig  # here we choose to use REMI
from pathlib import Path
import random
from miditok.utils import split_files_for_training
from miditok.data_augmentation import augment_dataset
from miditok.pytorch_data import DatasetMIDI, DataCollator
from torch.utils.data import DataLoader
from miditok import TokSequence
from multiprocessing import Pool
from memorizing_transformers_pytorch import MemorizingTransformer
import tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device}.")

Using cuda.


In [2]:
if device == "cuda":
    print(f"Device: {torch.cuda.get_device_name()}.")

Device: NVIDIA GeForce RTX 4090.


# Memorizing Transformers

From our MIDITok research, we know that we we want to 

- Train a BPE tokenizer on the entire dataset.
- Save it / Load it (BPE is deterministic so data doesn't need to be decoded with a the same tokenizer it was encoded with, providing whatever is used was trained on the same data with the same config. Unigram is *not* deterministic however so would require the same exact tokenizer for encode / decode).
- Shuffle file names, so songs aren't biased to a set.
- Split into test / train / validation sets.
- Split the files into chunks for each set.
- Optionally augment the dataset with pitch / velocity / duration shifted versions
- Shuffle the chunks when loading, so that parts of a single song aren't biased to a batch.
- Load the chunks with `max_seq_len` equal to that used when splitting files, to minimise padding / truncated data.
- Split the chunks into context-length sequences and feed them through contiguously.
- Manually reset memories between chunks rather than auto-reset on BOS / EOS tokens.

In [3]:
CHUNK_LENGTH = 1024
SEGMENTS = 4
BATCH_SIZE = 16
VOCAB_SIZE = 411 # REMI untrained token count is 411
N_EMBED = 512
N_LAYER = 8
N_HEAD = 8
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 1e-3
NUM_BATCHES = int(1e5)
MAX_GRAD_CLIP_NORM = 0.5
VALIDATE_EVERY  = 64
TIMESTEPS = CHUNK_LENGTH // SEGMENTS
DIM_HEAD = N_EMBED // N_HEAD
VERSION = 2
TOKENIZER_CONFIG = 'Basic'

midi_path = Path(f'../data/midi')
dataset_name = 'vg_large'
midi_dataset_path = Path(f'{midi_path}/{dataset_name}')
midi_file_paths = [p.resolve() for p in midi_dataset_path.glob("**/*.mid")]

tokenizer_save_path = Path(f'../data/vocab/MidiTok/{dataset_name}_{VOCAB_SIZE}_{TOKENIZER_CONFIG}.json')

if not tokenizer_save_path.exists():
    TOKENIZER_PARAMS = {
        "pitch_range": (21, 109),
        "beat_res": {(0, 4): 8, (4, 12): 4},
        "num_velocities": 32,
        "use_programs": True
        # "use_chords": True,
        # "use_time_signatures": True,
        # "use_tempos": True,
        # "num_tempos": 32,  # number of tempo bins
        # "tempo_range": (40, 250)
    }
    tokenizer_confg = TokenizerConfig(**TOKENIZER_PARAMS)
    tokenizer = REMI(tokenizer_confg)
    print(f"Untrained token count: {tokenizer.len}")
    tokenizer.train(vocab_size=VOCAB_SIZE, files_paths=midi_file_paths)
    tokenizer.save(tokenizer_save_path)
else:
    tokenizer = REMI(params=tokenizer_save_path)

Untrained token count: 411


  super().__init__(tokenizer_config, params)
  tokenizer.train(vocab_size=VOCAB_SIZE, files_paths=midi_file_paths)


In [4]:
random.seed(42)
random.shuffle(midi_file_paths)
len(midi_file_paths)

3839

In [5]:
n1 = int(0.8 * len(midi_file_paths))
n2 = int(0.9 * len(midi_file_paths))
train_filepaths = midi_file_paths[:n1]
valid_filepaths = midi_file_paths[n1:n2]
test_filepaths = midi_file_paths[n2:]

print(f'Train files: {len(train_filepaths)}, Valid files: {len(valid_filepaths)}, Test files: {len(test_filepaths)}')

Train files: 3071, Valid files: 384, Test files: 384


In [6]:
chunk_path = Path(f'{midi_path}/{dataset_name}_{VOCAB_SIZE}_{TOKENIZER_CONFIG}_miditok_split')
train_chunk_path = Path(f'{chunk_path}/train')
valid_chunk_path = Path(f'{chunk_path}/valid')
test_chunk_path = Path(f'{chunk_path}/test')

split_data = [
    (train_filepaths, train_chunk_path),
    (valid_filepaths, valid_chunk_path),
    (test_filepaths, test_chunk_path)
]

def chunk_files(filepaths, tokenizer, chunks_dir, max_seq_len):
    split_files_for_training(
        files_paths=filepaths,
        tokenizer=tokenizer,
        save_dir=chunks_dir,
        max_seq_len=max_seq_len,
        num_overlap_bars=1
    )

if not chunk_path.exists():
    with Pool(processes=3) as pool:
        pool.starmap(chunk_files, [(filepaths, tokenizer, chunks_dir, CHUNK_LENGTH) for filepaths, chunks_dir in split_data])

Splitting music files (../data/midi/vg_large_411_Basic_miditok_split/test): 100%|██████████| 384/384 [00:17<00:00, 22.11it/s]]
Splitting music files (../data/midi/vg_large_411_Basic_miditok_split/valid): 100%|██████████| 384/384 [00:19<00:00, 19.81it/s]]
Splitting music files (../data/midi/vg_large_411_Basic_miditok_split/train): 100%|██████████| 3071/3071 [05:29<00:00,  9.33it/s]


In [7]:
train_chunk_filepaths = list(train_chunk_path.glob("**/*.mid"))
valid_chunk_filepaths = list(valid_chunk_path.glob("**/*.mid"))
test_chunk_filepaths = list(test_chunk_path.glob("**/*.mid"))

print(f'Train chunks: {len(train_chunk_filepaths)}, Valid chunks: {len(valid_chunk_filepaths)}, Test chunks: {len(test_chunk_filepaths)}')

Train chunks: 15461, Valid chunks: 2064, Test chunks: 1816


In [8]:
chunk_paths = [train_chunk_filepaths, valid_chunk_filepaths, test_chunk_filepaths]

for chunk_path in chunk_paths:
    random.shuffle(chunk_path)

In [9]:
def create_data_loader(chunks_path, tokenizer, max_seq_len, batch_size):
    collator = DataCollator(tokenizer.pad_token_id) # copy_inputs_as_labels and shift_labels not needed as done by the transformer
    dataset = DatasetMIDI(
        pre_tokenize=True,
        files_paths=chunks_path,
        tokenizer=tokenizer,
        max_seq_len=max_seq_len,
        bos_token_id=tokenizer["BOS_None"],
        eos_token_id=tokenizer["EOS_None"])
    return DataLoader(dataset=dataset, collate_fn=collator, batch_size=batch_size)

In [10]:
from itertools import cycle

train_loader, valid_loader, test_loader = map(
    lambda chunk_filepaths: cycle(create_data_loader(chunk_filepaths, tokenizer, CHUNK_LENGTH, BATCH_SIZE)),
    chunk_paths
)

Pre-tokenizing: 100%|██████████| 15461/15461 [06:11<00:00, 41.65it/s]
Pre-tokenizing: 100%|██████████| 2064/2064 [00:49<00:00, 41.57it/s]
Pre-tokenizing: 100%|██████████| 1816/1816 [00:43<00:00, 41.37it/s]


In [11]:
model_name = f'memorizing_miditok_{dataset_name}_{VOCAB_SIZE}_v{VERSION}'
model_load_path = Path(f'../data/checkpoints/{model_name}.dat')
model_save_path = Path(f'../data/checkpoints/{model_name}.dat')
log_dir = Path(f'../tensorboard/{model_name}')
tensorboard_writer = SummaryWriter(log_dir)

Now we can create a transformer and set up our training loop

In [12]:
model = MemorizingTransformer(
    num_tokens = VOCAB_SIZE,
    dim = N_EMBED,
    depth = N_LAYER,
    heads = N_HEAD,
    dim_head = DIM_HEAD,
    memorizing_layers = (4, 5),
    max_knn_memories = CHUNK_LENGTH, # No point in having more meories than the chunk length as we clear them at the end of each chunk
    num_retrieved_memories = 32, # Top K
    xl_memory_layers = (2, 3, 4, 5),
    xl_max_memories = TIMESTEPS, # One context-length of XL memory
    # shift_knn_memories_down = 1,
    # shift_xl_memories_down = 1
).to(device)

print(sum(p.numel() for p in model.parameters()))

21958571


In [13]:
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

Model's state_dict:
token_emb.weight 	 torch.Size([411, 512])
rel_pos_bias.relative_attention_bias.weight 	 torch.Size([32, 8])
knn_rel_pos_bias.relative_attention_bias.weight 	 torch.Size([32, 8])
layers.0.0.fn.to_q.weight 	 torch.Size([512, 512])
layers.0.0.fn.to_kv.weight 	 torch.Size([128, 512])
layers.0.0.fn.to_out.weight 	 torch.Size([512, 512])
layers.0.0.fn.to_out.bias 	 torch.Size([512])
layers.0.0.norm.weight 	 torch.Size([512])
layers.0.0.norm.bias 	 torch.Size([512])
layers.0.1.fn.net.0.weight 	 torch.Size([2048, 512])
layers.0.1.fn.net.0.bias 	 torch.Size([2048])
layers.0.1.fn.net.3.weight 	 torch.Size([512, 2048])
layers.0.1.fn.net.3.bias 	 torch.Size([512])
layers.0.1.norm.weight 	 torch.Size([512])
layers.0.1.norm.bias 	 torch.Size([512])
layers.1.0.fn.to_q.weight 	 torch.Size([512, 512])
layers.1.0.fn.to_kv.weight 	 torch.Size([128, 512])
layers.1.0.fn.to_out.weight 	 torch.Size([512, 512])
layers.1.0.fn.to_out.bias 	 torch.Size([512])
layers.1.0.norm.weight 	 torch.Si

In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

In [None]:
def save_checkpoint(optimizer, completed_iterations, train_loss, val_loss):
    tensorboard_writer.add_scalar('Loss/train', train_loss, completed_iterations)
    tensorboard_writer.add_scalar('Loss/val', val_loss, completed_iterations)
    print(f'Writing to Tensorboard: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
    torch.save({
        'iter': completed_iterations,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, model_save_path)

completed_iterations = 0
if model_load_path.exists():
    checkpoint = torch.load(model_load_path)
    completed_iterations = checkpoint['iter']
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print(f"Loaded model from iteration {completed_iterations}")

In [None]:
for i in tqdm.tqdm(range(NUM_BATCHES - completed_iterations), mininterval = 10., desc = 'training'):
    model.train()

    data = next(train_loader)["input_ids"].to(device)
    if data.shape[0] != BATCH_SIZE:
        print(f'Skipping batch {i} as it is not of size {BATCH_SIZE}, but {data.shape[0]}')
        data = next(train_loader)["input_ids"].to(device)

    train_loss = 0.
    with model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
        xl_memories = None    
        seq, labels = data[:, :-1], data[:, 1:]

        for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):
            loss, xl_memories = model(
                seq_segment,
                labels = labels_segment,
                knn_memories = knn_memories,
                xl_memories = xl_memories
            )

            train_loss += loss.item() / SEGMENTS
            (loss / SEGMENTS).backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_CLIP_NORM)
    optimizer.step()
    optimizer.zero_grad()

    if not (i % VALIDATE_EVERY):
        model.eval()

        valid_data = next(valid_loader)
        if valid_data.shape[0] != BATCH_SIZE:
            print(f'Skipping validation batch {i} as it is not of size {BATCH_SIZE}, but {valid_data.shape[0]}')
            valid_data = next(valid_loader)["input_ids"].to(device)

        valid_loss = 0.

        with torch.no_grad(), model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
            xl_memories = None    
            seq, labels = data[:, :-1], data[:, 1:]

            for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):
                loss, xl_memories = model(
                    seq_segment,
                    labels = labels_segment,
                    knn_memories = knn_memories,
                    xl_memories = xl_memories
                )

                valid_loss += loss.item() / SEGMENTS

        save_checkpoint(optimizer, i + completed_iterations, train_loss, valid_loss)

  x.storage().data_ptr() + x.storage_offset() * 4)


Writing to Tensorboard: Train Loss: 6.1980, Val Loss: 4.6114


training:   0%|          | 30/100000 [00:10<9:27:18,  2.94it/s]

Writing to Tensorboard: Train Loss: 2.2646, Val Loss: 2.2434


training:   0%|          | 106/100000 [00:30<7:47:17,  3.56it/s]

Writing to Tensorboard: Train Loss: 1.8780, Val Loss: 1.8546


training:   0%|          | 186/100000 [00:53<7:40:30,  3.61it/s]

Writing to Tensorboard: Train Loss: 1.8046, Val Loss: 1.7817


training:   0%|          | 225/100000 [01:04<7:49:28,  3.54it/s]

Writing to Tensorboard: Train Loss: 1.8478, Val Loss: 1.8281


training:   0%|          | 298/100000 [01:25<7:47:31,  3.55it/s]

Writing to Tensorboard: Train Loss: 1.7682, Val Loss: 1.7525


training:   0%|          | 371/100000 [01:45<7:42:31,  3.59it/s]

Writing to Tensorboard: Train Loss: 1.6069, Val Loss: 1.5878


training:   0%|          | 447/100000 [02:07<7:46:28,  3.56it/s]

Writing to Tensorboard: Train Loss: 1.7061, Val Loss: 1.6903


training:   0%|          | 485/100000 [02:18<7:54:22,  3.50it/s]

Writing to Tensorboard: Train Loss: 1.5747, Val Loss: 1.5558


training:   1%|          | 557/100000 [02:39<7:50:19,  3.52it/s]

Writing to Tensorboard: Train Loss: 1.6171, Val Loss: 1.6016


training:   1%|          | 633/100000 [03:00<7:48:05,  3.54it/s]

Writing to Tensorboard: Train Loss: 1.5247, Val Loss: 1.5052


training:   1%|          | 671/100000 [03:11<7:52:31,  3.50it/s]

Writing to Tensorboard: Train Loss: 1.5205, Val Loss: 1.5034


training:   1%|          | 743/100000 [03:32<7:52:43,  3.50it/s]

Writing to Tensorboard: Train Loss: 1.5802, Val Loss: 1.5625


training:   1%|          | 817/100000 [03:53<7:48:15,  3.53it/s]

Writing to Tensorboard: Train Loss: 1.5951, Val Loss: 1.5781


training:   1%|          | 892/100000 [04:14<7:47:06,  3.54it/s]

Writing to Tensorboard: Train Loss: 1.6127, Val Loss: 1.5950


training:   1%|          | 928/100000 [04:25<7:53:34,  3.49it/s]

Writing to Tensorboard: Train Loss: 1.5401, Val Loss: 1.5208


training:   1%|          | 962/100000 [04:35<8:01:37,  3.43it/s]

Skipping batch 966 as it is not of size 16, but 5


training:   1%|          | 1000/100000 [04:45<7:48:59,  3.52it/s]

Writing to Tensorboard: Train Loss: 1.5587, Val Loss: 1.5372


training:   1%|          | 1074/100000 [05:07<7:49:06,  3.51it/s]

Writing to Tensorboard: Train Loss: 1.4725, Val Loss: 1.4556


training:   1%|          | 1147/100000 [05:27<7:45:46,  3.54it/s]

Writing to Tensorboard: Train Loss: 1.5775, Val Loss: 1.5608


training:   1%|          | 1184/100000 [05:39<7:57:23,  3.45it/s]

Writing to Tensorboard: Train Loss: 1.4653, Val Loss: 1.4500


training:   1%|▏         | 1254/100000 [06:00<8:00:27,  3.43it/s]

Writing to Tensorboard: Train Loss: 1.5190, Val Loss: 1.5053


training:   1%|▏         | 1329/100000 [06:20<7:47:00,  3.52it/s]

Writing to Tensorboard: Train Loss: 1.4617, Val Loss: 1.4416


training:   1%|▏         | 1405/100000 [06:42<7:44:26,  3.54it/s]

Writing to Tensorboard: Train Loss: 1.5434, Val Loss: 1.5280


training:   1%|▏         | 1443/100000 [06:54<7:59:21,  3.43it/s]

Writing to Tensorboard: Train Loss: 1.4841, Val Loss: 1.4726


training:   2%|▏         | 1513/100000 [07:14<7:54:42,  3.46it/s]

Writing to Tensorboard: Train Loss: 1.4343, Val Loss: 1.4187


training:   2%|▏         | 1586/100000 [07:35<7:50:47,  3.48it/s]

Writing to Tensorboard: Train Loss: 1.5344, Val Loss: 1.5210


training:   2%|▏         | 1660/100000 [07:56<7:49:40,  3.49it/s]

Writing to Tensorboard: Train Loss: 1.5473, Val Loss: 1.5286


training:   2%|▏         | 1697/100000 [08:07<7:54:46,  3.45it/s]

Writing to Tensorboard: Train Loss: 1.5451, Val Loss: 1.5274


training:   2%|▏         | 1767/100000 [08:28<7:53:45,  3.46it/s]

Writing to Tensorboard: Train Loss: 1.5629, Val Loss: 1.5493


training:   2%|▏         | 1841/100000 [08:48<7:40:05,  3.56it/s]

Writing to Tensorboard: Train Loss: 1.4252, Val Loss: 1.4129


training:   2%|▏         | 1917/100000 [09:09<7:35:59,  3.59it/s]

Writing to Tensorboard: Train Loss: 1.5289, Val Loss: 1.5131
Skipping batch 1932 as it is not of size 16, but 5


training:   2%|▏         | 1955/100000 [09:20<7:42:04,  3.54it/s]

Writing to Tensorboard: Train Loss: 1.4210, Val Loss: 1.4078


training:   2%|▏         | 2027/100000 [09:42<7:47:46,  3.49it/s]

Writing to Tensorboard: Train Loss: 1.6225, Val Loss: 1.6015


training:   2%|▏         | 2102/100000 [10:02<7:37:24,  3.57it/s]

Writing to Tensorboard: Train Loss: 1.4843, Val Loss: 1.4726


training:   2%|▏         | 2140/100000 [10:13<7:42:00,  3.53it/s]

Writing to Tensorboard: Train Loss: 1.5149, Val Loss: 1.5025


training:   2%|▏         | 2214/100000 [10:35<7:50:27,  3.46it/s]

Writing to Tensorboard: Train Loss: 1.3159, Val Loss: 1.3058


training:   2%|▏         | 2289/100000 [10:56<7:37:29,  3.56it/s]

Writing to Tensorboard: Train Loss: 1.4057, Val Loss: 1.3879


training:   2%|▏         | 2364/100000 [11:17<7:38:28,  3.55it/s]

Writing to Tensorboard: Train Loss: 1.4926, Val Loss: 1.4739


training:   2%|▏         | 2401/100000 [11:29<7:57:04,  3.41it/s]

Writing to Tensorboard: Train Loss: 1.3995, Val Loss: 1.3855


training:   2%|▏         | 2471/100000 [11:50<7:56:24,  3.41it/s]

Writing to Tensorboard: Train Loss: 1.5236, Val Loss: 1.5076


training:   3%|▎         | 2546/100000 [12:12<7:48:38,  3.47it/s]

Writing to Tensorboard: Train Loss: 1.4173, Val Loss: 1.4054


training:   3%|▎         | 2619/100000 [12:33<7:48:00,  3.47it/s]

Writing to Tensorboard: Train Loss: 1.4437, Val Loss: 1.4312


training:   3%|▎         | 2655/100000 [12:44<7:51:54,  3.44it/s]

Writing to Tensorboard: Train Loss: 1.4208, Val Loss: 1.4057


training:   3%|▎         | 2726/100000 [13:04<7:51:07,  3.44it/s]

Writing to Tensorboard: Train Loss: 1.5195, Val Loss: 1.5017


training:   3%|▎         | 2798/100000 [13:25<7:47:27,  3.47it/s]

Writing to Tensorboard: Train Loss: 1.6750, Val Loss: 1.6604


training:   3%|▎         | 2872/100000 [13:46<7:40:57,  3.51it/s]

Writing to Tensorboard: Train Loss: 1.5390, Val Loss: 1.5262
Skipping batch 2898 as it is not of size 16, but 5


training:   3%|▎         | 2909/100000 [13:58<7:54:31,  3.41it/s]

Writing to Tensorboard: Train Loss: 1.4319, Val Loss: 1.4203


training:   3%|▎         | 2981/100000 [14:19<7:49:57,  3.44it/s]

Writing to Tensorboard: Train Loss: 1.5217, Val Loss: 1.5067


training:   3%|▎         | 3057/100000 [14:40<7:33:29,  3.56it/s]

Writing to Tensorboard: Train Loss: 1.4469, Val Loss: 1.4337


training:   3%|▎         | 3136/100000 [15:01<7:28:05,  3.60it/s]

Writing to Tensorboard: Train Loss: 1.4101, Val Loss: 1.4018


training:   3%|▎         | 3175/100000 [15:12<7:29:47,  3.59it/s]

Writing to Tensorboard: Train Loss: 1.3928, Val Loss: 1.3835


training:   3%|▎         | 3250/100000 [15:33<7:24:56,  3.62it/s]

Writing to Tensorboard: Train Loss: 1.4059, Val Loss: 1.4014


training:   3%|▎         | 3326/100000 [15:55<7:31:17,  3.57it/s]

Writing to Tensorboard: Train Loss: 1.5606, Val Loss: 1.5421


training:   3%|▎         | 3363/100000 [16:06<7:38:00,  3.52it/s]

Writing to Tensorboard: Train Loss: 1.5238, Val Loss: 1.5098


training:   3%|▎         | 3434/100000 [16:27<7:40:05,  3.50it/s]

Writing to Tensorboard: Train Loss: 1.4202, Val Loss: 1.4033


training:   4%|▎         | 3509/100000 [16:48<7:35:20,  3.53it/s]

Writing to Tensorboard: Train Loss: 1.4781, Val Loss: 1.4655


training:   4%|▎         | 3547/100000 [16:59<7:40:11,  3.49it/s]

Writing to Tensorboard: Train Loss: 1.5189, Val Loss: 1.5039


training:   4%|▎         | 3625/100000 [17:21<7:31:23,  3.56it/s]

Writing to Tensorboard: Train Loss: 1.5252, Val Loss: 1.5089


training:   4%|▎         | 3706/100000 [17:42<7:16:28,  3.68it/s]

Writing to Tensorboard: Train Loss: 1.4075, Val Loss: 1.3915


training:   4%|▎         | 3747/100000 [17:54<7:17:13,  3.67it/s]

Writing to Tensorboard: Train Loss: 1.4772, Val Loss: 1.4641


training:   4%|▍         | 3825/100000 [18:14<7:05:53,  3.76it/s]

Writing to Tensorboard: Train Loss: 1.4950, Val Loss: 1.4757


training:   4%|▍         | 3864/100000 [18:25<7:12:01,  3.71it/s]

Skipping batch 3864 as it is not of size 16, but 5


training:   4%|▍         | 3904/100000 [18:35<7:04:21,  3.77it/s]

Writing to Tensorboard: Train Loss: 1.3957, Val Loss: 1.3881


training:   4%|▍         | 3944/100000 [18:46<7:09:58,  3.72it/s]

Writing to Tensorboard: Train Loss: 1.3909, Val Loss: 1.3802


training:   4%|▍         | 4020/100000 [19:07<7:07:28,  3.74it/s]

Writing to Tensorboard: Train Loss: 1.5205, Val Loss: 1.5075


training:   4%|▍         | 4059/100000 [19:18<7:16:19,  3.66it/s]

Writing to Tensorboard: Train Loss: 1.4780, Val Loss: 1.4631


training:   4%|▍         | 4137/100000 [19:39<7:13:19,  3.69it/s]

Writing to Tensorboard: Train Loss: 1.5683, Val Loss: 1.5507


training:   4%|▍         | 4215/100000 [20:01<7:15:40,  3.66it/s]

Writing to Tensorboard: Train Loss: 1.5388, Val Loss: 1.5275


training:   4%|▍         | 4253/100000 [20:12<7:33:51,  3.52it/s]

Writing to Tensorboard: Train Loss: 1.5762, Val Loss: 1.5672


training:   4%|▍         | 4326/100000 [20:33<7:35:27,  3.50it/s]

Writing to Tensorboard: Train Loss: 1.4734, Val Loss: 1.4563


training:   4%|▍         | 4402/100000 [20:54<7:23:59,  3.59it/s]

Writing to Tensorboard: Train Loss: 1.4992, Val Loss: 1.4878


training:   4%|▍         | 4472/100000 [21:15<7:32:15,  3.52it/s]

Writing to Tensorboard: Train Loss: 1.3878, Val Loss: 1.3749


training:   5%|▍         | 4509/100000 [21:26<7:43:53,  3.43it/s]

Writing to Tensorboard: Train Loss: 1.3972, Val Loss: 1.3736


training:   5%|▍         | 4582/100000 [21:48<7:40:39,  3.45it/s]

Writing to Tensorboard: Train Loss: 1.4039, Val Loss: 1.3873


training:   5%|▍         | 4659/100000 [22:08<7:22:09,  3.59it/s]

Writing to Tensorboard: Train Loss: 1.6313, Val Loss: 1.6159


training:   5%|▍         | 4699/100000 [22:20<7:21:25,  3.60it/s]

Writing to Tensorboard: Train Loss: 1.5328, Val Loss: 1.5252


training:   5%|▍         | 4777/100000 [22:41<7:13:13,  3.66it/s]

Writing to Tensorboard: Train Loss: 1.5107, Val Loss: 1.4958


training:   5%|▍         | 4817/100000 [22:52<7:18:09,  3.62it/s]

Skipping batch 4830 as it is not of size 16, but 5


training:   5%|▍         | 4855/100000 [23:02<7:13:31,  3.66it/s]

Writing to Tensorboard: Train Loss: 1.5336, Val Loss: 1.5189


training:   5%|▍         | 4893/100000 [23:13<7:20:01,  3.60it/s]

Writing to Tensorboard: Train Loss: 1.6446, Val Loss: 1.6328


training:   5%|▍         | 4968/100000 [23:33<7:11:31,  3.67it/s]

Writing to Tensorboard: Train Loss: 1.6446, Val Loss: 1.6246


training:   5%|▌         | 5046/100000 [23:55<7:11:45,  3.67it/s]

Writing to Tensorboard: Train Loss: 1.6080, Val Loss: 1.5926


training:   5%|▌         | 5085/100000 [24:06<7:21:20,  3.58it/s]

Writing to Tensorboard: Train Loss: 1.5687, Val Loss: 1.5566


training:   5%|▌         | 5161/100000 [24:27<7:10:16,  3.67it/s]

Writing to Tensorboard: Train Loss: 1.5218, Val Loss: 1.5178


training:   5%|▌         | 5241/100000 [24:49<7:10:10,  3.67it/s]

Writing to Tensorboard: Train Loss: 1.5648, Val Loss: 1.5487


training:   5%|▌         | 5281/100000 [25:01<7:20:05,  3.59it/s]

Writing to Tensorboard: Train Loss: 1.5915, Val Loss: 1.5771


training:   5%|▌         | 5353/100000 [25:21<7:22:14,  3.57it/s]

Writing to Tensorboard: Train Loss: 1.5772, Val Loss: 1.5786


training:   5%|▌         | 5425/100000 [25:45<7:55:56,  3.31it/s]

Writing to Tensorboard: Train Loss: 1.6827, Val Loss: 1.6616


training:   6%|▌         | 5502/100000 [26:08<7:48:05,  3.36it/s]

Writing to Tensorboard: Train Loss: 1.5589, Val Loss: 1.5501


training:   6%|▌         | 5542/100000 [26:20<7:42:09,  3.41it/s]

Writing to Tensorboard: Train Loss: 1.5659, Val Loss: 1.5611


training:   6%|▌         | 5617/100000 [26:40<7:26:12,  3.53it/s]

Writing to Tensorboard: Train Loss: 1.6018, Val Loss: 1.5914


training:   6%|▌         | 5694/100000 [27:02<7:20:55,  3.56it/s]

Writing to Tensorboard: Train Loss: 1.5836, Val Loss: 1.5652


training:   6%|▌         | 5732/100000 [27:12<7:18:35,  3.58it/s]

Writing to Tensorboard: Train Loss: 1.4669, Val Loss: 1.4613


training:   6%|▌         | 5769/100000 [27:23<7:18:55,  3.58it/s]

Skipping batch 5796 as it is not of size 16, but 5


training:   6%|▌         | 5808/100000 [27:33<7:09:18,  3.66it/s]

Writing to Tensorboard: Train Loss: 1.6364, Val Loss: 1.6202


training:   6%|▌         | 5884/100000 [27:54<7:14:50,  3.61it/s]

Writing to Tensorboard: Train Loss: 1.5611, Val Loss: 1.5473


training:   6%|▌         | 5918/100000 [28:05<7:29:21,  3.49it/s]

Writing to Tensorboard: Train Loss: 1.6589, Val Loss: 1.6363


training:   6%|▌         | 5991/100000 [28:26<7:25:40,  3.52it/s]

Writing to Tensorboard: Train Loss: 1.6868, Val Loss: 1.6788


training:   6%|▌         | 6066/100000 [28:47<7:24:29,  3.52it/s]

Writing to Tensorboard: Train Loss: 1.5657, Val Loss: 1.5513


training:   6%|▌         | 6140/100000 [29:08<7:25:16,  3.51it/s]

Writing to Tensorboard: Train Loss: 1.6728, Val Loss: 1.6664


training:   6%|▌         | 6177/100000 [29:20<7:33:43,  3.45it/s]

Writing to Tensorboard: Train Loss: 1.5826, Val Loss: 1.5751


training:   6%|▌         | 6247/100000 [29:41<7:38:58,  3.40it/s]

Writing to Tensorboard: Train Loss: 1.7198, Val Loss: 1.7065


training:   6%|▋         | 6321/100000 [30:02<7:28:56,  3.48it/s]

Writing to Tensorboard: Train Loss: 1.5972, Val Loss: 1.5761


training:   6%|▋         | 6395/100000 [30:23<7:24:30,  3.51it/s]

Writing to Tensorboard: Train Loss: 1.6552, Val Loss: 1.6382


training:   6%|▋         | 6432/100000 [30:34<7:30:35,  3.46it/s]

Writing to Tensorboard: Train Loss: 1.6470, Val Loss: 1.6384


training:   7%|▋         | 6504/100000 [30:54<7:24:22,  3.51it/s]

Writing to Tensorboard: Train Loss: 1.6235, Val Loss: 1.6053


training:   7%|▋         | 6578/100000 [31:15<7:17:06,  3.56it/s]

Writing to Tensorboard: Train Loss: 1.6440, Val Loss: 1.6331


training:   7%|▋         | 6651/100000 [31:37<7:31:06,  3.45it/s]

Writing to Tensorboard: Train Loss: 1.6062, Val Loss: 1.6032


training:   7%|▋         | 6686/100000 [31:47<7:33:23,  3.43it/s]

Writing to Tensorboard: Train Loss: 1.5908, Val Loss: 1.5647


training:   7%|▋         | 6760/100000 [32:08<7:19:06,  3.54it/s]

Skipping batch 6762 as it is not of size 16, but 5
Writing to Tensorboard: Train Loss: 1.5420, Val Loss: 1.5306


training:   7%|▋         | 6839/100000 [32:29<7:04:32,  3.66it/s]

Writing to Tensorboard: Train Loss: 1.6492, Val Loss: 1.6018


training:   7%|▋         | 6879/100000 [32:41<7:11:35,  3.60it/s]

Writing to Tensorboard: Train Loss: 1.5623, Val Loss: 1.5864


training:   7%|▋         | 6953/100000 [33:01<7:06:54,  3.63it/s]

Writing to Tensorboard: Train Loss: 1.5739, Val Loss: 1.5618


training:   7%|▋         | 7029/100000 [33:23<7:10:50,  3.60it/s]

Writing to Tensorboard: Train Loss: 1.6359, Val Loss: 1.6194


training:   7%|▋         | 7066/100000 [33:34<7:20:51,  3.51it/s]

Writing to Tensorboard: Train Loss: 1.6860, Val Loss: 1.6762


training:   7%|▋         | 7144/100000 [33:55<7:13:47,  3.57it/s]

Writing to Tensorboard: Train Loss: 1.7023, Val Loss: 1.6844


training:   7%|▋         | 7222/100000 [34:17<7:07:35,  3.62it/s]

Writing to Tensorboard: Train Loss: 1.5672, Val Loss: 1.5554


training:   7%|▋         | 7261/100000 [34:28<7:14:12,  3.56it/s]

Writing to Tensorboard: Train Loss: 1.6303, Val Loss: 1.6090


training:   7%|▋         | 7336/100000 [34:50<7:13:40,  3.56it/s]

Writing to Tensorboard: Train Loss: 1.8188, Val Loss: 1.8008


# Initial Results

With the base REMI vocab size of 411 tokens along with the same embed / head / layer count as our hand built model, the results are significantly worse. The training curve almost immediately flattened out at 1.6 loss, compared to under 1.0 previously.

This could be due to e.g.

- Larger vocabulary = more choices to pick for the next token
- The MIDITok data loader only provides max 1000 tokens, whereas our custom 'contiguous' data loader always fed in entire tracks. Only a handful of tracks are between 1 and 2k tokens, most are 3 - 5K, so are being split into multiple pieces.
- The LucidRains Memorizing Transformers architecture is very similar to our model but there are a couple of changes (memory on a subset of layer vs all) and of course it is a different implementation. I would expect it to be less buggy if anything though.

I did have to tweak the transformer's KNN memory to run exclusively on the GPU as index was on CPU and the data on disk, which was a major bottleneck (as found in our model). Luckily once again the code was nearly identical so it was easy to follow / edit.

## Things to try

- Train the tokenizer with 1K, 5K, 10K, 20K, 30K size vocab. This should result in less tokens per song. We might need to drop the max seq length to avoid over-padded songs.

- Add in chord / tempo / time sig / rest tokens. This should result in more tokens per song. We might need to increase the max seq length to avoid over-split songs.

> Note - the files will need re-encoding as they are split based on estimated token length, which depends on encoding and vocab size!

- Augment the dataset with pitch and velocity shifted versions using the MIDITok tools

- Full Lakh dataset, although if we can't do well on a curated, stylistic dataset then a more varied one is likely to decrease performance.