In [1]:
import os
os.chdir(r'5 - TransformerXL')
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.tensorboard import SummaryWriter
from itertools import chain
from itertools import groupby
from functools import reduce
from typing import Collection, List
from pathlib import Path
import music21 as m21
musescore_path = '/usr/bin/mscore'
m21.environment.set('musicxmlPath', musescore_path)
m21.environment.set('musescoreDirectPNGPath', musescore_path)
from midi_encoding import *
from einops import rearrange, repeat, pack, unpack, einsum
import faiss
import time

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device}.")

Using cuda.


In [2]:
if device == "cuda":
    print(f"Device: {torch.cuda.get_device_name()}.")

Device: NVIDIA GeForce RTX 4090.


In [3]:
!nvidia-smi

Thu Aug 22 17:52:01 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.02              Driver Version: 560.94         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:01:00.0  On |                  Off |
| 30%   31C    P0             50W /  450W |    3042MiB /  24564MiB |     36%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
vocab = MusicVocab()
vocab.size

392

# Building the Model

Our memory-augmented transformer will be very similar to the vanilla model developed in the previous notebook.

There will be three major additions, all of which are fairly simple:

- Relative positional embeddings.
- KNN lookup for keys (and their associated values).
- Recurrent 'TransformerXL' style memory.

## Einops

First, lets condense our previous implementation in two ways

- Add a dimension to our `MultiHeadAttention` module, abandoning the separate `SelfAttentionHead` module.
- Switch to using einops for shape manipulation as it is both simpler to write and read.

In [5]:
class MultiHeadAttention(torch.nn.Module):

    def __init__(self, n_embed, n_head = 8, dropout = 0.2):
        super().__init__()
        self.n_embed = n_embed
        self.n_head = n_head
        self.head_size = n_embed // n_head
        head_total_size = n_head * self.head_size
        self.key = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.query = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.value = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.proj = torch.nn.Linear(head_total_size, n_embed)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # Split heads
        q = rearrange(q, 'b t (h d) -> b h t d', h = self.n_head)
        k = rearrange(k, 'b t (h d) -> b h t d', h = self.n_head)
        v = rearrange(v, 'b t (h d) -> b h t d', h = self.n_head)

        # Without einsum we had to swap dims using k.transpose(-2, -1)
        w = einsum(q, k, 'b h i d, b h j d -> b h i j') * (self.head_size ** -0.5)
        
        # TODO: Relative positional encoding

        i, j = w.shape[-2:]
        mask = torch.tril(torch.ones((i,j), dtype = torch.bool))
        w = w.masked_fill(mask, float('-inf'))
        w = F.softmax(w, dim=-1)

        weighted_values = w@v

        # Concat heads
        weighted_values = rearrange(weighted_values, 'b h t d -> b t (h d)')

        # TODO: KNN memory

        out = self.proj(weighted_values)
        return self.dropout(out)

## KNN Memory

If we want to add KNN memory, we need an indexed data store to look up and retrieve keys and values.

For the index we can use [Faiss](https://github.com/facebookresearch/faiss) from Meta.

For the data store we can simply use a memory-mapped numpy array.

The following code is adapted from the [Colab Notebook](https://colab.research.google.com/drive/1XZz1sjNt1MKRG6ul_hOGSJFQLS4lRtmJ?usp=sharing#scrollTo=gs7RpvCdePZr) accompanying the [Coding a Paper](https://www.youtube.com/playlist?list=PLam9sigHPGwOe8VDoS_6VT4jjlgs9Uepb) series:

In [6]:
class KNN():
    def __init__(self, dim, max_memories, db_filepath):
        self.dim = dim
        self.max_memories = max_memories
        self.shape = (max_memories, 2, dim)
        self.db_offset = 0
        if db_filepath.exists():
            dbMode = 'r+'
        else:
            dbMode = 'w+' # Create file if it doesn't exist
        self.db = np.memmap(db_filepath, mode = dbMode, dtype = np.float32, shape = self.shape)
        self.index = faiss.IndexFlatL2(dim)

    def add_to_db(self, new_data):
        new_data_len = new_data.shape[0] # (t)
        ids = (np.arange(new_data_len) + self.db_offset)
        self.db[ids] = new_data.detach().numpy()
        self.db_offset += new_data_len
        # Write to file
        self.db.flush()

    def search_and_retrieve(self, query, k):

        # The tooltip says the args are (n, x, k) but that's the CPP api, it's actually (x, k) in Python (n is the first dim of x anyway so can be inferred).
        distances, indices = self.index.search(query, k)
        
        kvs = self.db[indices]
        return kvs

    def add(self, new_data):
        # new_data = (t, 2, c)

        # Add to db
        self.add_to_db(new_data)

        # Only keys are used in knn index
        keys, vals = new_data.unbind(dim=-2)

        # Add (t, c) tensors to index
        keys = keys.detach().numpy()
        keys = np.ascontiguousarray(keys)
        self.index.add(keys)

    def search(self, query, k):

        T, C = query.shape
        
        # If we have enough memories, search and retrieve, otherwise return zeros
        if self.index.ntotal >= k:
            kvs = self.search_and_retrieve(np.ascontiguousarray(query.detach().numpy()), k)
            kvs = torch.tensor(kvs)
        else:
            kvs = torch.zeros((T, k, 2, C), device=device)

        return kvs

    def clear(self):
        self.index.reset()
        self.db[:] = 0
        self.db_offset = 0

Let's test it

In [7]:
c = 4
t = 2

knn = KNN(c, 100000, Path('../data/numpy/knn-test.db'))

for i in range(1000):
    vector_data = torch.tensor(np.random.random((t, 2, c)).astype('float32'))
    knn.add(vector_data)

query_data = torch.tensor(np.random.random((t, c)).astype('float32'))
query_data

tensor([[0.2335, 0.3282, 0.2448, 0.4773],
        [0.9102, 0.1002, 0.2155, 0.3850]])

Search returns a `(t k 2 c)` tensor which contains the top_k keys and values for each `(t c)` query.

Here our query is `(2 * 4)` so our results will be `(2 * 2 * 2 * 4)`

In [8]:
top_k = 2
knn.search(query_data, top_k) # (t k two c) tensor, returns top_k keys and values for each query.

tensor([[[[0.2213, 0.3221, 0.2657, 0.5227],
          [0.3680, 0.0744, 0.3142, 0.5691]],

         [[0.3283, 0.3834, 0.2424, 0.5119],
          [0.7836, 0.4819, 0.4896, 0.9313]]],


        [[[0.8644, 0.1047, 0.1088, 0.4424],
          [0.4714, 0.6815, 0.9910, 0.2402]],

         [[0.8612, 0.0825, 0.1059, 0.3038],
          [0.6274, 0.8587, 0.2832, 0.2548]]]])

Now we can integrate the memory into our multiheaded attention.

We will make a new class for this as we only use KNN on the second to last layer.

It will have a KNN memory for each batch dimension, and we will clear that memory if the file in that batch dimension changes.

We will know this as the `CustomMidiDataset` returns the file indices of each batch along with the data. 

These can be passed to our model, which in turn can pass them to the KNN attention block.

In [12]:
# Define i and j
i, j = 4, 4

# Create the mask
mask = torch.logical_not(torch.tril(torch.ones((i, j), dtype=torch.bool)))

# Example tensor w
w = torch.rand((i, j))

# Apply the mask
w = w.masked_fill(mask, float('-inf'))

print("Mask:\n", mask)
print("Original w:\n", torch.rand((i, j)))  # Example tensor before masking
print("Masked w:\n", w)

Mask:
 tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])
Original w:
 tensor([[0.9309, 0.0458, 0.1893, 0.1659],
        [0.0803, 0.6873, 0.5376, 0.6530],
        [0.9119, 0.4572, 0.0314, 0.4545],
        [0.2545, 0.7634, 0.3870, 0.7951]])
Masked w:
 tensor([[0.1389,   -inf,   -inf,   -inf],
        [0.1542, 0.8455,   -inf,   -inf],
        [0.9759, 0.5102, 0.5729,   -inf],
        [0.1436, 0.0020, 0.8549, 0.0032]])


In [9]:
class KNNAttention(torch.nn.Module):

    def __init__(self, dbFilePath, batch_size, n_embed, k, n_head = 8, dropout = 0.2):
        super().__init__()
        self.n_embed = n_embed
        self.k = k
        self.n_head = n_head
        self.head_size = n_embed // n_head
        head_total_size = n_head * self.head_size
        self.key = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.query = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.value = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.project = torch.nn.Linear(head_total_size, n_embed)
        self.dropout = torch.nn.Dropout(dropout)

        # Memory per batch dim
        self.knn = {i: KNN(n_embed, 100000, f'{dbFilePath}-batch_dim-{i}.db') for i in range(batch_size)} # KNN memory will get or create the files, so we just need to be consistent with the file names.
        self.current_file_indexes = None

        self.gate_bias = torch.nn.Parameter(torch.randn(self.n_head, 1, 1))
    

    def forward(self, x, batch_file_indexes):

        # Clear batch dim's knn memory if file changes
        if self.current_file_indexes != None:
            for i in range(len(self.current_file_indexes)):
                if self.current_file_indexes[i] != batch_file_indexes[i]:
                    self.knn[i].clear()

        self.current_file_indexes = batch_file_indexes

        B, T, C = x.shape

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # This helps to mitigate drift in the embeddings which can cause the historical keys to become less aligned to the current queries.
        q = F.normalize(q, dim=-1)
        k = F.normalize(k, dim=-1)

        ### LOCAL ATTENTION

        # Split heads
        q = rearrange(q, 'b t (h d) -> b h t d', h = self.n_head)
        k = rearrange(k, 'b t (h d) -> b h t d', h = self.n_head)
        v = rearrange(v, 'b t (h d) -> b h t d', h = self.n_head)

        # Without einsum we had to swap dims using k.transpose(-2, -1)
        w = einsum(q, k, 'b h i d, b h j d -> b h i j') * (self.head_size ** -0.5)
        
        # TODO: Relative positional encoding

        i, j = w.shape[-2:]
        mask = torch.logical_not(torch.tril(torch.ones((i,j), dtype = torch.bool))) # Can't cache this as its shape depends on whether we have XL memory or not.
        w = w.masked_fill(mask, float('-inf'))
        w = F.softmax(w, dim=-1)

        weighted_values = w@v # b h t d

        ### KNN ATTENTION
        knn_mask = torch.tensor([self.knn[i].index.ntotal > 0 for i in range(B)], dtype=torch.bool, device=device)

        # Only do knn if there are at least some memories
        if knn_mask.any():

            t1 = time.time()
            print ("Begin KNN operations")

            # Convert queries to search form
            q = rearrange(q, 'b h t d -> b t (h d)')

            # KNN returns zeroes if it doesn't have data.
            mem_kv = torch.stack([self.knn[i].search(q[i], k = self.k) for i in range(B)], dim = 0) # b, t, k, 2, c
            
            mem_k, mem_v = mem_kv.unbind(dim = -2)
            mem_k = rearrange(mem_k, 'b t k (h d) -> b h t k d', h=self.n_head)
            mem_v = rearrange(mem_v, 'b t k (h d) -> b h t k d', h=self.n_head)

            # Convert queries to attention form
            q = rearrange(q, 'b t (h d) -> b h t d', h = self.n_head)

            # Sum over d for each combination of batch, head, time and top k to get qk affinities, and hence weights for each k. resulting in a tensor of shape (b, h, t, k).
            mem_w = einsum('b h t d, b h t k d -> b h t k', q, mem_k)
            mem_w = mem_w * (self.head_size ** -0.5)
            mem_w = F.softmax(mem_w, dim=-1)

            # Weighted sum over the top k dimension for each combination of b, h, and t, resulting in a tensor of shape (b, h, t, d). Equivalent to doing w@v for each k and summing.
            mem_weighted_values = einsum('b h t k, b h t k d -> b h t d', mem_w, mem_v)

            ## Combined attention
            
            # Assume every memory has content. Empty memories will be masked out below.
            combined_weighted_values = mem_weighted_values * self.gate_bias + weighted_values * (1 - self.gate_bias)

            # Mask out combined weighted values where knn memory *is* empty and non-combined values where it *is not* empty, then merge them.
            combined_weighted_values = combined_weighted_values * knn_mask.view(B, 1, 1, 1) + weighted_values * (~knn_mask).view(B, 1, 1, 1)

            # Concat heads
            combined_weighted_values = rearrange(combined_weighted_values, 'b h t d -> b t (h d)')
            out = self.project(combined_weighted_values)

            t2 = time.time()
            print ("End KNN operations, time taken:", t2-t1)

        else:
            # Concat heads
            weighted_values = rearrange(weighted_values, 'b h t d -> b t (h d)')
            out = self.project(weighted_values)

        current_kv = torch.stack((k, v), dim=-2) # (b, t, 2, c)
        for i in range(B):
            self.knn[i].add(current_kv[i])

        return self.dropout(out)

# XL Memory

We simply need to append the previous iteration's key and value tensors to the current, allowing the queries to search over / swap information from the previous timesteps (which did the same to the timestep before that etc etc), creating a kind of delay-line memory that fades out over time.

In [11]:
class XLAttention(torch.nn.Module):

    def __init__(self, n_embed, n_head = 8, dropout = 0.2):
        super().__init__()
        self.n_embed = n_embed
        self.n_head = n_head
        self.head_size = n_embed // n_head
        head_total_size = n_head * self.head_size
        self.key = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.query = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.value = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.project = torch.nn.Linear(head_total_size, n_embed)
        self.dropout = torch.nn.Dropout(dropout)    

    def forward(self, x, xl_memory = None):

        B, T, C = x.shape

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # Chris's implementation does `queries = queries * (self.head_size ** -0.5)` here but I don't think it is correct.

        if xl_memory is not None:
            k_xl, v_xl = xl_memory.unbind(dim = -2) # assume stacked
            keys = torch.cat((k_xl, keys), dim = -2) # prepend XL memory
            values = torch.cat((v_xl, values), dim = -2) # prepend XL memory

        ### LOCAL ATTENTION

        # Split heads
        q = rearrange(q, 'b t (h d) -> b h t d', h = self.n_head)
        k = rearrange(k, 'b t (h d) -> b h t d', h = self.n_head)
        v = rearrange(v, 'b t (h d) -> b h t d', h = self.n_head)

        w = einsum(q, k, 'b h i d, b h j d -> b h i j') * (self.head_size ** -0.5)
        
        # TODO: Relative positional encoding

        i, j = w.shape[-2:]
        mask = torch.logical_not(torch.tril(torch.ones((i,j), dtype = torch.bool))) # Can't cache this as its shape depends on whether we have XL memory or not.
        w = w.masked_fill(mask, float('-inf'))
        w = F.softmax(w, dim=-1)

        weighted_values = w@v # b h t d
        # Concat heads
        weighted_values = rearrange(weighted_values, 'b h t d -> b t (h d)')
        
        out = self.project(weighted_values)

        # new XL memories

        # Concatenate key and value heads
        k = rearrange(k, 'b h t d -> b t (h d)', h = self.heads)
        v = rearrange(v, 'b h t d -> b t (h d)', h=self.heads)
        kv_memories = torch.stack((k, v), dim=-2) # b t 2 (h d)

        if xl_memory is None:
            new_xl_memory = kv_memories
        else:
            new_xl_memory = kv_memories[:, -T:]

        return self.dropout(out), new_xl_memory

### TODO

- Relative Positional embeddings
- Wire it all together
- Training loop (inc. pause / save and load / continue)

Once all that is working, consider

- Ragged memmap for data loading (will make loading much faster)
- Byte pair encoding (bigger vocab with common token pairs gives us a bigger effective context) 