In [1]:
import os
os.chdir(r'6 - Byte Pair encoding')
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.tensorboard import SummaryWriter
from itertools import chain, cycle, groupby
from functools import reduce
from typing import Collection, List
from pathlib import Path
import music21 as m21
musescore_path = '/usr/bin/mscore'
m21.environment.set('musicxmlPath', musescore_path)
m21.environment.set('musescoreDirectPNGPath', musescore_path)
from midi_encoding import *
from data_loading import *
from einops import rearrange, repeat, pack, unpack, einsum
import faiss
import time
import math
import pickle

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device}.")

# Byte Pair Encoding (BPE)

## What is the aim here?

- Increase information density in the context by introducing dedicated tokens for common pairs of tokens (and pairs of pairs, and pairs of those, etc etc)

This will be achieved by adapting the `Vocab` class to be *trainable*.

Training the `Vocab` class means feeding it a bunch of example data (which may or may not be the training set - rare items may be deliberately represented).

It uses this example data to *extend* the base vocabulary with pairs of tokens, and then pairs of those pairs etc etc. until the desired vocab size is reached.

This trained extended vocab can then be used to encode training data in a much more efficient, compressed form.

In our particular music-based case, you might imagine common chords being represented as a single token rather than a whole string of notes and durations.

> A C Major with notes that lasted slightly different lengths of time wouldn't be grouped unfortunately...

In [2]:
# Create a nested tensor
# Zeros represent the time idx
nested_tensor = torch.nested.nested_tensor([torch.tensor([[[1,0], [2,0]], [[3,0], [4,0]]]), torch.tensor([[[5,0], [6,0]]]), torch.tensor([[[7,0], [8,0]], [[9,0], [10,0]], [[11,0], [12,0]]])])
nested_tensor

In [3]:
flattened_tensor = torch.cat([t.flatten(0,1) for t in nested_tensor.unbind()])
flattened_tensor

In [4]:
flattened_tensor[:, 0]

In [5]:
class MusicVocab():
    def __init__(self):
        self.itos = {k:v for k,v in enumerate(ALL_TOKENS)}
        self.stoi = {v:k for k,v in enumerate(ALL_TOKENS)}
        self.idx_to_elem = {k:[k] for k,v in enumerate(ALL_TOKENS)} # 1 is [1], 2 is [2] etc. until we merge.
        self.merges = None
    
    def to_indices(self, tokens):
        return [self.stoi[w] for w in tokens]

    def to_tokens(self, idxs, sep=' '):
        items = [self.itos[idx] for idx in idxs]
        return sep.join(items) if sep is not None else items
    
    def to_element(self, idxs):
        return [self.idx_to_elem[idx] for idx in idxs]

    def get_stats(self, idxs):
        stats = {}
        for pair in zip(idxs[:-1], idxs[1:]):
            stats[pair] = stats.get(pair, 0) + 1
        return stats

    def merge(self, idxs, pos, pair, idx):
        new_idxs = []
        new_pos = None if pos is None else []
        i = 0
        while i < len(idxs):
            # Need to make sure we don't merge across time steps, otherwise we can't assign a tidx to the merged token
            current_item = idxs[i]

            if pos is not None:
                current_pos = pos[i]
                new_pos.append(current_pos)

            next_item = idxs[i+1] if i < len(idxs) - 1 else None

            if next_item is not None and current_item == pair[0] and next_item == pair[1]:
                new_idxs.append(idx)
                i += 2
            else:
                new_idxs.append(current_item)
                i += 1

        return new_idxs, new_pos

    # Pass in data already encoded using untrained vocab
    def train(self, dataset, max_vocab_size):

        if self.merges is not None:
            raise Exception("Already trained")
        
        self.merges = {}

        # Might have to skip the clone, depending on memory usage
        idxs = torch.cat([t.flatten(0,1) for t in dataset.data.unbind()]) # Flatten the nested tensor
        idxs = idxs[:, 0].detach().cpu().tolist() # Discard time idx and convert to list
        initial_size = self.size
        num_merges = max_vocab_size - initial_size

        for i in range(num_merges):
            stats = self.get_stats(idxs)
            pair = max(stats, key=stats.get)
            idx = initial_size + i
            print(f"Merging {pair} to a new token {idx}")
            idxs, _ = self.merge(idxs, None, pair, idx)
            self.merges[pair] = idx
        
        for (p0, p1), idx in self.merges.items():
            value = f"{self.itos[p0]} {self.itos[p1]}"
            self.itos[idx] = value
            self.stoi[value] = idx
            self.idx_to_elem[idx] = self.idx_to_elem[p0] + self.idx_to_elem[p1]
    
    def state_dict(self):
        return {
            'idx_to_elem': self.idx_to_elem,
            'merges': self.merges
        }
    
    def load_state_dict(self, state_dict):
        self.merges = state_dict['merges']
        self.idx_to_elem = state_dict['idx_to_elem']
        self.itos = {k:self.to_tokens(v) for k,v in enumerate(self.idx_to_elem.values())}
        self.stoi = {v:k for k,v in enumerate(self.itos.values())}
    
    def encode(self, note_position_score):
        nps = note_position_score.copy()
        note_dur_score = nps[:, :2] # Note and duration, drop tidx
        
        # Offset the note and duration values by the min index to get their index
        note_min_idx, _ = self.note_range
        dur_min_idx, _ = self.duration_range
        note_idx_score = note_dur_score + np.array([note_min_idx, dur_min_idx])

        note_idx_score = note_idx_score.reshape(-1) # Flatten note and duration into a single dimension
        pos_score = np.repeat(nps[:, 2], 2) # Double up positions for flattened note and duration
        
        while True:
            stats = self.get_stats(note_idx_score)

            # Iterate keys and get the pair with the min number of merges, so we do earlier merges first
            pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
            
            if pair not in self.merges:
                print("No more merges to do")
                break
            else:
                idx = self.merges[pair]
                print(f"Replacing {pair} with token {idx}")
                note_idx_score, pos_score = self.merge(note_idx_score, pos_score, pair, idx)

        return np.array(note_idx_score), np.array(pos_score)
    
    def decode(self, note_idx_score):
        # Convert idxs to positions and pair up note / durations
        merge_chunks = [self.idx_to_elem[idx] for idx in note_idx_score]
        position_score = np.array(list(chain(*merge_chunks))).reshape(-1, 2)

        # Offset the note and duration idxs by their respective min index to get their actual value
        if position_score.shape[0] != 0: 
            note_min_idx, _ = self.note_range
            dur_min_idx, _ = self.duration_range
            position_score -= np.array([note_min_idx, dur_min_idx])

        return position_score

    @property
    def sos_idx(self): return self.stoi[SOS]
    @property
    def eos_idx(self): return self.stoi[EOS]
    @property
    def sep_idx(self): return self.stoi[SEP]
    @property
    def pad_idx(self): return self.stoi[PAD]
    @property
    def note_position_enc_range(self): return (self.stoi[SEP], self.size)
    @property
    def note_range(self): return self.stoi[NOTE_START], self.stoi[NOTE_END]+1
    @property
    def duration_range(self): return self.stoi[DURATION_START], self.stoi[DURATION_END]+1
    @property
    def size(self): return len(self.itos)

Question
- Should we consider SEPARATOR_IDX as a hard terminator of bpe (like a word space) and pre-chunk into simultaneous 'actions'? Instead of 'optional space followed by a sequence of chars' it would be 'optional separator followed by an action (sequence of notes and durations)'.

Answer: 
- No need because an 'action' is *always* followed by a separator, it's not like we have a range of punctuation to differentiate (i.e. dog. vs dog, vs dog!).
- However we did need to prevent merging across time index boundaties to ensure each token had a distinct timestep, which meant skipping a merge during encoding if the previous pair was SEP, DUR

In [6]:
dataset_vocab = MusicVocab()

In [7]:
sample_length = 256
max_file_length = 32
vg_large_path = Path('../data/midi/vg_large')
vg_large_file_names = [f for f in os.listdir(vg_large_path) if os.path.isfile(os.path.join(vg_large_path, f))]
midi_path = Path('../data/midi/vg_large')
score_path = Path(f'../data/numpy/vg_large/all')

dataset = MidiDataset(vg_large_file_names, midi_path, score_path, sample_length, max_file_length)

In [8]:
dataset.load_samples(dataset_vocab, device) # Could load directly to CPU here as it is only used for training the vocab but wanted to replicate existing code

In [9]:
vocab = MusicVocab()
init_vocab_size = vocab.size
init_vocab_size

In [10]:
vocab.train(dataset, max_vocab_size=395)

In [11]:
trained_vocab_size = vocab.size

In [12]:
for idx in range(init_vocab_size, init_vocab_size + 6):
    print(f"{idx}: {vocab.itos[idx]}")

In [13]:
for idx in range(init_vocab_size, init_vocab_size + 6):
    print(f"{idx}: {vocab.idx_to_elem[idx]}")

In [14]:
final_fantasy_midi_path = Path('../data/midi/Mana_GB_Final Fantasy Adventure_Battle 2.mid')
final_fantasy_midi = m21.midi.MidiFile()
final_fantasy_midi.open(final_fantasy_midi_path)
final_fantasy_midi.read()
final_fantasy_midi.close()

final_fantasy_m21stream = m21.midi.translate.midiFileToStream(final_fantasy_midi)
final_fantasy_m21stream.plot()

In [15]:
final_fantasy_idx_score = midifile_to_idx_score(final_fantasy_midi_path, vocab)

vocab.to_tokens(final_fantasy_idx_score[:, 0])

In [16]:
final_fantasy_idx_score.shape # 3162 before adding 6 tokens to the vocab

In [17]:
reconstructed_final_fantasy_stream = idx_to_stream_enc(final_fantasy_idx_score[:, 0], vocab)
reconstructed_final_fantasy_stream.plot()

In [18]:
final_fantasy_m21stream.show('midi')

In [19]:
reconstructed_final_fantasy_stream.show('midi')

It works!

## Save / Restore

It is important that we use the same vocab to decode data that we used to encode it, otherwise tokens might not exist or have other meanings.

This also goes for decoding performances generated by the model. We must use the same vocab which encoded the training data (and all training data must have been encoded with the same vocab). This stands to reason, otherwise tokens will have no consistent meaning.

For this reason we should save it alongside our model during training so that we can load it during inference.

In [20]:
trained_state = vocab.state_dict()

untrained_vocab = MusicVocab()
untrained_vocab.size

In [21]:
untrained_vocab.load_state_dict(trained_state)
untrained_vocab.size

In [22]:
untrained_final_fantasy_stream = idx_to_stream_enc(final_fantasy_idx_score[:, 0], untrained_vocab)
untrained_final_fantasy_stream.plot()

In [23]:
untrained_final_fantasy_stream.show('midi')

In [24]:
for v in untrained_vocab.itos.items():
    print(v)

In [25]:
for v in untrained_vocab.stoi.items():
    print(v)

Save and restore looking good.

## Adding to the training loop

I think the method will be

1. Create a vocab
2. Restore vocab and load merged data if exists
3. If doesn't exist
- Load unmerged data.
- Train vocab and save.
- Merge data and save.

A dataset is **useless** without the vocab that encoded it, as is a model without the vocab of the data it was trained on, so save it in the model checkpoints, and load when restoring a model.