In [1]:
import os
os.chdir(r'9 - Rebuild')
import torch
from miditok import REMI, TokenizerConfig  # here we choose to use REMI
from pathlib import Path
import random
from miditok.utils import split_files_for_training
from miditok.data_augmentation import augment_dataset
from miditok.pytorch_data import DatasetMIDI, DataCollator
from torch.utils.data import DataLoader
from miditok import TokSequence
from multiprocessing import Pool
from memorizing_transformers_pytorch import MemorizingTransformer

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device}.")

Using cuda.


In [2]:
if device == "cuda":
    print(f"Device: {torch.cuda.get_device_name()}.")

Device: NVIDIA GeForce RTX 4090.


# Memorizing Transformers

From our MIDITok research, we know that we we want to 

- Train a BPE tokenizer on the entire dataset.
- Save it / Load it (BPE is deterministic so data doesn't need to be decoded with a the same tokenizer it was encoded with, providing whatever is used was trained on the same data with the same config. Unigram is *not* deterministic however so would require the same exact tokenizer for encode / decode).
- Shuffle file names.
- Split into test / train / validation sets, so songs aren't biased to a set.
- Split the files into chunks for each set.
- Optionally augment the dataset with pitch / velocity / duration shifted versions
- Shuffle the chunks when loading, so that parts of a single song aren't biased to a batch.
- Load the chunks with `max_seq_len` equal to that used when splitting files, to minimise padding / truncated data.
- Split the chunks into context-length sequences and feed them through contiguously.
- Manually reset memories between chunks rather than auto-reset on BOS / EOS tokens.

In [3]:
CHUNK_LENGTH = 1024
SEGMENTS = 4
BATCH_SIZE = 16
VOCAB_SIZE = 1000
N_EMBED = 512
N_LAYER = 8
N_HEAD = 8
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 1e-3

midi_path = Path(f'../data/midi')
dataset_name = 'vg_large'
midi_dataset_path = Path(f'{midi_path}/{dataset_name}')
midi_file_paths = [p.resolve() for p in midi_dataset_path.glob("**/*.mid")]

tokenizer_save_path = Path(f'../data/vocab/MidiTok/{dataset_name}.json')

if not tokenizer_save_path.exists():
    TOKENIZER_PARAMS = {
        "pitch_range": (21, 109),
        "beat_res": {(0, 4): 8, (4, 12): 4},
        "num_velocities": 32,
        "use_programs": True
        # "use_chords": True,
        # "use_time_signatures": True,
        # "use_tempos": True,
        # "num_tempos": 32,  # number of tempo bins
        # "tempo_range": (40, 250)
    }
    tokenizer_confg = TokenizerConfig(**TOKENIZER_PARAMS)
    tokenizer = REMI(tokenizer_confg)
    tokenizer.train(vocab_size=VOCAB_SIZE, files_paths=midi_file_paths)
    tokenizer.save(tokenizer_save_path)
else:
    tokenizer = REMI(params=tokenizer_save_path)

  super().__init__(tokenizer_config, params)


In [4]:
random.seed(42)
random.shuffle(midi_file_paths)
len(midi_file_paths)

3839

In [5]:
n1 = int(0.8 * len(midi_file_paths))
n2 = int(0.9 * len(midi_file_paths))
train_filepaths = midi_file_paths[:n1]
valid_filepaths = midi_file_paths[n1:n2]
test_filepaths = midi_file_paths[n2:]

print(f'Train files: {len(train_filepaths)}, Valid files: {len(valid_filepaths)}, Test files: {len(test_filepaths)}')

Train files: 3071, Valid files: 384, Test files: 384


In [6]:
chunk_path = Path(f'{midi_path}/{dataset_name}_miditok_split')
train_chunk_path = Path(f'{chunk_path}/train')
valid_chunk_path = Path(f'{chunk_path}/valid')
test_chunk_path = Path(f'{chunk_path}/test')

split_data = [
    (train_filepaths, train_chunk_path),
    (valid_filepaths, valid_chunk_path),
    (test_filepaths, test_chunk_path)
]

def chunk_files(filepaths, tokenizer, chunks_dir, max_seq_len):
    split_files_for_training(
        files_paths=filepaths,
        tokenizer=tokenizer,
        save_dir=chunks_dir,
        max_seq_len=max_seq_len,
        num_overlap_bars=1
    )

if not chunk_path.exists():
    with Pool(processes=3) as pool:
        pool.starmap(chunk_files, [(filepaths, tokenizer, chunks_dir, CHUNK_LENGTH) for filepaths, chunks_dir in split_data])

In [7]:
train_chunk_filepaths = list(train_chunk_path.glob("**/*.mid"))
valid_chunk_filepaths = list(valid_chunk_path.glob("**/*.mid"))
test_chunk_filepaths = list(test_chunk_path.glob("**/*.mid"))

print(f'Train chunks: {len(train_chunk_filepaths)}, Valid chunks: {len(valid_chunk_filepaths)}, Test chunks: {len(test_chunk_filepaths)}')

Train chunks: 8322, Valid chunks: 1087, Test chunks: 974


In [8]:
chunk_paths = [train_chunk_filepaths, valid_chunk_filepaths, test_chunk_filepaths]

for chunk_path in chunk_paths:
    random.shuffle(chunk_path)

In [9]:
def create_data_loader(chunks_path, tokenizer, max_seq_len, batch_size):
    collator = DataCollator(tokenizer.pad_token_id, copy_inputs_as_labels = True, shift_labels = True)
    dataset = DatasetMIDI(
        files_paths=chunks_path,
        tokenizer=tokenizer,
        max_seq_len=max_seq_len,
        bos_token_id=tokenizer["BOS_None"],
        eos_token_id=tokenizer["EOS_None"])
    return DataLoader(dataset=dataset, collate_fn=collator, batch_size=batch_size)

In [10]:
from itertools import cycle

train_loader, valid_loader, test_loader = map(
    lambda chunk_filepaths: cycle(create_data_loader(chunk_filepaths, tokenizer, CHUNK_LENGTH, BATCH_SIZE)),
    chunk_paths
)

Now we can create a transformer and set up our training loop

In [11]:
TIMESTEPS = CHUNK_LENGTH // SEGMENTS

model = MemorizingTransformer(
    num_tokens = VOCAB_SIZE,
    dim = N_EMBED,
    depth = N_LAYER,
    heads = N_HEAD,
    dim_head = N_EMBED // N_HEAD,
    memorizing_layers = (4, 5),
    max_knn_memories = BATCH_SIZE * TIMESTEPS * N_EMBED, # B,T,C
    num_retrieved_memories = 32, # Top K
    xl_memory_layers = (2, 3, 4, 5),
    xl_max_memories = TIMESTEPS, # One context-length of XL memory
    clear_memories_on_sos_token_id = 0, # We know we will never have mixed songs in a sequence
    # shift_knn_memories_down = 1,
    # shift_xl_memories_down = 1
).cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)