In [1]:
import os
os.chdir(r'6 - Byte Pair encoding')
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.tensorboard import SummaryWriter
from itertools import chain, cycle, groupby
from functools import reduce
from typing import Collection, List
from pathlib import Path
import music21 as m21
musescore_path = '/usr/bin/mscore'
m21.environment.set('musicxmlPath', musescore_path)
m21.environment.set('musescoreDirectPNGPath', musescore_path)
from midi_encoding import *
from einops import rearrange, repeat, pack, unpack, einsum
import faiss
import time
import math
import pickle

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device}.")

Using cuda.


# Byte Pair Encoding (BPE)

## What is the aim here?

- Increase information density in the context by introducing dedicated tokens for common pairs of tokens (and pairs of pairs, and pairs of those, etc etc)

This will be achieved by adapting the `Vocab` class to be *trainable*.

Training the `Vocab` class means feeding it a bunch of example data (which may or may not be the training set - rare items may be deliberately represented).

It uses this example data to *extend* the base vocabulary with pairs of tokens, and then pairs of those pairs etc etc. until the desired vocab size is reached.

This trained extended vocab can then be used to encode training data in a much more efficient, compressed form.

In our particular music-based case, you might imagine common chords being represented as a single token rather than a whole string of notes and durations.

> A C Major with notes that lasted slightly different lengths of time wouldn't be grouped unfortunately...

In [2]:
# Create a nested tensor
# Zeros represent the time idx
nested_tensor = torch.nested.nested_tensor([torch.tensor([[[1,0], [2,0]], [[3,0], [4,0]]]), torch.tensor([[[5,0], [6,0]]]), torch.tensor([[[7,0], [8,0]], [[9,0], [10,0]], [[11,0], [12,0]]])])

# Manually flatten the nested tensor by concatenating the individual tensors
flattened_tensor = torch.cat([t.flatten(0,1) for t in nested_tensor.unbind()])

print(flattened_tensor)

tensor([[ 1,  0],
        [ 2,  0],
        [ 3,  0],
        [ 4,  0],
        [ 5,  0],
        [ 6,  0],
        [ 7,  0],
        [ 8,  0],
        [ 9,  0],
        [10,  0],
        [11,  0],
        [12,  0]])


  return _nested.nested_tensor(


In [3]:
print(flattened_tensor[:, 0])

tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])


In [4]:
class MusicVocab():
    def __init__(self, cache_path):
        self.itos = {k:v for k,v in enumerate(ALL_TOKENS)}
        self.stoi = {v:k for k,v in enumerate(ALL_TOKENS)}
        self.cache_path = cache_path
    
    def to_indices(self, tokens):
        return [self.stoi[w] for w in tokens]

    def to_tokens(self, idxs, sep=' '):
        items = [self.itos[idx] for idx in idxs]
        return sep.join(items) if sep is not None else items
    
    def get_stats(self, idxs):
        stats = {}
        for pair in zip(idxs[:-1], idxs[1:]):
            stats[pair] = stats.get(pair, 0) + 1
        return stats

    def merge(self, idxs, pair, idx):
        new_idxs = []
        i = 0
        while i < len(idxs):
            current_item = idxs[i]
            next_item = idxs[i+1] if i < len(idxs) - 1 else None
            if next_item is not None and current_item == pair[0] and next_item == pair[1]:
                new_idxs.append(idx)
                i += 2
            else:
                new_idxs.append(current_item)
                i += 1
        return new_idxs

    # Pass in data already encoded using untrained vocab
    def train(self, dataset, max_vocab_size=400):
        # Might have to skip the clone, depending on memory usage
        idxs = torch.cat([t.flatten(0,1) for t in dataset.data.unbind()]) # Flatten the nested tensor
        idxs = idxs[:, 0].detach().cpu().tolist() # Discard time idx and convert to numpy
        num_merges = max_vocab_size - self.size
        merges = {}

        for i in range(num_merges):
            stats = self.get_stats(idxs)
            pair = max(stats, key=stats.get)
            idx = self.size + i
            print(f"Merging {pair} to a new token {idx}")
            idxs = self.merge(idxs, pair, idx)
            merges[pair] = idx
        
        for (p0, p1), idx in merges.items():
            value = self.itos[p0] + self.itos[p1]
            self.itos[idx] = value
            self.stoi[value] = idx

        self.trained = True
    
    def save(self, path):
        with open(path, 'wb') as file:
            pickle.dump(self.itos, file)
    
    def load(self, path):
        with open(path, 'rb') as file:
            self.itos = pickle.load(file)
    
    def encode(self, note_position_score):
        return None
    
    def decode(self, note_index_score):
        return None

    @property
    def sos_idx(self): return self.stoi[SOS]
    @property
    def eos_idx(self): return self.stoi[EOS]
    @property
    def sep_idx(self): return self.stoi[SEP]
    @property
    def pad_idx(self): return self.stoi[PAD]
    @property
    def note_position_enc_range(self): return (self.stoi[SEP], self.stoi[DURATION_END]+1)
    @property
    def note_range(self): return self.stoi[NOTE_START], self.stoi[NOTE_END]+1
    @property
    def duration_range(self): return self.stoi[DURATION_START], self.stoi[DURATION_END]+1
    @property
    def size(self): return len(self.itos)

Question
- Should we consider SEPARATOR_IDX as a hard terminator of bpe (like a word space) and pre-chunk into simultaneous 'actions'? Instead of 'optional space followed by a sequence of chars' it would be 'optional separator followed by an action (sequence of notes and durations)'.

Answer: 
- No need because an 'action' is *always* followed by a separator, it's not like we have a range of punctuation to differentiate (i.e. dog. vs dog, vs dog!).

In [5]:
# Create a nested tensor
nested_tensor = torch.nested.nested_tensor([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6]]), torch.tensor([[7, 8], [9, 10], [11, 12]])])

# Manually flatten the nested tensor by concatenating the individual tensors
flattened_tensor = torch.cat([t.flatten() for t in nested_tensor.unbind()])

print(flattened_tensor)

tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
