In [1]:
import os
os.chdir(r'5 - TransformerXL')
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.tensorboard import SummaryWriter
from itertools import chain
from itertools import groupby
from functools import reduce
from typing import Collection, List
from pathlib import Path
import music21 as m21
musescore_path = '/usr/bin/mscore'
m21.environment.set('musicxmlPath', musescore_path)
m21.environment.set('musescoreDirectPNGPath', musescore_path)
from midi_encoding import *
from einops import rearrange, repeat, pack, unpack, einsum
import faiss

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device}.")

Using cuda.


In [2]:
if device == "cuda":
    print(f"Device: {torch.cuda.get_device_name()}.")

Device: NVIDIA GeForce RTX 4090.


In [3]:
!nvidia-smi

Sun Aug 18 17:27:54 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.31.01              Driver Version: 560.81         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:01:00.0  On |                  Off |
| 30%   32C    P0             46W /  450W |    3513MiB /  24564MiB |     36%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
vocab = MusicVocab()
vocab.size

392

# Building the Model

Our memory-augmented transformer will be very similar to the vanilla model developed in the previous notebook.

There will be three major additions, all of which are fairly simple:

- Relative positional embeddings.
- KNN lookup for keys (and their associated values).
- Recurrent 'TransformerXL' style memory.

## Einops

First, lets condense our previous implementation in two ways

- Add a dimension to our `MultiHeadAttention` module, abandoning the separate `SelfAttentionHead` module.
- Switch to using einops for shape manipulation as it is both simpler to write and read.

In [5]:
class MultiHeadAttention(torch.nn.Module):

    def __init__(self, n_embed, n_head = 8, dropout = 0.2):
        super().__init__()
        self.n_embed = n_embed
        self.n_head = n_head
        self.head_size = n_embed // n_head
        head_total_size = n_head * self.head_size
        self.key = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.query = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.value = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.proj = torch.nn.Linear(head_total_size, n_embed)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # Split heads
        q = rearrange(q, 'b t (h d) -> b h t d', h = self.n_head)
        k = rearrange(k, 'b t (h d) -> b h t d', h = self.n_head)
        v = rearrange(v, 'b t (h d) -> b h t d', h = self.n_head)

        # Without einsum we had to swap dims using k.transpose(-2, -1)
        w = einsum(q, k, 'b h i d, b h j d -> b h i j') * (self.head_size ** -0.5)
        
        # TODO: Relative positional encoding

        i, j = w.shape[-2:]
        mask = torch.tril(torch.ones((i,j), dtype = torch.bool))
        w = w.masked_fill(mask, float('-inf'))
        w = F.softmax(w, dim=-1)

        attention_scores = w@v

        # Concat heads
        attention_scores = rearrange(attention_scores, 'b h t d -> b t (h d)')

        # TODO: KNN memory

        out = self.proj(attention_scores)
        return self.dropout(out)

## KNN Memory

If we want to add KNN memory, we need an indexed data store to look up and retrieve keys and values.

For the index we can use [Faiss](https://github.com/facebookresearch/faiss) from Meta.

For the data store we can simply use a memory-mapped numpy array.

The following code is taked from the [Colab Notebook](https://colab.research.google.com/drive/1XZz1sjNt1MKRG6ul_hOGSJFQLS4lRtmJ?usp=sharing#scrollTo=mxnmadjyLgH1) accompanying the [Coding a Paper](https://www.youtube.com/playlist?list=PLam9sigHPGwOe8VDoS_6VT4jjlgs9Uepb) series:

In [6]:
class KNN():
    def __init__(self, dim, max_memories, db_filepath):
        self.dim = dim
        self.max_memories = max_memories
        self.shape = (max_memories, 2, dim) # TODO: Memory per batch dim, cleared when file is exhausted
        self.db_offset = 0
        self.db_filepath = db_filepath
        self.db = np.memmap(self.db_filepath, mode = 'w+', dtype = np.float32, shape = self.shape)
        self.index = faiss.IndexFlatL2(dim)


    def add_to_db(self, new_data):
        new_data_len = new_data.shape[0] # (b t)
        ids = (np.arange(new_data_len) + self.db_offset)
        self.db[ids] = new_data.detach().numpy()
        self.db_offset += new_data_len
        # Write to file
        self.db.flush()


    def search_and_retrieve(self, query, k):

        # The tooltip says the args are (n, x, k) but that's the CPP api, it's actually (x, k) in Python (n is the first dim of x anyway so can be inferred).
        distances, indices = self.index.search(query, k)
        
        kvs = self.db[indices]
        return kvs

    def add(self, new_data):
        new_data = new_data = rearrange(new_data, 'b t 2 (h d) -> (b t) 2 (h d)')

        # Add to db
        self.add_to_db(new_data)

        # Only keys are used in knn index
        keys, vals = new_data.unbind(dim=-2)

        # Add (b t) (h d) tensors to index
        keys = keys.detach().numpy()
        keys = np.ascontiguousarray(keys)
        self.index.add(keys)

    def search(self, query, k):
        T = query.shape[1]
        query = rearrange(query, 'b t (h d) -> (b t) (h d)')

        kvs = self.search_and_retrieve(np.ascontiguousarray(query.detach().numpy()), k)
        kvs = torch.tensor(kvs)
        kvs = kvs = rearrange(kvs, '(b t) k 2 (h d) -> b t k 2 (h d)', t = T)
        
        return kvs


    def clear(self):
        self.index.reset()
        self.db[:] = 0
        self.db_offset = 0

### TODO

- Test KNN
- Add KNN lookup to vanilla transformer
- XL recurrence
- Relative Positional embeddings

Once all that is working, consider

- Ragged memmap for data loading (allows moving to bigger dataset)
- Byte pair encoding (bigger vocab with common token pairs gives us a bigger effective context) 