In [1]:
import os
os.chdir(r'5 - TransformerXL')
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.tensorboard import SummaryWriter
from itertools import chain
from itertools import groupby
from functools import reduce
from typing import Collection, List
from pathlib import Path
import music21 as m21
musescore_path = '/usr/bin/mscore'
m21.environment.set('musicxmlPath', musescore_path)
m21.environment.set('musescoreDirectPNGPath', musescore_path)
# from midi_encoding import *

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device}.")

Using cuda.


# TransformerXL

Our vanilla transformer showed improvements, but suffered from only having attention within the current block.

We also only used absolute positional encodings, so tokens knew where they were in the sequence but not where they were relative to other tokens.

[TransformerXL](https://research.google/blog/transformer-xl-unleashing-the-potential-of-attention-models/) tackles both of these problems by 
1. Using a 'memory' for keys and values from the previous block, allowing information to propagate through time.
2. Employing relative positional encoding.

Obviously for either of these to work, data needs to be fed in sequentially, so our loading and batching strategy will once again need revisiting.

# Coding A Paper

Luckily I found [this walkthrough](https://www.youtube.com/playlist?list=PLam9sigHPGwOe8VDoS_6VT4jjlgs9Uepb) in the style of Karpathy's makemore videos.

## Notes

### Ep.2 Keeping GPUs Busy

We need to keep our music blocks contiguous across batches, e.g. for a batch size of four:


          |        Chunk 1        |------|        Chunk 2        |
|        | Batch 1 | Batch 2 | Batch 3 |        | Batch 4 | Batch 5 | Batch 6 |
|--------|---------|---------|---------|--------|---------|---------|---------|
| Song 1 | Block 1 | Block 2 | Block 3 | Song 5 | Block 1 | Block 2 | Block 3 |
| Song 2 | Block 1 | Block 2 | Block 3 | Song 6 | Block 1 | Block 2 | Block 3 |
| Song 3 | Block 1 | Block 2 | Block 3 | Song 7 | Block 1 | Block 2 | Block 3 |
| Song 4 | Block 1 | Block 2 | Block 3 | Song 8 | Block 1 | Block 2 | Block 3 |

Note that the above shows songs that are all the same length, which of course isn't what we have in reality.

This means that we either
- Crop long songs
- Pad short songs
- Connect them in a ragged way

The video takes the cropping approach, picking a given 'chunk' (i.e. multiple of block) size and cropping the song to a multiple of this chunk size, i.e.

In [2]:
blocks = 3
block_size = 256
chunk_size = blocks * block_size
chunk_size

768

So mod the song length by chunk size and crop.

Use `reshape` (or `view`?) to rearrange the a song into chunks, then `concat` to join the songs into one list of chunks, then `chunk` to split into batches.

Following the above, batch 1 block 1 should be the precursor to batch 2 block 1.

Data and labels per chunk are the same as in a vanilla transformer - labels are data offset by one.

# To consider

- Bigger data set
    - Cleaning / preparing files
    - Wrap in PyTorch data helper classes (help with batching?)
    - Streaming?
    - Parallel loading / processing on CPU with [pebble](https://pypi.org/project/Pebble/)
    - If we are moving to pop music with Lakh MIDI dataset, how will we handle instruments and percussion?

- Encoding strategy
    - Could be on the fly if quicker than GPU takes to process a batch, otherwise pre-encode
    - Follow song-per-batch-layer as outlined above

- Use einops rather than manually applying transformation functions where practical

- Relative positional encodings

- KNN memory
    - Vector index (fais?)
    - Memory mapped file for db
    - Second to last block only
    - Look up K nearest keys / values

- TransformerXL recurrent memory

- Vectorise head operations

- Monitoring during training
    - Tensorboard?

- 'Reverse teacher forcing' (offset future mask extra step)


# Lakh MIDI Dataset

This is a huge (~6GB) set of MIDI files of pretty much every kind of music scraped from across the internet.

See [the website](https://colinraffel.com/projects/lmd/#get) for more details.

## Data Quality

If we move to processing such a large dataset, we are going to need to pay more attention to quality.

That is, there will likely be corrupt files, but also files with long gaps of silence etc. that could throw off the training.

## Instrument Info

We are going to get a lot of different instruments, and setting them all as piano will lose a huge amount of information.

There are 128 instruments, identified by their program number.

Percussion in particular really needs to be mapped correctly. 

Rather than have a GM instrument per drum, percussion is mapped across 47 notes (35 -> 81) on Channel 10.

> NOTE - GM level 2 expanded the range of percussion, amongst other things. It goes from 27 -> 87. Perhaps it is better to have a 128 dim embedding and be done with it?

It would also be good to have pitch (and mod?) control info incorporated, as this is used a lot, however it is very high resolution both in terms of range and sample rate if you want to get smooth (i.e. not stepped) pitch bends.

### Encoding

Ignoring the pitch / CC stuff, we have 128 instruments that can each play 128 notes, plus 47 instruments that can play 1 note.


In [3]:
instruments = 128
pitches = 128
perscussion = 47
(instruments * pitches) + perscussion

16431

#### Tokens

We could swap our `n{i}` for `{instrument}{i}` tokens, but that would result in 16431 note tokens (as opposed to the 128 we currently have).

Some of these would also be very rarely used.

#### Embeddings

We could use an embedding lookup to add instrument information to each token, the same as we do for bar / beat, packaging the instrument info alongside the note and timestep.

The trouble with this is we will face the same challenge that came with bars and beats, which is reconstructing them at the output.

You could have a second output layer with `instruments + percussion` (175) neurons representing the most likely instrument for that note and softmax over it?

Unlike bar and beat, we don't rely on it being perfect in order to render the performance.

We might expect the residual stream to have the ability to pass the embedding information straight through the network to the output layer.

> Copilot suggest we sum the losses from the two output heads

I think this is the most reasonable way to proceed.

- Encode instrument info alongside note, duration and timestep
- Use this to create a second set of labels
- Embed the instrument info at the inout layer
- Pass output of the transformer through two linear layers for classification, one of vocab dims and one of 175 for the instruments.
- Score the outputs against the respective labels
- Sum the loss

## Plan

All of the above considered, it is probably best to first update the architecture, then update the dataset / encoding afterwards otherwise it will be a lot at once and any bugs will be hard to track.


### Loading data

PyTorch has a built-in helper class for loading data in a custom way, called [Batch Sampler](https://medium.com/@haleema.ramzan/how-to-build-a-custom-batch-sampler-in-pytorch-ce04161583ee).

Its job is to generate the index of the next item from the data set. This index is then fed to another customised PyTorch class, [DataSet](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html), which uses them to return the appropiate sample.



In [4]:
from torch.utils.data import Dataset, DataLoader, Sampler

class CustomTextDataset(Dataset):
    def __init__(self, file_paths, sample_length):
        self.file_paths = file_paths
        self.data = []
        self.file_lengths = []
        self.total_samples = 0
        self.sample_length = sample_length

    def load_samples(self):
        for file_path in self.file_paths:
            with open(file_path, 'r') as file:
                samples = file.read().split('\n')  # Assuming each line is a sample
                samples = [[int(char) for char in sample] for sample in samples if len(sample) == self.sample_length] # Filter out samples that are not the correct length
                if len(samples) == 0: # Skip files with no valid samples
                    continue
                self.data.append(samples)
                self.file_lengths.append(len(samples))
        self.total_samples = sum(self.file_lengths)
    
    def __len__(self):
        return self.total_samples
    
    def __getitem__(self, idx):
        file_idx = idx[0]
        sample_idx = idx[1]
        sample = self.data[file_idx][sample_idx]
        return torch.tensor(sample)

class CustomBatchSampler(Sampler):
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        self.current_files = list(range(batch_size))
        self.batch_indices = []
    
    def precompute_indices(self):
        file_indices = [0] * len(self.dataset.file_paths)
        while sum(file_indices) < self.dataset.total_samples:
            batch = []
            for i in range(self.batch_size):
                current_file = self.current_files[i]

                # Check if the current file is exhausted
                if file_indices[current_file] == self.dataset.file_lengths[current_file]:
                    # Find the next file that hasn't been started
                    found_new_file = False
                    for next_file in range(len(self.dataset.file_lengths)):
                        if file_indices[next_file] == 0:
                            current_file = next_file
                            self.current_files[i] = current_file
                            found_new_file = True
                            break
                    
                    if not found_new_file:
                        # No more unstarted files, break the loop
                        return

                indices = [current_file, file_indices[current_file]]
                batch.append(indices)
                
                file_indices[current_file] += 1

            self.batch_indices.append(torch.tensor(batch))
        return

    def __iter__(self):
        for index in self.batch_indices:
            yield index

    def __len__(self):
        return len(self.batch_indices)

We can test this with some dummy data.  How about

- Batch size 4
- 6 batches

- Dim 1 - 1 file * 6 samples
- Dim 2 - 2 files * 3 samples
- Dim 3 - 2 files, 1 * 2 samples, 1 * 4 samples
- Dim 4 - 3 files * 2 samples

We would expect these file indices in the respective batch dims:

- 1
- 2, 7
- 3, 5
- 4, 6, 8

If we add up to `batch_size-1` extra files, they shouldn't be used as we should `return` if we haven't got enough to fill a whole batch.

In [5]:
batch_size = 4

file_names = [      # File index
    '6.txt',        # 1
    '3.1.txt',      # 2
    '2.1.txt',      # 3
    '2.2.txt',      # 4
    '4.txt',        # 5
    '2.3.txt',      # 6
    '3.2.txt',      # 7
    '2.4.txt',      # 8
    'extra.1.txt',  # 9
    'extra.2.txt',  # 10
    'extra.3.txt',  # 11
    'empty.txt',    # 12
    'bad.txt',      # 13
]

file_paths = list(map(lambda filename: Path(f'../data/text/{filename}'), file_names))

dataset = CustomTextDataset(file_paths, 3)
dataset.load_samples()

sampler = CustomBatchSampler(dataset, batch_size)
sampler.precompute_indices()

dataloader = DataLoader(dataset, batch_sampler=sampler)

for batch in dataloader:
    print(batch)

tensor([[6, 1, 1],
        [3, 1, 1],
        [2, 1, 1],
        [2, 2, 1]])
tensor([[6, 1, 2],
        [3, 1, 2],
        [2, 1, 2],
        [2, 2, 2]])
tensor([[6, 1, 3],
        [3, 1, 3],
        [4, 1, 1],
        [2, 3, 1]])
tensor([[6, 1, 4],
        [3, 2, 1],
        [4, 1, 2],
        [2, 3, 2]])
tensor([[6, 1, 5],
        [3, 2, 2],
        [4, 1, 3],
        [2, 4, 1]])
tensor([[6, 1, 6],
        [3, 2, 3],
        [4, 1, 4],
        [2, 4, 2]])


In [6]:
sampler.__len__()

6