In [2]:
import torch
import torch.nn as nn
from einops import rearrange, reduce

In [3]:
# simple function
def mod_addition(a, b, P):
    return (a + b) % P

P = 13
a = 6
b = 11

print(mod_addition(a, b, P))

4


In [None]:
# complex algorithm
def mod_addition_alg(a, b, P):
    w = torch.tensor([14, 35, 41, 42, 52])                          # frequencies
    c = rearrange(torch.arange(P), "x -> x 1")                      # potential c values
    cos = torch.cos;    sin = torch.sin                             # to simplify following expression

    terms = (cos(w*a)*cos(w*b) - sin(w*a)*sin(w*b)) * cos(w*c) \
        + (sin(w*a)*cos(w*b) - cos(w*a)*sin(w*b)) * sin(w*c)        # trig terms, expanded out
    logits = reduce(terms, "h w -> h", "sum")                       # sum across the frequencies
    print(logits)

    i_max = torch.argmax(logits).item()
    return c[i_max]

P = 13
a = 6
b = 11

print(mod_addition_alg(a, b, P))

tensor([ 0.9075, -1.0094,  1.0857, -1.1346,  1.1547, -1.1456,  1.1075, -1.0413,
         0.9488, -0.8322,  0.6946, -0.5394,  0.3705])
tensor([4])


In [3]:
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)

<torch._C.Generator at 0x1d4ffccd550>

In [None]:
# Model architecture
# "one-layer ReLU transformer, token embeddings with d = 128, learned positional embeddings, 4 attention heads of dimension d/4 = 32, and n = 512 hidden units in the MLP. In other experiments, we vary the depth and dimension of the model. We did not use LayerNorm or tie our embed/unembed matrices."

class SelfAttentionHead(nn.Module):
    def __init__(self, d_token, d_head, l_context):
        super().__init__()

        self.d_head = d_head

        self.key = nn.Linear(d_token, d_head, bias=False)
        self.query = nn.Linear(d_token, d_head, bias=False)
        self.value = nn.Linear(d_token, d_head, bias=False)
        self.softmax = nn.Softmax(1)
        
        self.register_buffer("tril", torch.tril(torch.ones(l_context, l_context))) # TODO: delete

    def forward(self, x):
        # x:        (b, t, d_token)
        # K, Q, V:  (b, t, d_head)
        # Q @ K':   (b, t, t)
        _, T, _ = x.shape

        K_transpose = rearrange(self.key(x), "b t d_head -> b d_head t")
        Q = self.query(x)
        V = self.value(x)

        wei = self.softmax(
            ((Q @ K_transpose) / (self.d_head)**0.5).masked_fill(
                self.tril[:T, :T] == 0, -torch.inf)
        ) 
        attention = wei @ V

        return attention


class MultiHeadAttention(nn.Module):
    def __init__(self, l_context, d_token, N_heads, d_head):
        super().__init__()
        self.N_heads = N_heads
        self.d_head = d_head

        self.sa_heads = nn.ModuleList(
            [SelfAttentionHead(d_token, d_head, l_context) for _ in range(self.N_heads)]
        )

        # projects from concatenated sa_head outputs to something to add to the residual stream
        self.proj = nn.Linear(N_heads*d_head, d_token) 

    def forward(self, x):
        x_concat = torch.concat([sa(x) for sa in self.sa_heads], dim=2)
        x += self.proj(x_concat)
        
        return x



class Block(nn.Module):
    def __init__(self, l_context, d_token, 
                 N_heads, d_head, 
                 d_ffwd):
        super().__init__()
        
        self.sa_heads = MultiHeadAttention(l_context, d_token, N_heads, d_head)

        self.ln1 = nn.LayerNorm(d_token)

        self.ffwd_head = nn.Sequential(
            nn.Linear(d_token, d_ffwd),
            nn.ReLU(),
            nn.Linear(d_ffwd, d_token)
        )

        self.ln2 = nn.LayerNorm(d_token)


    def forward(self, x):
        x += self.ln1(self.sa_heads(x))
        x += self.ln2(self.ffwd_head(x))

        return x


class ModAdd(torch.nn.Module):
    def __init__(self, p, 
                 d,
                 N_heads, d_head,
                 d_ffwd, 
                 n):
        """
        model = ModAdd(p, ...)
        c = model(a, b)
        -> c = (a + b) % p
        """

        super().__init__()

        self.d_inout = p                # prime number on the RHS of '%'
        # Residual stream
        self.d_token = d                # embedding dimension of tokens
        self.l_context = 2              # context length
        # Attention heads
        self.N_heads = N_heads          # no. of sa_heads working in parallel
        self.d_head = d_head            # dimension of each sa_head (ie. final dim of K, Q, V arrays)
        # Feed-forward layer (within block)
        self.d_ffwd = d_ffwd
        # Linear layer TODO: delete
        # self.d_linear = n

        # Layers
        self.token_embedding_table = nn.Embedding(
            num_embeddings=self.d_inout, 
            embedding_dim=self.d_token
        )
        
        self.block = Block(
            self.l_context, self.d_token, 
            self.N_heads, self.d_head, 
            self.d_ffwd)

        self.linear = nn.Linear(self.d_token, self.d_inout)

        self.cross_entropy = nn.CrossEntropyLoss()
        
    def forward(self, X, X_target=None):
        x = self.token_embedding_table(X)           # no positional encoding needed
        x_ffwd = self.block(x)
        logits = self.linear(x_ffwd)

        if X_target is None:
            loss = None
            return logits, loss
        
        B, T, C = logits.shape
        assert (B == 1) & (C == self.d_inout)
        loss = self.cross_entropy(logits.view(B))
        
        return logits, loss

p = 7
model = ModAdd(p=p, d=128, N_heads=4, d_head=32, d_ffwd=128*4, n=512).to(device)

X = torch.randint(0, p, [5, 2])
Y_pred = model(X)
print(Y_pred.shape)

wei: tensor([[0.5000, 0.5000],
        [0.5000, 0.5000]], grad_fn=<SelectBackward0>)
Q: tensor([[-0.1264, -0.0083,  0.0090, -0.9598, -0.5485,  0.5829, -0.9033,  0.1759,
         -1.4082,  0.5907, -0.0114,  0.1595,  0.6726,  0.2361, -0.2484, -0.2638,
         -0.2969,  0.5122,  0.0926, -0.1770,  0.4638, -0.8553,  0.3490, -0.8564,
         -1.4068, -1.3824, -0.6120,  0.5088,  0.1498,  0.0990,  0.5465, -0.9410],
        [-0.1264, -0.0083,  0.0090, -0.9598, -0.5485,  0.5829, -0.9033,  0.1759,
         -1.4082,  0.5907, -0.0114,  0.1595,  0.6726,  0.2361, -0.2484, -0.2638,
         -0.2969,  0.5122,  0.0926, -0.1770,  0.4638, -0.8553,  0.3490, -0.8564,
         -1.4068, -1.3824, -0.6120,  0.5088,  0.1498,  0.0990,  0.5465, -0.9410]],
       grad_fn=<SelectBackward0>)
K_transpose: tensor([[-6.8638e-01, -6.8638e-01],
        [ 1.4519e-01,  1.4519e-01],
        [-1.0565e+00, -1.0565e+00],
        [-1.1058e+00, -1.1058e+00],
        [-4.7714e-01, -4.7714e-01],
        [-3.1765e-01, -3.1765e-01]

In [None]:
"""
Changes:
- redo self-attention
    -> no need for 'attention' part of transformer? (since position isn't important)
       (maybe make it easy to turn on and off or at least have a think about it)
    -> actually gonna be simpler because there's only ever one new token and it has (a,b) as context (no need for tril mumbo jumbo)
- simplify the loss calculation
    -> the output is a (..., 1) vector (ie. c), so that might make it a bit simpler...? (not really, but just need to implement)
- sample the logits (just argmax for accuracy)

TODO:
- train it
- look at Nander's library
"""

In [None]:
x = torch.randn(4, 3)
wei = torch.tril(torch.ones(3, 3))
wei /= reduce(wei, "i j -> j 1", "sum")
print(wei)
xbow = wei @ x

tensor([[0.3333, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [1.0000, 1.0000, 1.0000]])


In [22]:
# Data
# 30% training, 70% test

In [None]:
# Training
# full-batch gradient descent
# AdamW optimiser
# learning rage = 0.001, 
# N_epochs = 40_000