In [11]:
import torch
from torch import nn
import math

In [None]:
# Build Map:

# Transformer = Tokenizer -> Embedder -> Attention Block (Nx) -> Norm -> Linear -> Softmax
# Tokenizer = BPE / import
# Embedder = random weights
# Attn Block = RMSNorm + MHA + (add residual x) + RMSNorm + FFN (MLP) + (add residual x)
# RMSNorm = x / RMS(x) + eps
# MHA = project x to W_i, Q_i, K_i and ship to Attn Head i -> concat results *W_o -> out
# Attn Head i = [Q, K = RoPE(Q, K)] -> softmax((QK^T) / sqrt(q) + M) V -> out
# MLP = (linear -> activation) (Nx) -> out

# i guess we're doing batched


# k = dim(embedding), T = #toks, H = #heads, B = batch size, eps for numerical stability
k = 1024
T = 200
H = 8
eps = 10 ** -4
B = 50

In [None]:
class Transformer(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x):
        return x

In [None]:
class AttentionHead(nn.Module):
    # TODO:
    # RoPE
    # KV cache
    # update mask?
    # q, v here or elsewhere; enforce?
    # matmuls
    # require_grad?

    def __init__(self, q, v) -> None:
        # q = v = k / H (?)
        super().__init__()
        self.q = q
        self.v = v
        self.W_q = nn.Parameter(torch.rand((k, q)))
        self.W_k = nn.Parameter(torch.rand((k, q)))
        self.W_v = nn.Parameter(torch.rand((k, v)))
        self.register_buffer('mask', torch.Tensor([[-math.inf if j > i else 0 for j in range(n)] for i in range(n)]))
        
    def forward(self, x):
        Q, K, V = x @ self.W_q, x @ self.W_k, x @ self.W_v
        pattern = (Q @ K.T) / math.sqrt(self.q) + self.mask # (nxn)
        torch.nn.Softmax(pattern)
        return pattern @ V # (nxv)

In [None]:
class MLP(nn.Module):
    # TODO:
    # SwiGLU
    # support n layers?

    def __init__(self, d_ff) -> None:
        # project onto d_ff-space
        super().__init__()
        self.stack = nn.Sequential([nn.Linear(k, d_ff),
                                    nn.GELU(),
                                    nn.Linear(d_ff, k),
                                    nn.GELU()])
        
    def forward(self, x):
        return self.stack(x)

In [None]:
class AttentionBlock(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(k // H, k // H) for _ in range(H)])
        self.W_o = nn.Parameter(torch.rand(k, k))
        self.MLP = MLP()
        self.RMS_1 = nn.RMSNorm()
        self.RMS_2 = nn.RMSNorm()

    def forward(self, x):
        x_norm = self.RMS_1(x)
        x_out = torch.cat([head(x_norm) for head in self.heads], dim=1)
        y = x + (x_out @ self.W_o)
        y_norm = self.RMS_2(y)
        y_out = self.MLP(y_norm)
        out = x + y_out
        return out

In [None]:
attn = AttentionHead(2, 2)
x = torch.ones((n, k))
attn.forward(x)

tensor([[      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [      -inf,       -inf],
        [     