In [1]:
import os
os.chdir(r'9 - Rebuild')
import torch
from miditok import REMI, TokenizerConfig  # here we choose to use REMI
from pathlib import Path
import random
from miditok.utils import split_files_for_training
from miditok.data_augmentation import augment_dataset
from miditok.pytorch_data import DatasetMIDI, DataCollator
from torch.utils.data import DataLoader
from miditok import TokSequence
from multiprocessing import Pool
from memorizing_transformers_pytorch import MemorizingTransformer
import tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device}.")

Using cuda.


In [2]:
if device == "cuda":
    print(f"Device: {torch.cuda.get_device_name()}.")

Device: NVIDIA GeForce RTX 4090.


# Memorizing Transformers

From our MIDITok research, we know that we we want to 

- Train a BPE tokenizer on the entire dataset.
- Save it / Load it (BPE is deterministic so data doesn't need to be decoded with a the same tokenizer it was encoded with, providing whatever is used was trained on the same data with the same config. Unigram is *not* deterministic however so would require the same exact tokenizer for encode / decode).
- Shuffle file names, so songs aren't biased to a set.
- Split into test / train / validation sets.
- Split the files into chunks for each set.
- Optionally augment the dataset with pitch / velocity / duration shifted versions
- Shuffle the chunks when loading, so that parts of a single song aren't biased to a batch.
- Load the chunks with `max_seq_len` equal to that used when splitting files, to minimise padding / truncated data.
- Split the chunks into context-length sequences and feed them through contiguously.
- Manually reset memories between chunks rather than auto-reset on BOS / EOS tokens.

In [3]:
CHUNK_LENGTH = 2048
SEGMENTS = 8
TIMESTEPS = CHUNK_LENGTH // SEGMENTS # Context length
BATCH_SIZE = 16
VOCAB_SIZE = 411 # REMI untrained token count is 411
N_EMBED = 512
N_LAYER = 8
N_HEAD = 8
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-3
NUM_BATCHES = int(1e5)
MAX_GRAD_CLIP_NORM = 0.5
VALIDATE_EVERY  = 64
DIM_HEAD = N_EMBED // N_HEAD
VERSION_LABEL = "augmented"
TOKENIZER_CONFIG = 'Basic'

midi_path = Path(f'../data/midi')
dataset_name = 'vg_large'
midi_dataset_path = Path(f'{midi_path}/{dataset_name}')
midi_file_paths = [p.resolve() for p in midi_dataset_path.glob("**/*.mid")]

tokenizer_save_path = Path(f'../data/vocab/MidiTok/{dataset_name}_{VOCAB_SIZE}_{TOKENIZER_CONFIG}.json')

if not tokenizer_save_path.exists():
    TOKENIZER_PARAMS = {
        "pitch_range": (21, 109),
        "beat_res": {(0, 4): 8, (4, 12): 4},
        "num_velocities": 32,
        "use_programs": True
        # "use_chords": True,
        # "use_time_signatures": True,
        # "use_tempos": True,
        # "num_tempos": 32,  # number of tempo bins
        # "tempo_range": (40, 250)
    }
    tokenizer_confg = TokenizerConfig(**TOKENIZER_PARAMS)
    tokenizer = REMI(tokenizer_confg)
    print(f"Untrained token count: {tokenizer.len}")
    tokenizer.train(vocab_size=VOCAB_SIZE, files_paths=midi_file_paths)
    tokenizer.save(tokenizer_save_path)
else:
    tokenizer = REMI(params=tokenizer_save_path)

  super().__init__(tokenizer_config, params)


In [4]:
random.seed(42)
random.shuffle(midi_file_paths)
len(midi_file_paths)

3839

In [5]:
n1 = int(0.8 * len(midi_file_paths))
n2 = int(0.9 * len(midi_file_paths))
train_filepaths = midi_file_paths[:n1]
valid_filepaths = midi_file_paths[n1:n2]
test_filepaths = midi_file_paths[n2:]

print(f'Train files: {len(train_filepaths)}, Valid files: {len(valid_filepaths)}, Test files: {len(test_filepaths)}')

Train files: 3071, Valid files: 384, Test files: 384


In [6]:
chunk_path = Path(f'{midi_path}/mtok_split/{dataset_name}/v-{VOCAB_SIZE}_t-{TOKENIZER_CONFIG}_c-{CHUNK_LENGTH}')
train_chunk_path = Path(f'{chunk_path}/train')
valid_chunk_path = Path(f'{chunk_path}/valid')
test_chunk_path = Path(f'{chunk_path}/test')

split_data = [
    (train_filepaths, train_chunk_path),
    (valid_filepaths, valid_chunk_path),
    (test_filepaths, test_chunk_path)
]

def chunk_files(filepaths, tokenizer, chunks_dir, max_seq_len):
    split_files_for_training(
        files_paths=filepaths,
        tokenizer=tokenizer,
        save_dir=chunks_dir,
        max_seq_len=max_seq_len,
        num_overlap_bars=1
    )
    augment_dataset(
        chunks_dir,
        pitch_offsets=[-12, 12],
        velocity_offsets=[-4, 4],
        duration_offsets=[-0.5, 0.5],
    )

if not chunk_path.exists():
    with Pool(processes=3) as pool:
        pool.starmap(chunk_files, [(filepaths, tokenizer, chunks_dir, CHUNK_LENGTH) for filepaths, chunks_dir in split_data])

Splitting music files (../data/midi/mtok_split/vg_large/v-411_t-Basic_c-2048/test): 100%|██████████| 384/384 [00:08<00:00, 46.93it/s]]
Splitting music files (../data/midi/mtok_split/vg_large/v-411_t-Basic_c-2048/valid): 100%|██████████| 384/384 [00:08<00:00, 43.51it/s]
Performing data augmentation: 100%|██████████| 943/943 [00:39<00:00, 23.99it/s]]in):  65%|██████▍   | 1989/3071 [00:34<00:17, 60.24it/s]
Performing data augmentation: 100%|██████████| 1063/1063 [00:44<00:00, 23.98it/s]n):  75%|███████▍  | 2295/3071 [00:40<00:13, 59.16it/s]
Splitting music files (../data/midi/mtok_split/vg_large/v-411_t-Basic_c-2048/train): 100%|██████████| 3071/3071 [01:06<00:00, 46.17it/s]
Performing data augmentation: 100%|██████████| 7973/7973 [18:03<00:00,  7.36it/s]


In [7]:
train_chunk_filepaths = list(train_chunk_path.glob("**/*.mid"))
valid_chunk_filepaths = list(valid_chunk_path.glob("**/*.mid"))
test_chunk_filepaths = list(test_chunk_path.glob("**/*.mid"))

print(f'Train chunks: {len(train_chunk_filepaths)}, Valid chunks: {len(valid_chunk_filepaths)}, Test chunks: {len(test_chunk_filepaths)}')

Train chunks: 52558, Valid chunks: 6993, Test chunks: 6267


In [8]:
chunk_paths = [train_chunk_filepaths, valid_chunk_filepaths, test_chunk_filepaths]

for chunk_path in chunk_paths:
    random.shuffle(chunk_path)

In [9]:
def create_data_loader(chunks_path, tokenizer, max_seq_len, batch_size):
    collator = DataCollator(tokenizer.pad_token_id) # copy_inputs_as_labels and shift_labels not needed as done by the transformer
    dataset = DatasetMIDI(
        pre_tokenize=True,
        files_paths=chunks_path,
        tokenizer=tokenizer,
        max_seq_len=max_seq_len,
        bos_token_id=tokenizer["BOS_None"],
        eos_token_id=tokenizer["EOS_None"])
    return DataLoader(dataset=dataset, collate_fn=collator, batch_size=batch_size)

In [10]:
from itertools import cycle

train_loader, valid_loader, test_loader = map(
    lambda chunk_filepaths: cycle(create_data_loader(chunk_filepaths, tokenizer, CHUNK_LENGTH, BATCH_SIZE)),
    chunk_paths
)

Pre-tokenizing: 100%|██████████| 52558/52558 [15:43<00:00, 55.68it/s]
Pre-tokenizing: 100%|██████████| 6993/6993 [02:01<00:00, 57.54it/s]
Pre-tokenizing: 100%|██████████| 6267/6267 [01:52<00:00, 55.54it/s]


In [11]:
model_name = f'memorizing_miditok_{dataset_name}_t-{TIMESTEPS}_v-{VOCAB_SIZE}_{VERSION_LABEL}'
model_load_path = Path(f'../data/checkpoints/{model_name}.dat')
model_save_path = Path(f'../data/checkpoints/{model_name}.dat')
log_dir = Path(f'../tensorboard/{model_name}')
tensorboard_writer = SummaryWriter(log_dir)

Now we can create a transformer and set up our training loop

In [12]:
model = MemorizingTransformer(
    num_tokens = VOCAB_SIZE,
    dim = N_EMBED,
    depth = N_LAYER,
    heads = N_HEAD,
    dim_head = DIM_HEAD,
    attn_dropout = 0.2,
    ff_dropout = 0.2,
    memorizing_layers = (4, 5),
    max_knn_memories = CHUNK_LENGTH, # No point in having more meories than the chunk length as we clear them at the end of each chunk
    num_retrieved_memories = 32, # Top K
    xl_memory_layers = (2, 3, 4, 5),
    xl_max_memories = TIMESTEPS, # One context-length of XL memory
    pad_id = tokenizer.pad_token_id,
    # shift_knn_memories_down = 1,
    # shift_xl_memories_down = 1
).to(device)

print(sum(p.numel() for p in model.parameters()))

21958571


In [13]:
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

Model's state_dict:
token_emb.weight 	 torch.Size([411, 512])
rel_pos_bias.relative_attention_bias.weight 	 torch.Size([32, 8])
knn_rel_pos_bias.relative_attention_bias.weight 	 torch.Size([32, 8])
layers.0.0.fn.to_q.weight 	 torch.Size([512, 512])
layers.0.0.fn.to_kv.weight 	 torch.Size([128, 512])
layers.0.0.fn.to_out.weight 	 torch.Size([512, 512])
layers.0.0.fn.to_out.bias 	 torch.Size([512])
layers.0.0.norm.weight 	 torch.Size([512])
layers.0.0.norm.bias 	 torch.Size([512])
layers.0.1.fn.net.0.weight 	 torch.Size([2048, 512])
layers.0.1.fn.net.0.bias 	 torch.Size([2048])
layers.0.1.fn.net.3.weight 	 torch.Size([512, 2048])
layers.0.1.fn.net.3.bias 	 torch.Size([512])
layers.0.1.norm.weight 	 torch.Size([512])
layers.0.1.norm.bias 	 torch.Size([512])
layers.1.0.fn.to_q.weight 	 torch.Size([512, 512])
layers.1.0.fn.to_kv.weight 	 torch.Size([128, 512])
layers.1.0.fn.to_out.weight 	 torch.Size([512, 512])
layers.1.0.fn.to_out.bias 	 torch.Size([512])
layers.1.0.norm.weight 	 torch.Si

In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

In [15]:
def save_checkpoint(optimizer, completed_iterations, train_loss, val_loss):
    tensorboard_writer.add_scalar('Loss/train', train_loss, completed_iterations)
    tensorboard_writer.add_scalar('Loss/val', val_loss, completed_iterations)
    print(f'Writing to Tensorboard: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
    torch.save({
        'iter': completed_iterations,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, model_save_path)

completed_iterations = 0
if model_load_path.exists():
    checkpoint = torch.load(model_load_path)
    completed_iterations = checkpoint['iter']
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print(f"Loaded model from iteration {completed_iterations}")

In [None]:
for i in tqdm.tqdm(range(NUM_BATCHES - completed_iterations), mininterval = 10., desc = 'training'):
    model.train()

    data = next(train_loader)["input_ids"].to(device)
    if data.shape[0] != BATCH_SIZE:
        print(f'Skipping batch {i} as it is not of size {BATCH_SIZE}, but {data.shape[0]}')
        data = next(train_loader)["input_ids"].to(device)

    train_loss = 0.
    with model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
        xl_memories = None    
        seq, labels = data[:, :-1], data[:, 1:]

        for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):
            loss, xl_memories = model(
                seq_segment,
                labels = labels_segment,
                knn_memories = knn_memories,
                xl_memories = xl_memories
            )

            train_loss += loss.item() / SEGMENTS
            (loss / SEGMENTS).backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_CLIP_NORM)
    optimizer.step()
    optimizer.zero_grad()

    if not (i % VALIDATE_EVERY):
        model.eval()

        valid_data = next(valid_loader)["input_ids"].to(device)
        if valid_data.shape[0] != BATCH_SIZE:
            print(f'Skipping validation batch {i} as it is not of size {BATCH_SIZE}, but {valid_data.shape[0]}')
            valid_data = next(valid_loader)["input_ids"].to(device)

        valid_loss = 0.

        with torch.no_grad(), model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
            xl_memories = None    
            seq, labels = data[:, :-1], data[:, 1:]

            for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):
                loss, xl_memories = model(
                    seq_segment,
                    labels = labels_segment,
                    knn_memories = knn_memories,
                    xl_memories = xl_memories
                )

                valid_loss += loss.item() / SEGMENTS

        save_checkpoint(optimizer, i + completed_iterations, train_loss, valid_loss)

  x.storage().data_ptr() + x.storage_offset() * 4)


Writing to Tensorboard: Train Loss: 6.1722, Val Loss: 4.2459


training:   0%|          | 50/100000 [00:30<16:27:01,  1.69it/s]

Writing to Tensorboard: Train Loss: 2.3053, Val Loss: 2.2537


training:   0%|          | 123/100000 [01:12<15:37:30,  1.78it/s]

Writing to Tensorboard: Train Loss: 2.0493, Val Loss: 2.0265


training:   0%|          | 179/100000 [01:44<15:45:41,  1.76it/s]

Writing to Tensorboard: Train Loss: 2.0483, Val Loss: 1.9822


training:   0%|          | 255/100000 [02:27<15:27:20,  1.79it/s]

Writing to Tensorboard: Train Loss: 1.9109, Val Loss: 1.8700


training:   0%|          | 312/100000 [02:59<15:22:29,  1.80it/s]

Writing to Tensorboard: Train Loss: 1.8810, Val Loss: 1.8357


training:   0%|          | 367/100000 [03:42<15:49:15,  1.75it/s]

Writing to Tensorboard: Train Loss: 2.0241, Val Loss: 1.9513


training:   0%|          | 439/100000 [04:13<15:53:25,  1.74it/s]

Writing to Tensorboard: Train Loss: 1.7619, Val Loss: 1.7221


training:   0%|          | 457/100000 [04:25<16:26:29,  1.68it/s]

# Initial Results

With the base REMI vocab size of 411 tokens along with the same embed / head / layer count as our hand built model, the results are significantly worse. The training curve almost immediately flattened out at 1.6 loss, compared to under 1.0 previously.

This could be due to e.g.

- Different encoding. Not only are tokens different in MidiTok / Remi but we used bar / beat *embeddings* rather than tokens previously. There are so many more meanings to the tokens in addition to the count rather than just note / duration / sos / eos / pad that it would be much harder to guess the right answer. If we do see convergence it should mean we have learnt more useful information rather than just gamed the system so to speak.
- Larger vocabulary = more choices to pick for the next token
- The MIDITok data loader only provides max `CHUNK_LENGTH` tokens, whereas our custom 'contiguous' data loader always fed in entire tracks. Only a handful of tracks are between 1 and 2k tokens, most are 3 - 5K, so are being split into multiple pieces.
- The LucidRains Memorizing Transformers architecture is very similar to our model but there are a couple of changes (memory on a subset of layer vs all, lack of absolute embeddings) and of course it is a different implementation. I would expect it to be less buggy if anything though.

I did have to tweak the transformer's KNN memory to run exclusively on the GPU as index was on CPU and the data on disk, which was a major bottleneck (as found in our model). Luckily once again the code was nearly identical so it was easy to follow / edit.

## Things to try

- Train the tokenizer with 1K, 5K, 10K, 20K, 30K size vocab. This should result in less tokens per song. We might need to drop the max seq length to avoid over-padded songs.

- Add in chord / tempo / time sig / rest tokens. This should result in more tokens per song. We might need to increase the max seq length to avoid over-split songs.

> Note - the files will need re-encoding as they are split based on estimated token length, which depends on encoding and vocab size!

> Another note! - When the vocab size was halved, the number of tokens per file seemed to double, as you might expect given the bpe was trained on the same data. That meant doubling the chunk size and segment count in order to split the files into the same number of pieces.

- Augment the dataset with pitch and velocity shifted versions using the MIDITok tools

- Full Lakh dataset, although if we can't do well on a curated, stylistic dataset then a more varied one is likely to decrease performance.


## Experiment log

- Was seeing increase in loss after initial drop which eventually came back down again. Decreasing learning rate made it happen later and worse. Increased learning rate (and added dropout) which seemed to negate it.

- Some apparent cyclic loss patterns so tried shuffling between epochs. Requires either waiting for a pretokenise op between epochs (long even on vg_large, let alone Lakh) or tokenise on the fly which reduces GPU usage to around 70%

# Parallel DataLoaders

Shuffling between epochs should avoid cyclic learning behaviour, but
- If we pretokenise there is a big delay between epochs
- If we don't, the GPU isn't used as effectively

We could have two dataloaders, and pretokenise one whilst the other is in use.

This doesn't seem to be an issue right now but we could implement it if needed.

Claude suggested the following:

In [None]:
from concurrent.futures import ThreadPoolExecutor
from torch.utils.data import DataLoader

class ParallelDataLoaderManager:
    def __init__(self, chunk_filepaths: list, tokenizer, chunk_length: int, batch_size: int):
        self.chunk_filepaths = chunk_filepaths
        self.tokenizer = tokenizer
        self.chunk_length = chunk_length
        self.batch_size = batch_size
        self.executor = ThreadPoolExecutor(max_workers=1)
        self.next_loader_future = None
        self.current_loader = None

    def create_dataloader(self, filepaths) -> DataLoader:
        # Using the create_data_loader function from your workspace
        return create_data_loader(
            filepaths,
            self.tokenizer,
            self.chunk_length, 
            self.batch_size
        )
    
    def prepare_next_loader(self):
        # Shuffle filepaths for next epoch
        shuffled_paths = self.chunk_filepaths.copy()
        random.shuffle(shuffled_paths)
        
        # Start creating next loader in background
        self.next_loader_future = self.executor.submit(
            self.create_dataloader,
            shuffled_paths
        )

    def get_loader(self) -> DataLoader:
        if self.current_loader is None:
            # First call - create initial loader
            self.current_loader = self.create_dataloader(self.chunk_filepaths)
            # Start preparing next loader 
            self.prepare_next_loader()
            return self.current_loader

        # Wait for next loader to be ready and swap
        self.current_loader = self.next_loader_future.result()
        # Start preparing next loader
        self.prepare_next_loader()
        return self.current_loader

    def __del__(self):
        self.executor.shutdown()