In [1]:
import os
os.chdir(r'5 - TransformerXL')
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.tensorboard import SummaryWriter
from itertools import chain
from itertools import groupby
from functools import reduce
from typing import Collection, List
from pathlib import Path
import music21 as m21
musescore_path = '/usr/bin/mscore'
m21.environment.set('musicxmlPath', musescore_path)
m21.environment.set('musescoreDirectPNGPath', musescore_path)
from midi_encoding import *
from einops import rearrange, repeat, pack, unpack, einsum
import faiss

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device}.")

Using cuda.


In [2]:
if device == "cuda":
    print(f"Device: {torch.cuda.get_device_name()}.")

Device: NVIDIA GeForce RTX 4090.


In [3]:
!nvidia-smi

Mon Aug 19 14:45:34 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.31.01              Driver Version: 560.81         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:01:00.0  On |                  Off |
| 30%   33C    P0             44W /  450W |    2962MiB /  24564MiB |     34%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
vocab = MusicVocab()
vocab.size

392

# Building the Model

Our memory-augmented transformer will be very similar to the vanilla model developed in the previous notebook.

There will be three major additions, all of which are fairly simple:

- Relative positional embeddings.
- KNN lookup for keys (and their associated values).
- Recurrent 'TransformerXL' style memory.

## Einops

First, lets condense our previous implementation in two ways

- Add a dimension to our `MultiHeadAttention` module, abandoning the separate `SelfAttentionHead` module.
- Switch to using einops for shape manipulation as it is both simpler to write and read.

In [5]:
class MultiHeadAttention(torch.nn.Module):

    def __init__(self, n_embed, n_head = 8, dropout = 0.2):
        super().__init__()
        self.n_embed = n_embed
        self.n_head = n_head
        self.head_size = n_embed // n_head
        head_total_size = n_head * self.head_size
        self.key = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.query = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.value = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.proj = torch.nn.Linear(head_total_size, n_embed)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # Split heads
        q = rearrange(q, 'b t (h d) -> b h t d', h = self.n_head)
        k = rearrange(k, 'b t (h d) -> b h t d', h = self.n_head)
        v = rearrange(v, 'b t (h d) -> b h t d', h = self.n_head)

        # Without einsum we had to swap dims using k.transpose(-2, -1)
        w = einsum(q, k, 'b h i d, b h j d -> b h i j') * (self.head_size ** -0.5)
        
        # TODO: Relative positional encoding

        i, j = w.shape[-2:]
        mask = torch.tril(torch.ones((i,j), dtype = torch.bool))
        w = w.masked_fill(mask, float('-inf'))
        w = F.softmax(w, dim=-1)

        attention_scores = w@v

        # Concat heads
        attention_scores = rearrange(attention_scores, 'b h t d -> b t (h d)')

        # TODO: KNN memory

        out = self.proj(attention_scores)
        return self.dropout(out)

## KNN Memory

If we want to add KNN memory, we need an indexed data store to look up and retrieve keys and values.

For the index we can use [Faiss](https://github.com/facebookresearch/faiss) from Meta.

For the data store we can simply use a memory-mapped numpy array.

The following code is adapted from the [Colab Notebook](https://colab.research.google.com/drive/1XZz1sjNt1MKRG6ul_hOGSJFQLS4lRtmJ?usp=sharing#scrollTo=gs7RpvCdePZr) accompanying the [Coding a Paper](https://www.youtube.com/playlist?list=PLam9sigHPGwOe8VDoS_6VT4jjlgs9Uepb) series:

In [6]:
class KNN():
    def __init__(self, dim, max_memories, db_filepath):
        self.dim = dim
        self.max_memories = max_memories
        self.shape = (max_memories, 2, dim) # TODO: Memory per batch dim, cleared when file is exhausted
        self.db_offset = 0
        self.db_filepath = db_filepath
        self.db = np.memmap(self.db_filepath, mode = 'w+', dtype = np.float32, shape = self.shape)
        self.index = faiss.IndexFlatL2(dim)


    def add_to_db(self, new_data):
        new_data_len = new_data.shape[0] # (b t)
        ids = (np.arange(new_data_len) + self.db_offset)
        self.db[ids] = new_data.detach().numpy()
        self.db_offset += new_data_len
        # Write to file
        self.db.flush()


    def search_and_retrieve(self, query, k):

        # The tooltip says the args are (n, x, k) but that's the CPP api, it's actually (x, k) in Python (n is the first dim of x anyway so can be inferred).
        distances, indices = self.index.search(query, k)
        
        kvs = self.db[indices]
        return kvs

    def add(self, new_data):
        new_data = rearrange(new_data, 'b t two c -> (b t) two c')

        # Add to db
        self.add_to_db(new_data)

        # Only keys are used in knn index
        keys, vals = new_data.unbind(dim=-2)

        # Add (b t) c tensors to index
        keys = keys.detach().numpy()
        keys = np.ascontiguousarray(keys)
        self.index.add(keys)

    def search(self, query, k):
        T = query.shape[1]
        query = rearrange(query, 'b t c -> (b t) c')

        kvs = self.search_and_retrieve(np.ascontiguousarray(query.detach().numpy()), k)
        kvs = torch.tensor(kvs)
        kvs = kvs = rearrange(kvs, '(b t) k two c -> b t k two c', t = T)
        
        return kvs


    def clear(self):
        self.index.reset()
        self.db[:] = 0
        self.db_offset = 0

Let's test it

In [7]:
c = 4
t = 2

knn = KNN(c, 100000, '../data/numpy/knn-test.db')

vector_data = torch.tensor(np.random.random((1000, t, 2, c)).astype('float32'))
knn.add(vector_data)

query_data = torch.tensor(np.random.random((1, t, c)).astype('float32'))
query_data

tensor([[[0.2857, 0.8111, 0.4562, 0.2322],
         [0.8220, 0.7890, 0.6071, 0.4583]]])

Search returns a `(b t k 2 c)` tensor which contains the top_k keys and values for each `b t c` query.

Here our query is `1 * 2 * 4` so our results will be `(1 * 2 * 2 * 2 * 4)`

In [8]:
top_k = 2
knn.search(query_data, top_k) # (b t k two c) tensor, returns top_k keys and values for each query.

tensor([[[[[0.3740, 0.7438, 0.4960, 0.2564],
           [0.9457, 0.5970, 0.5391, 0.2538]],

          [[0.2522, 0.9185, 0.4160, 0.2056],
           [0.6594, 0.8644, 0.6513, 0.3425]]],


         [[[0.7901, 0.8044, 0.6787, 0.4075],
           [0.6474, 0.6122, 0.0258, 0.6570]],

          [[0.7565, 0.8019, 0.6250, 0.3903],
           [0.2035, 0.2504, 0.8388, 0.0084]]]]])

Now we can integrate the memory into our multiheaded attention.

We will make a new class for this as we only use KNN on the second to last layer.

It will have a KNN memory for each batch dimension, and we will clear that memory if the file in that batch dimension changes.

We will know this as the `CustomMidiDataset` returns the file indices of each batch along with the data. 

These can be passed to our model, which in turn can pass them to the KNN attention block.

In [9]:
class KNNAttention(torch.nn.Module):

    def __init__(self, batch_size, n_embed, k, n_head = 8, dropout = 0.2):
        super().__init__()
        self.n_embed = n_embed
        self.k = k
        self.n_head = n_head
        self.head_size = n_embed // n_head
        head_total_size = n_head * self.head_size
        self.key = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.query = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.value = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.proj = torch.nn.Linear(head_total_size, n_embed)
        self.dropout = torch.nn.Dropout(dropout)

        # Memory per batch dim
        self.knn = {i: KNN(n_embed, 100000, f'../data/numpy/knn-{i}.db') for i in range(batch_size)}
        self.current_files = None

        self.gate_bias = torch.nn.Parameter(torch.randn(self.n_head, 1, 1))
    

    def forward(self, x, batch_files):

        # Clear batch dim's knn memory if file changes
        if self.current_files != None:
            for i in range(len(self.current_files)):
                if self.current_files[i] != batch_files[i]:
                    self.knn[i].clear()

        self.current_files = batch_files

        B, T, C = x.shape

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # Split heads
        q = rearrange(q, 'b t (h d) -> b h t d', h = self.n_head)
        k = rearrange(k, 'b t (h d) -> b h t d', h = self.n_head)
        v = rearrange(v, 'b t (h d) -> b h t d', h = self.n_head)

        # Without einsum we had to swap dims using k.transpose(-2, -1)
        w = einsum(q, k, 'b h i d, b h j d -> b h i j') * (self.head_size ** -0.5)
        
        # TODO: Relative positional encoding

        i, j = w.shape[-2:]
        mask = torch.tril(torch.ones((i,j), dtype = torch.bool))
        w = w.masked_fill(mask, float('-inf'))
        w = F.softmax(w, dim=-1)

        attention_scores = w@v

        # Concat heads
        attention_scores = rearrange(attention_scores, 'b h t d -> b t (h d)')



        # KNN memory

        # Convert queries to search form
        q = rearrange(q, 'b h t d -> b t (h d)')
        mem_kv = torch.stack([self.knn[i].search(q[i:i+1], k = self.k) for i in range(B)], dim = 0) # returns b t k 2 c
        mem_k, mem_v = mem_kv.unbind(dim = -2)
        mem_k = rearrange(mem_k, 'b t k (h d) -> b h t k d', h=self.n_head)
        mem_v = rearrange(mem_v, 'b t k (h d) -> b h t k d', h=self.n_head)

        # Convert queries to attention form
        q = rearrange(q, 'b t (h d) -> b h t d', h = self.n_head)
        mem_w = einsum('b h t d, b h t k d -> b h t k', q, mem_k)
        mem_w = mem_w * (self.head_size ** -0.5)

        mem_w = F.softmax(mem_w, dim=-1)
        mem_attention_scores = einsum('b h t k, b h t k d -> b h t d', mem_w, mem_v)

        # Combined attentions
        combined_attention_scores = mem_attention_scores * self.gate_bias + attention_scores * (1 - self.gate_bias)
        combined_attention_scores = rearrange(combined_attention_scores, 'b h t d -> b t (h d)')





        out = self.proj(combined_attention_scores)
        return self.dropout(out)

### TODO

- Add KNN lookup to vanilla transformer
- XL recurrence
- Relative Positional embeddings

Once all that is working, consider

- Ragged memmap for data loading (allows moving to bigger dataset)
- Byte pair encoding (bigger vocab with common token pairs gives us a bigger effective context) 