In [1]:
import os
os.chdir(r'5 - TransformerXL')
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.tensorboard import SummaryWriter
from itertools import chain
from itertools import groupby
from functools import reduce
from typing import Collection, List
from pathlib import Path
import music21 as m21
musescore_path = '/usr/bin/mscore'
m21.environment.set('musicxmlPath', musescore_path)
m21.environment.set('musescoreDirectPNGPath', musescore_path)
from midi_encoding import *
from einops import rearrange, repeat, pack, unpack, einsum
import faiss
import time
import math

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device}.")

Using cuda.


In [2]:
if device == "cuda":
    print(f"Device: {torch.cuda.get_device_name()}.")

Device: NVIDIA GeForce RTX 4090.


In [3]:
!nvidia-smi

Fri Aug 23 14:50:52 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.02              Driver Version: 560.94         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:01:00.0  On |                  Off |
| 30%   32C    P0             49W /  450W |    1946MiB /  24564MiB |     40%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
vocab = MusicVocab()
vocab.size

392

# Building the Model

Our memory-augmented transformer will be very similar to the vanilla model developed in the previous notebook.

There will be three major additions, all of which are fairly simple:

- Relative positional embeddings.
- KNN lookup for keys (and their associated values).
- Recurrent 'TransformerXL' style memory.

## Einops

First, lets condense our previous implementation in two ways

- Add a dimension to our `MultiHeadAttention` module, abandoning the separate `SelfAttentionHead` module.
- Switch to using einops for shape manipulation as it is both simpler to write and read.

In [5]:
class MultiHeadAttention(torch.nn.Module):

    def __init__(self, n_embed, n_head = 8, dropout = 0.2):
        super().__init__()
        self.n_embed = n_embed
        self.n_head = n_head
        self.head_size = n_embed // n_head
        head_total_size = n_head * self.head_size
        self.key = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.query = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.value = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.proj = torch.nn.Linear(head_total_size, n_embed)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # Split heads
        q = rearrange(q, 'b t (h d) -> b h t d', h = self.n_head)
        k = rearrange(k, 'b t (h d) -> b h t d', h = self.n_head)
        v = rearrange(v, 'b t (h d) -> b h t d', h = self.n_head)

        # Without einsum we had to swap dims using k.transpose(-2, -1)
        w = einsum(q, k, 'b h i d, b h j d -> b h i j') * (self.head_size ** -0.5)
        
        # TODO: Relative positional encoding

        i, j = w.shape[-2:]
        mask = torch.tril(torch.ones((i,j), dtype = torch.bool))
        w = w.masked_fill(mask, float('-inf'))
        w = F.softmax(w, dim=-1)

        weighted_values = w@v

        # Concat heads
        weighted_values = rearrange(weighted_values, 'b h t d -> b t (h d)')

        # TODO: KNN memory

        out = self.proj(weighted_values)
        return self.dropout(out)

## KNN Memory

If we want to add KNN memory, we need an indexed data store to look up and retrieve keys and values.

For the index we can use [Faiss](https://github.com/facebookresearch/faiss) from Meta.

For the data store we can simply use a memory-mapped numpy array.

The following code is adapted from the [Colab Notebook](https://colab.research.google.com/drive/1XZz1sjNt1MKRG6ul_hOGSJFQLS4lRtmJ?usp=sharing#scrollTo=gs7RpvCdePZr) accompanying the [Coding a Paper](https://www.youtube.com/playlist?list=PLam9sigHPGwOe8VDoS_6VT4jjlgs9Uepb) series:

In [6]:
class KNN():
    def __init__(self, dim, max_memories, db_filepath):
        self.dim = dim
        self.max_memories = max_memories
        self.shape = (max_memories, 2, dim)
        self.db_offset = 0
        if db_filepath.exists():
            dbMode = 'r+'
        else:
            dbMode = 'w+' # Create file if it doesn't exist
        self.db = np.memmap(db_filepath, mode = dbMode, dtype = np.float32, shape = self.shape)
        self.index = faiss.IndexFlatL2(dim)

    def add_to_db(self, new_data):
        new_data_len = new_data.shape[0] # (t)
        ids = (np.arange(new_data_len) + self.db_offset)
        self.db[ids] = new_data.detach().numpy()
        self.db_offset += new_data_len
        # Write to file
        self.db.flush()

    def search_and_retrieve(self, query, k):

        # The tooltip says the args are (n, x, k) but that's the CPP api, it's actually (x, k) in Python (n is the first dim of x anyway so can be inferred).
        distances, indices = self.index.search(query, k)
        
        kvs = self.db[indices]
        return kvs

    def add(self, new_data):
        # new_data = (t, 2, c)

        # Add to db
        self.add_to_db(new_data)

        # Only keys are used in knn index
        keys, vals = new_data.unbind(dim=-2)

        # Add (t, c) tensors to index
        keys = keys.detach().numpy()
        keys = np.ascontiguousarray(keys)
        self.index.add(keys)

    def search(self, query, k):

        T, C = query.shape
        
        # If we have enough memories, search and retrieve, otherwise return zeros
        if self.index.ntotal >= k:
            kvs = self.search_and_retrieve(np.ascontiguousarray(query.detach().numpy()), k)
            kvs = torch.tensor(kvs)
        else:
            kvs = torch.zeros((T, k, 2, C), device=device)

        return kvs

    def clear(self):
        self.index.reset()
        self.db[:] = 0
        self.db_offset = 0

Let's test it

In [7]:
c = 4
t = 2

knn = KNN(c, 100000, Path('../data/numpy/knn-test.db'))

for i in range(1000):
    vector_data = torch.tensor(np.random.random((t, 2, c)).astype('float32'))
    knn.add(vector_data)

query_data = torch.tensor(np.random.random((t, c)).astype('float32'))
query_data

tensor([[0.2238, 0.3509, 0.1005, 0.2910],
        [0.9363, 0.2582, 0.4334, 0.8027]])

Search returns a `(t k 2 c)` tensor which contains the top_k keys and values for each `(t c)` query.

Here our query is `(2 * 4)` so our results will be `(2 * 2 * 2 * 4)`

In [8]:
top_k = 2
knn.search(query_data, top_k) # (t k two c) tensor, returns top_k keys and values for each query.

tensor([[[[0.1990, 0.3517, 0.1679, 0.2943],
          [0.9202, 0.7429, 0.7107, 0.8525]],

         [[0.1331, 0.3914, 0.1099, 0.2552],
          [0.3810, 0.6787, 0.6610, 0.5767]]],


        [[[0.9419, 0.2734, 0.3867, 0.8327],
          [0.4675, 0.3117, 0.3460, 0.1590]],

         [[0.9230, 0.3176, 0.5073, 0.7970],
          [0.5202, 0.4738, 0.0330, 0.5286]]]])

Now we can integrate the memory into our multiheaded attention.

We will make a new class for this as we only use KNN on the second to last layer.

It will have a KNN memory for each batch dimension, and we will clear that memory if the file in that batch dimension changes.

We will know this as the `CustomMidiDataset` returns the file indices of each batch along with the data. 

These can be passed to our model, which in turn can pass them to the KNN attention block.

In [9]:
# Define i and j
i, j = 4, 4

# Create the mask
mask = torch.logical_not(torch.tril(torch.ones((i, j), dtype=torch.bool)))

# Example tensor w
w = torch.rand((i, j))

# Apply the mask
w = w.masked_fill(mask, float('-inf'))

print("Mask:\n", mask)
print("Original w:\n", torch.rand((i, j)))  # Example tensor before masking
print("Masked w:\n", w)

Mask:
 tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])
Original w:
 tensor([[0.5835, 0.5431, 0.3137, 0.5969],
        [0.8058, 0.2264, 0.0946, 0.2151],
        [0.8888, 0.7451, 0.1404, 0.6317],
        [0.9363, 0.1994, 0.5483, 0.1877]])
Masked w:
 tensor([[0.7811,   -inf,   -inf,   -inf],
        [0.6212, 0.6408,   -inf,   -inf],
        [0.0985, 0.5483, 0.7571,   -inf],
        [0.6869, 0.0328, 0.5287, 0.2488]])


In [10]:
class KNNAttention(torch.nn.Module):

    def __init__(self, dbFilePath, batch_size,  k, n_embed, n_head, dropout):
        super().__init__()
        self.n_embed = n_embed
        self.k = k
        self.n_head = n_head
        head_size = n_embed // n_head
        self.scale_factor = head_size ** -0.5
        self.key = torch.nn.Linear(n_embed, n_embed, bias=False)
        self.query = torch.nn.Linear(n_embed, n_embed, bias=False)
        self.value = torch.nn.Linear(n_embed, n_embed, bias=False)
        self.project = torch.nn.Linear(n_embed, n_embed)
        self.dropout = torch.nn.Dropout(dropout)

        # Memory per batch dim
        self.knn = {i: KNN(n_embed, 100000, f'{dbFilePath}-batch_dim-{i}.db') for i in range(batch_size)} # KNN memory will get or create the files, so we just need to be consistent with the file names.
        self.current_file_idxs = None

        self.gate_bias = torch.nn.Parameter(torch.randn(self.n_head, 1, 1))
    

    def forward(self, relative_positions, batch_file_idxs, x):

        # Clear batch dim's knn memory if file changes
        if self.current_file_idxs != None:
            for i in range(len(self.current_file_idxs)):
                if self.current_file_idxs[i] != batch_file_idxs[i]:
                    self.knn[i].clear()

        self.current_file_idxs = batch_file_idxs

        B, T, C = x.shape

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # This helps to mitigate drift in the embeddings which can cause the historical keys to become less aligned to the current queries.
        q = F.normalize(q, dim=-1)
        k = F.normalize(k, dim=-1)

        ### LOCAL ATTENTION

        # Split heads
        q = rearrange(q, 'b t (h d) -> b h t d', h = self.n_head)
        k = rearrange(k, 'b t (h d) -> b h t d', h = self.n_head)
        v = rearrange(v, 'b t (h d) -> b h t d', h = self.n_head)

        # Without einsum we had to swap dims using k.transpose(-2, -1)
        w = einsum(q, k, 'b h i d, b h j d -> b h i j')
        i, j = w.shape[-2:]
        
        # Add relative positional encoding and scale
        w = w + relative_positions[..., -i:, -j:]
        w = w * self.scale_factor

        mask = torch.logical_not(torch.tril(torch.ones((i,j), dtype = torch.bool))) # Can't cache this as its shape depends on whether we have XL memory or not.
        w = w.masked_fill(mask, float('-inf'))
        w = F.softmax(w, dim=-1)

        weighted_values = w@v # b h t d

        ### KNN ATTENTION
        knn_mask = torch.tensor([self.knn[i].index.ntotal > 0 for i in range(B)], dtype=torch.bool, device=device)

        # Only do knn if there are at least some memories
        if knn_mask.any():

            t1 = time.time()
            print ("Begin KNN operations")

            # Convert queries to search form
            q = rearrange(q, 'b h t d -> b t (h d)')

            # KNN returns zeroes if it doesn't have data.
            mem_kv = torch.stack([self.knn[i].search(q[i], k = self.k) for i in range(B)], dim = 0) # b, t, k, 2, c
            
            mem_k, mem_v = mem_kv.unbind(dim = -2)
            mem_k = rearrange(mem_k, 'b t k (h d) -> b h t k d', h=self.n_head)
            mem_v = rearrange(mem_v, 'b t k (h d) -> b h t k d', h=self.n_head)

            # Convert queries to attention form
            q = rearrange(q, 'b t (h d) -> b h t d', h = self.n_head)

            # Sum over d for each combination of batch, head, time and top k to get qk affinities, and hence weights for each k. resulting in a tensor of shape (b, h, t, k).
            mem_w = einsum('b h t d, b h t k d -> b h t k', q, mem_k)
            mem_w = mem_w * self.scale_factor
            mem_w = F.softmax(mem_w, dim=-1)

            # Weighted sum over the top k dimension for each combination of b, h, and t, resulting in a tensor of shape (b, h, t, d). Equivalent to doing w@v for each k and summing.
            mem_weighted_values = einsum('b h t k, b h t k d -> b h t d', mem_w, mem_v)

            ## Combined attention
            
            # Assume every memory has content. Empty memories will be masked out below.
            combined_weighted_values = mem_weighted_values * self.gate_bias + weighted_values * (1 - self.gate_bias)

            # Mask out combined weighted values where knn memory *is* empty and non-combined values where it *is not* empty, then merge them.
            combined_weighted_values = combined_weighted_values * knn_mask.view(B, 1, 1, 1) + weighted_values * (~knn_mask).view(B, 1, 1, 1)

            # Concat heads
            combined_weighted_values = rearrange(combined_weighted_values, 'b h t d -> b t (h d)')
            out = self.project(combined_weighted_values)

            t2 = time.time()
            print ("End KNN operations, time taken:", t2-t1)

        else:
            # Concat heads
            weighted_values = rearrange(weighted_values, 'b h t d -> b t (h d)')
            out = self.project(weighted_values)

        current_kv = torch.stack((k, v), dim=-2) # (b, t, 2, c)
        for i in range(B):
            self.knn[i].add(current_kv[i])

        return self.dropout(out)

# XL Memory

We simply need to append the previous iteration's key and value tensors to the current, allowing the queries to search over / swap information from the previous timesteps (which did the same to the timestep before that etc etc), creating a kind of delay-line memory that fades out over time.

In [11]:
class XLAttention(torch.nn.Module):

    def __init__(self, n_embed, n_head, dropout):
        super().__init__()
        self.n_embed = n_embed
        self.n_head = n_head
        self.head_size = n_embed // n_head
        head_total_size = n_head * self.head_size
        self.key = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.query = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.value = torch.nn.Linear(n_embed, head_total_size, bias=False)
        self.project = torch.nn.Linear(head_total_size, n_embed)
        self.dropout = torch.nn.Dropout(dropout)    

    def forward(self, relative_positions, x, xl_memory = None):

        B, T, C = x.shape

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # Chris's implementation does `queries = queries * (self.head_size ** -0.5)` here but I don't think it is correct.

        if xl_memory is not None:
            k_xl, v_xl = xl_memory.unbind(dim = -2) # assume stacked
            keys = torch.cat((k_xl, keys), dim = -2) # prepend XL memory
            values = torch.cat((v_xl, values), dim = -2) # prepend XL memory

        ### LOCAL ATTENTION

        # Split heads
        q = rearrange(q, 'b t (h d) -> b h t d', h = self.n_head)
        k = rearrange(k, 'b t (h d) -> b h t d', h = self.n_head)
        v = rearrange(v, 'b t (h d) -> b h t d', h = self.n_head)

        w = einsum(q, k, 'b h i d, b h j d -> b h i j')
        i, j = w.shape[-2:]

        # Add relative positional encoding and scale
        w = w + relative_positions[..., -i:, -j:]
        w = w * (self.head_size ** -0.5)
        
        mask = torch.logical_not(torch.tril(torch.ones((i,j), dtype = torch.bool))) # Can't cache this as its shape depends on whether we have XL memory or not.
        w = w.masked_fill(mask, float('-inf'))
        w = F.softmax(w, dim=-1)

        weighted_values = w@v # b h t d
        # Concat heads
        weighted_values = rearrange(weighted_values, 'b h t d -> b t (h d)')
        
        out = self.project(weighted_values)

        # new XL memories

        # Concatenate key and value heads
        k = rearrange(k, 'b h t d -> b t (h d)', h = self.n_head)
        v = rearrange(v, 'b h t d -> b t (h d)', h = self.n_head)
        current_kv = torch.stack((k, v), dim=-2) # b t 2 (h d)

        if xl_memory is None:
            new_xl_memory = current_kv
        else:
            new_xl_memory = current_kv[:, -T:]

        return self.dropout(out), new_xl_memory

# KNN-XL

Now we can add XL memory to out KNN Attention

In [12]:
class KNN_XLAttention(torch.nn.Module):

    def __init__(self, dbFilePath, batch_size, k, n_embed, n_head, dropout):
        super().__init__()
        self.n_embed = n_embed
        self.k = k
        self.n_head = n_head
        head_size = n_embed // n_head
        self.scale_factor = head_size ** -0.5
        self.key = torch.nn.Linear(n_embed, n_embed, bias=False)
        self.query = torch.nn.Linear(n_embed, n_embed, bias=False)
        self.value = torch.nn.Linear(n_embed, n_embed, bias=False)
        self.project = torch.nn.Linear(n_embed, n_embed)
        self.dropout = torch.nn.Dropout(dropout)

        # Memory per batch dim
        self.knn = {i: KNN(n_embed, 100000, Path(f'{dbFilePath}/batch_dim-{i}.db')) for i in range(batch_size)} # KNN memory will get or create the files, so we just need to be consistent with the file names.
        self.current_file_idxs = None

        self.gate_bias = torch.nn.Parameter(torch.randn(self.n_head, 1, 1))
    

    def forward(self, batch_file_idxs, relative_positions, x, xl_memory = None):

        # Clear batch dim's knn memory if file changes
        if self.current_file_idxs != None:
            for i in range(len(self.current_file_idxs)):
                if self.current_file_idxs[i] != batch_file_idxs[i]:
                    self.knn[i].clear()

        self.current_file_idxs = batch_file_idxs

        B, T, C = x.shape

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # This helps to mitigate drift in the embeddings which can cause the historical keys to become less aligned to the current queries.
        q = F.normalize(q, dim=-1)
        k = F.normalize(k, dim=-1)

        if xl_memory is not None:
            k_xl, v_xl = xl_memory.unbind(dim = -2) # assume stacked
            keys = torch.cat((k_xl, keys), dim = -2) # prepend XL memory
            values = torch.cat((v_xl, values), dim = -2) # prepend XL memory

        ### LOCAL ATTENTION

        # Split heads
        q = rearrange(q, 'b t (h d) -> b h t d', h = self.n_head)
        k = rearrange(k, 'b t (h d) -> b h t d', h = self.n_head)
        v = rearrange(v, 'b t (h d) -> b h t d', h = self.n_head)

        w = einsum(q, k, 'b h i d, b h j d -> b h i j')
        i, j = w.shape[-2:]
        
        # Add relative positional encoding and scale
        w = w + relative_positions[..., -i:, -j:]
        w = w * self.scale_factor

        mask = torch.logical_not(torch.tril(torch.ones((i,j), dtype = torch.bool))) # Can't cache this as its shape depends on whether we have XL memory or not.
        w = w.masked_fill(mask, float('-inf'))
        w = F.softmax(w, dim=-1)

        weighted_values = w@v # b h t d

        ### KNN ATTENTION
        knn_mask = torch.tensor([self.knn[i].index.ntotal > 0 for i in range(B)], dtype=torch.bool, device=device)

        # Only do knn if there are at least some memories
        if knn_mask.any():

            t1 = time.time()
            print ("Begin KNN operations")

            # Convert queries to search form
            q = rearrange(q, 'b h t d -> b t (h d)')

            # KNN returns zeroes if it doesn't have data.
            mem_kv = torch.stack([self.knn[i].search(q[i], k = self.k) for i in range(B)], dim = 0) # b, t, k, 2, c
            
            mem_k, mem_v = mem_kv.unbind(dim = -2)
            mem_k = rearrange(mem_k, 'b t k (h d) -> b h t k d', h=self.n_head)
            mem_v = rearrange(mem_v, 'b t k (h d) -> b h t k d', h=self.n_head)

            # Convert queries to attention form
            q = rearrange(q, 'b t (h d) -> b h t d', h = self.n_head)

            # Sum over d for each combination of batch, head, time and top k to get qk affinities, and hence weights for each k. resulting in a tensor of shape (b, h, t, k).
            mem_w = einsum('b h t d, b h t k d -> b h t k', q, mem_k)
            mem_w = mem_w * self.scale_factor
            mem_w = F.softmax(mem_w, dim=-1)

            # Weighted sum over the top k dimension for each combination of b, h, and t, resulting in a tensor of shape (b, h, t, d). Equivalent to doing w@v for each k and summing.
            mem_weighted_values = einsum('b h t k, b h t k d -> b h t d', mem_w, mem_v)

            ## Combined attention
            
            # Assume every memory has content. Empty memories will be masked out below.
            combined_weighted_values = mem_weighted_values * self.gate_bias + weighted_values * (1 - self.gate_bias)

            # Mask out combined weighted values where knn memory *is* empty and non-combined values where it *is not* empty, then merge them.
            combined_weighted_values = combined_weighted_values * knn_mask.view(B, 1, 1, 1) + weighted_values * (~knn_mask).view(B, 1, 1, 1)

            # Concat heads
            combined_weighted_values = rearrange(combined_weighted_values, 'b h t d -> b t (h d)')
            out = self.project(combined_weighted_values)

            t2 = time.time()
            print ("End KNN operations, time taken:", t2-t1)

        else:
            # Concat heads
            weighted_values = rearrange(weighted_values, 'b h t d -> b t (h d)')
            out = self.project(weighted_values)


        # New XL memories

        # Concatenate key and value heads
        k = rearrange(k, 'b h t d -> b t (h d)', h = self.n_head)
        v = rearrange(v, 'b h t d -> b t (h d)', h = self.n_head)
        current_kv = torch.stack((k, v), dim=-2) # b t 2 c

        if xl_memory is None:
            new_xl_memory = current_kv
        else:
            new_xl_memory = current_kv[:, -T:]

        for i in range(B):
            self.knn[i].add(new_xl_memory[i])

        return self.dropout(out), new_xl_memory

# Relative Positional Embeddings

In the above Attention classes we have passed relative embeddings to the `forward` methods, which are then added to the weights.

This is in contrast to the fixed positional embeddings used by a vanilla transformer, which are applied at the input of the model.

We haven't yet defined the relative embedding structure. It is very similar to fixed embeddings, with the 0 point offset by the current position as defined by the triangular mask.

In [13]:
block_size = 7
q_pos = torch.arange(block_size, dtype=torch.long)
q_pos = q_pos.reshape(q_pos.shape[0], 1)
k_pos = torch.arange(block_size, dtype=torch.long)
rel_pos = k_pos - q_pos
inv_rel_pos = -rel_pos
inv_rel_pos

tensor([[ 0, -1, -2, -3, -4, -5, -6],
        [ 1,  0, -1, -2, -3, -4, -5],
        [ 2,  1,  0, -1, -2, -3, -4],
        [ 3,  2,  1,  0, -1, -2, -3],
        [ 4,  3,  2,  1,  0, -1, -2],
        [ 5,  4,  3,  2,  1,  0, -1],
        [ 6,  5,  4,  3,  2,  1,  0]])

In [14]:
masked_rel_pos = torch.max(inv_rel_pos, torch.zeros_like(inv_rel_pos))
masked_rel_pos

tensor([[0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0],
        [2, 1, 0, 0, 0, 0, 0],
        [3, 2, 1, 0, 0, 0, 0],
        [4, 3, 2, 1, 0, 0, 0],
        [5, 4, 3, 2, 1, 0, 0],
        [6, 5, 4, 3, 2, 1, 0]])

Rather than give every position its own index, beyond a certain distance we will group them into logarithmically bigger buckets as they get further from the current position.

In [15]:
n_buckets = 6
max_distance = 20
max_exact = n_buckets // 2
is_small = masked_rel_pos < max_exact
val_if_large = max_exact + (torch.log(masked_rel_pos.float() / max_exact) / math.log(max_distance / max_exact) * (n_buckets - max_exact)).long()
position_bucket_indices = torch.where(is_small, masked_rel_pos, val_if_large) # below a certain distance, use the raw value, otherwise use the log-scaled value
position_bucket_indices

tensor([[0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0],
        [2, 1, 0, 0, 0, 0, 0],
        [3, 2, 1, 0, 0, 0, 0],
        [3, 3, 2, 1, 0, 0, 0],
        [3, 3, 3, 2, 1, 0, 0],
        [4, 3, 3, 3, 2, 1, 0]])

In [16]:
block_pos = torch.arange(block_size, dtype=torch.long)
# context_pos = torch.arange(2*block_size, dtype=torch.long)
context_pos = torch.arange(-block_size, block_size, dtype=torch.long) # XL memory, context is twice block size, and current position starts in the middle.
block_rel_pos = rearrange(block_pos, 'i -> i 1')
context_rel_pos = rearrange(context_pos, 'j -> 1 j')
rel_pos = context_rel_pos - block_rel_pos
rel_pos

tensor([[ -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6],
        [ -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5],
        [ -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4],
        [-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3],
        [-11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2],
        [-12, -11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1],
        [-13, -12, -11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0]])

In [17]:
class XLRelativePosition(torch.nn.Module):
  def __init__(
      self,
      n_buckets,
      max_distance,
      n_head,
      scaling_factor):
    
    super().__init__()
    self.scale = scaling_factor
    self.num_buckets = n_buckets
    self.max_distance = max_distance
    self.relative_attention_embedding = torch.nn.Embedding(n_buckets, n_head)

  def relative_position_bucket(self, relative_position_matrix):
    inv_rel_pos = -relative_position_matrix
    masked_rel_pos = torch.max(inv_rel_pos, torch.zeros_like(inv_rel_pos))

    max_exact = self.num_buckets // 2

    is_small = masked_rel_pos < max_exact
    val_if_large = max_exact + (torch.log(masked_rel_pos.float() / max_exact) / math.log(self.max_distance / max_exact) * (self.num_buckets - max_exact)).long()

    # Clip the values to the number of buckets - 1
    val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, self.num_buckets - 1))

    return torch.where(is_small, masked_rel_pos, val_if_large)

  def forward(self, block_size):
    block_pos = torch.arange(block_size, dtype=torch.long)
    context_pos = torch.arange(-block_size, block_size, dtype=torch.long) # XL memory, context is twice block size, and current position starts in the middle.
    block_rel_pos = rearrange(block_pos, 'i -> i 1')
    context_rel_pos = rearrange(context_pos, 'j -> 1 j')
    rel_pos = context_rel_pos - block_rel_pos

    position_bucket_indices = self.relative_position_bucket(rel_pos)

    rp_values = self.relative_attention_embedding(position_bucket_indices)

    rp_values = rearrange(rp_values, 'i j h -> () h i j')
    return rp_values * self.scale


In [18]:
rel_pos = XLRelativePosition(n_buckets = 5, max_distance = 8, n_head = 2)
rp_values = rel_pos(6)
rp_values

tensor([[[[-0.6735,  0.8919,  0.8919,  1.2615,  1.2615, -0.3989, -0.7912,
           -0.7912, -0.7912, -0.7912, -0.7912, -0.7912],
          [-0.6735, -0.6735,  0.8919,  0.8919,  1.2615,  1.2615, -0.3989,
           -0.7912, -0.7912, -0.7912, -0.7912, -0.7912],
          [-0.6735, -0.6735, -0.6735,  0.8919,  0.8919,  1.2615,  1.2615,
           -0.3989, -0.7912, -0.7912, -0.7912, -0.7912],
          [-0.6735, -0.6735, -0.6735, -0.6735,  0.8919,  0.8919,  1.2615,
            1.2615, -0.3989, -0.7912, -0.7912, -0.7912],
          [-0.6735, -0.6735, -0.6735, -0.6735, -0.6735,  0.8919,  0.8919,
            1.2615,  1.2615, -0.3989, -0.7912, -0.7912],
          [-0.6735, -0.6735, -0.6735, -0.6735, -0.6735, -0.6735,  0.8919,
            0.8919,  1.2615,  1.2615, -0.3989, -0.7912]],

         [[-1.0920,  0.4090,  0.4090, -1.0790, -1.0790,  0.7588, -0.9241,
           -0.9241, -0.9241, -0.9241, -0.9241, -0.9241],
          [-1.0920, -1.0920,  0.4090,  0.4090, -1.0790, -1.0790,  0.7588,
       

In [19]:
class FeedForward(torch.nn.Module):

    def __init__(self, n_embed, dropout):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(n_embed, 4 * n_embed), # 4x is a common expansion factor
            torch.nn.GELU(),
            torch.nn.Linear(4 * n_embed, n_embed), # Project back to the residual stream
            torch.nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)
    
class Block(torch.nn.Module):

    def __init__(self, n_embed, attention, dropout):
        super().__init__()
        self.attention = attention
        self.ff = FeedForward(n_embed, dropout)
        self.layer_norm1 = torch.nn.LayerNorm(n_embed)
        self.layer_norm2 = torch.nn.LayerNorm(n_embed)

    def forward(self, x):
        # Residual connections
        attn_out, new_xl_memories = self.attention(self.layer_norm1(x))
        x = x + attn_out
        x = x + self.ff(self.layer_norm2(x))
        return x, new_xl_memories

In [20]:
class DecoderTransformer_KNN_XL(torch.nn.Module):

    def __init__(
            self, 
            vocab_size, 
            batch_size, # required so we can init a memory per batch dim
            n_embed = 384, # /6 heads = 64 per head
            n_head = 6, 
            n_layer = 6, 
            max_bar_position = 1024,
            top_k = 5,
            dropout = 0.2,
            n_rel_pos_buckets = 32,
            rel_pos_max_distance = 128):
        
        super().__init__()
        head_size = n_embed // n_head
        scaling_factor = head_size ** 0.5
        self.n_layer = n_layer
        self.max_bar_position = max_bar_position
        self.token_embedding = torch.nn.Embedding(vocab_size, n_embed)
        self.rel_pos = XLRelativePosition(n_buckets = n_rel_pos_buckets, max_distance = rel_pos_max_distance, n_head = n_head, scaling_factor = scaling_factor)
        self.rel_pos_knn = XLRelativePosition(n_buckets = n_rel_pos_buckets, max_distance = rel_pos_max_distance, n_head = n_head, scaling_factor = scaling_factor)
        self.beat_embedding = torch.nn.Embedding(SAMPLES_PER_BAR, n_embed)
        self.bar_embedding = torch.nn.Embedding(max_bar_position, n_embed)
        
        self.blocks = torch.nn.ModuleList([])
        for i in range(n_layer):

            if self.isKNNLayer(i):
                attention_type = KNN_XLAttention(
                            Path('../data/numpy/knn-demo'),
                            batch_size,
                            top_k,
                            n_embed,
                            n_head,
                            dropout)
            else:
                attention_type = XLAttention(
                            n_embed,
                            n_head,
                            dropout)

            self.blocks.append(Block(n_embed, attention_type, dropout))
            
        self.layer_norm = torch.nn.LayerNorm(n_embed)
        self.lm_head = torch.nn.Linear(n_embed, vocab_size)

    def isKNNLayer(self, i):
        return i == self.n_layer - 2

    def forward(self, batch_file_idxs, x, xl_memories=None, targets=None):

        B, T, C = x.shape()

        # Could split these out in one go using the unbind function
        token_idx = x[:, :, 0] # (B,T)
        time_idx = x[:, :, 1] # (B,T)

        sample_idx = time_idx % SAMPLES_PER_BAR # (B,T)
        bar_idx = (time_idx // SAMPLES_PER_BAR) % self.max_bar_position # (B,T)

        if xl_memories is None:
            xl_memories = (None,) * self.n_layer
        else:
            xl_memories = xl_memories

        rel_pos = self.rel_pos(T)
        rel_pos_knn = self.rel_pos_knn(T)

        token_embed = self.token_embedding(token_idx) # (B,T,Embed)
        bar_embed = self.bar_embedding(bar_idx) # (B,T,Embed)
        sample_embed = self.beat_embedding(sample_idx) # (B,T,Embed)

        x = token_embed + bar_embed + sample_embed

        # Store the XL memories for each pass
        new_xl_memories = []

        for i, block in enumerate(self.blocks):

            if self.isKNNLayer(i):
                x, xl_mem = block(batch_file_idxs, rel_pos_knn, x, xl_memories[i])
            else:
                x, xl_mem = block(rel_pos, x, xl_memories[i])

            new_xl_memories.append(xl_mem.detach())

        x = self.layer_norm(x)

        # TODO: Convert this section to use einops rearrange
        if targets is None:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None
        else:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # Flatten all the batches
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        if len(new_xl_memories) > 0:
            return logits, loss, new_xl_memories
        else:
            return logits, loss

In [21]:
batch_size = 64

model = DecoderTransformer_KNN_XL(vocab_size=vocab.size, batch_size=batch_size)

print(sum(p.numel() for p in model.parameters()))

11347982


This is slightly less than our vanilla model, which was surprising. Looking at the breakdown below though, it seems to be that our relative position params are quite small (`n_bucket * n_head` rather than `T * n_embed`).

Also we aren't registering a `tril` buffer which was included in the vanilla param count (256 * 256 * n_head * n_layer).

In [22]:
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

Model's state_dict:
token_embedding.weight 	 torch.Size([392, 384])
rel_pos.relative_attention_embedding.weight 	 torch.Size([32, 6])
rel_pos_knn.relative_attention_embedding.weight 	 torch.Size([32, 6])
beat_embedding.weight 	 torch.Size([32, 384])
bar_embedding.weight 	 torch.Size([1024, 384])
blocks.0.attention.key.weight 	 torch.Size([384, 384])
blocks.0.attention.query.weight 	 torch.Size([384, 384])
blocks.0.attention.value.weight 	 torch.Size([384, 384])
blocks.0.attention.project.weight 	 torch.Size([384, 384])
blocks.0.attention.project.bias 	 torch.Size([384])
blocks.0.ff.net.0.weight 	 torch.Size([1536, 384])
blocks.0.ff.net.0.bias 	 torch.Size([1536])
blocks.0.ff.net.2.weight 	 torch.Size([384, 1536])
blocks.0.ff.net.2.bias 	 torch.Size([384])
blocks.0.layer_norm1.weight 	 torch.Size([384])
blocks.0.layer_norm1.bias 	 torch.Size([384])
blocks.0.layer_norm2.weight 	 torch.Size([384])
blocks.0.layer_norm2.bias 	 torch.Size([384])
blocks.1.attention.key.weight 	 torch.Size([38

### TODO

- Training loop (inc. pause / save and load / continue)

Once that is working, consider

- Ragged memmap for data loading (will make loading much faster)
- Byte pair encoding (bigger vocab with common token pairs gives us a bigger effective context) 