In [30]:
import os
os.chdir(r'6 - Byte Pair encoding')
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.tensorboard import SummaryWriter
from itertools import chain, cycle, groupby
from functools import reduce
from typing import Collection, List
from pathlib import Path
import music21 as m21
musescore_path = '/usr/bin/mscore'
m21.environment.set('musicxmlPath', musescore_path)
m21.environment.set('musescoreDirectPNGPath', musescore_path)
from midi_encoding import *
from data_loading import *
from einops import rearrange, repeat, pack, unpack, einsum
import faiss
import time
import math
import pickle

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device}.")

FileNotFoundError: [Errno 2] No such file or directory: '6 - Byte Pair encoding'

# Byte Pair Encoding (BPE)

## What is the aim here?

- Increase information density in the context by introducing dedicated tokens for common pairs of tokens (and pairs of pairs, and pairs of those, etc etc)

This will be achieved by adapting the `Vocab` class to be *trainable*.

Training the `Vocab` class means feeding it a bunch of example data (which may or may not be the training set - rare items may be deliberately represented).

It uses this example data to *extend* the base vocabulary with pairs of tokens, and then pairs of those pairs etc etc. until the desired vocab size is reached.

This trained extended vocab can then be used to encode training data in a much more efficient, compressed form.

In our particular music-based case, you might imagine common chords being represented as a single token rather than a whole string of notes and durations.

> A C Major with notes that lasted slightly different lengths of time wouldn't be grouped unfortunately...

In [2]:
# Create a nested tensor
# Zeros represent the time idx
nested_tensor = torch.nested.nested_tensor([torch.tensor([[[1,0], [2,0]], [[3,0], [4,0]]]), torch.tensor([[[5,0], [6,0]]]), torch.tensor([[[7,0], [8,0]], [[9,0], [10,0]], [[11,0], [12,0]]])])
nested_tensor

  return _nested.nested_tensor(


nested_tensor([
  tensor([[[1, 0],
           [2, 0]],
  
          [[3, 0],
           [4, 0]]]),
  tensor([[[5, 0],
           [6, 0]]]),
  tensor([[[ 7,  0],
           [ 8,  0]],
  
          [[ 9,  0],
           [10,  0]],
  
          [[11,  0],
           [12,  0]]])
])

In [3]:
flattened_tensor = torch.cat([t.flatten(0,1) for t in nested_tensor.unbind()])
flattened_tensor

tensor([[ 1,  0],
        [ 2,  0],
        [ 3,  0],
        [ 4,  0],
        [ 5,  0],
        [ 6,  0],
        [ 7,  0],
        [ 8,  0],
        [ 9,  0],
        [10,  0],
        [11,  0],
        [12,  0]])

In [4]:
flattened_tensor[:, 0]

tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])

In [53]:
class MusicVocab():
    def __init__(self):
        self.itos = {k:v for k,v in enumerate(ALL_TOKENS)}
        self.stoi = {v:k for k,v in enumerate(ALL_TOKENS)}
        self.idx_to_elem = {k:k for k,v in enumerate(ALL_TOKENS)} # 1 is 1, 2 is 2 etc. until we merge.
        self.merges = None
    
    def to_indices(self, tokens):
        return [self.stoi[w] for w in tokens]

    def to_tokens(self, idxs, sep=' '):
        items = [self.itos[idx] for idx in idxs]
        return sep.join(items) if sep is not None else items
    
    def to_element(self, idxs):
        return [self.idx_to_elem[idx] for idx in idxs]

    def get_stats(self, idxs):
        stats = {}
        for pair in zip(idxs[:-1], idxs[1:]):
            stats[pair] = stats.get(pair, 0) + 1
        return stats

    def merge(self, idxs, pair, idx):
        new_idxs = []
        i = 0
        while i < len(idxs):
            current_item = idxs[i]
            next_item = idxs[i+1] if i < len(idxs) - 1 else None
            if next_item is not None and current_item == pair[0] and next_item == pair[1]:
                new_idxs.append(idx)
                i += 2
            else:
                new_idxs.append(current_item)
                i += 1
        return new_idxs
    
    def concat_lists(self, list1, list2):
        if not isinstance(list1, list):
            list1 = [list1]
        if not isinstance(list2, list):
            list2 = [list2]
        return list1 + list2

    # Pass in data already encoded using untrained vocab
    def train(self, dataset, max_vocab_size):

        if self.merges is not None:
            raise Exception("Already trained")
        
        self.merges = {}

        # Might have to skip the clone, depending on memory usage
        idxs = torch.cat([t.flatten(0,1) for t in dataset.data.unbind()]) # Flatten the nested tensor
        idxs = idxs[:, 0].detach().cpu().tolist() # Discard time idx and convert to numpy
        initial_size = self.size
        num_merges = max_vocab_size - initial_size

        for i in range(num_merges):
            stats = self.get_stats(idxs)
            pair = max(stats, key=stats.get)
            idx = initial_size + i
            print(f"Merging {pair} to a new token {idx}") 
            idxs = self.merge(idxs, pair, idx)
            self.merges[pair] = idx
        
        for (p0, p1), idx in self.merges.items():
            value = f"{self.itos[p0]}{self.itos[p1]}"
            self.itos[idx] = value
            self.stoi[value] = idx
            self.idx_to_elem[idx] = self.concat_lists(self.idx_to_elem[p0] ,self.idx_to_elem[p1])
    
    def save(self, path):
        state = {
            'itos': self.itos,
            'stoi': self.stoi,
            'idx_to_elem': self.idx_to_elem,
            'merges': self.merges
        }
        with open(path, 'wb') as file:
            pickle.dump(state, file)
    
    def load(self, path):
        with open(path, 'rb') as file:
            state = pickle.load(file)
            self.itos = state['itos']
            self.stoi = state['stoi']
            self.idx_to_elem = state['idx_to_elem']
            self.merges = state['merges']
    
    def encode(self, note_position_score):
        nps = note_position_score.copy()
        note_idx_score = nps[:, :2] # Note and duration, drop tidx
        pos_score = nps[:, 2]
        note_min_idx, _ = self.note_range
        dur_min_idx, _ = self.duration_range
        note_idx_score += np.array([note_min_idx, dur_min_idx])
        note_idx_score = note_idx_score.reshape(-1)
        
        while True:
            stats = self.get_stats(note_idx_score)
            # Iterate keys and get the pair with the min number of merges, so we do earlier merges first
            pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
            if pair not in self.merges:
                break
            idx = self.merges[pair]
            note_idx_score = self.merge(note_idx_score, pair, idx)

        # Questions:
        # 1. How do we handle the position score? - Pass the position to merge and it can be included in the merged score
        # 2. How do we decode to midi? We would need to 'un-pair' the tokens

        return note_idx_score
    
    def decode(self, note_idx_score):
        return None

    @property
    def sos_idx(self): return self.stoi[SOS]
    @property
    def eos_idx(self): return self.stoi[EOS]
    @property
    def sep_idx(self): return self.stoi[SEP]
    @property
    def pad_idx(self): return self.stoi[PAD]
    @property
    def note_position_enc_range(self): return (self.stoi[SEP], self.stoi[DURATION_END]+1)
    @property
    def note_range(self): return self.stoi[NOTE_START], self.stoi[NOTE_END]+1
    @property
    def duration_range(self): return self.stoi[DURATION_START], self.stoi[DURATION_END]+1
    @property
    def size(self): return len(self.itos)

Question
- Should we consider SEPARATOR_IDX as a hard terminator of bpe (like a word space) and pre-chunk into simultaneous 'actions'? Instead of 'optional space followed by a sequence of chars' it would be 'optional separator followed by an action (sequence of notes and durations)'.

Answer: 
- No need because an 'action' is *always* followed by a separator, it's not like we have a range of punctuation to differentiate (i.e. dog. vs dog, vs dog!).

In [54]:
vocab = MusicVocab()
init_vocab_size = vocab.size
init_vocab_size

389

In [7]:
sample_length = 256
max_file_length = 32
vg_large_path = Path('../data/midi/vg_large')
vg_large_file_names = [f for f in os.listdir(vg_large_path) if os.path.isfile(os.path.join(vg_large_path, f))]
midi_path = Path('../data/midi/vg_large')
score_path = Path(f'../data/numpy/vg_large/all')

dataset = MidiDataset(vg_large_file_names, midi_path, score_path, sample_length, max_file_length)

In [8]:
dataset.load_samples(vocab, device) # Could load directly to CPU here as it is only used for training the vocab but wanted to replicate existing code

In [55]:
vocab.train(dataset, max_vocab_size=395)

Merging (3, 136) to a new token 389
Merging (136, 389) to a new token 390
Merging (2, 2) to a new token 391
Merging (3, 134) to a new token 392
Merging (134, 392) to a new token 393
Merging (134, 389) to a new token 394


In [56]:
trained_vocab_size = vocab.size

In [57]:
for idx in range(init_vocab_size, trained_vocab_size):
    print(f"{idx}: {vocab.itos[idx]}")

389: <|sep|>d4
390: d4<|sep|>d4
391: <|pad|><|pad|>
392: <|sep|>d2
393: d2<|sep|>d2
394: d2<|sep|>d4


In [58]:
for idx in range(init_vocab_size, trained_vocab_size):
    print(f"{idx}: {vocab.idx_to_elem[idx]}")

389: [3, 136]
390: [136, 3, 136]
391: [2, 2]
392: [3, 134]
393: [134, 3, 134]
394: [134, 3, 136]
