In [2]:
import torch
from torch import nn
import math

In [3]:
# Build Map:

# Transformer = Tokenizer -> Embedder -> Attention Block (Nx) -> Norm -> Linear -> Softmax
# Tokenizer = BPE / import
# Embedder = random weights
# Attn Block = RMSNorm + MHA + (add residual x) + RMSNorm + FFN (MLP) + (add residual x)
# RMSNorm = x / (RMS(x) + eps)
# MHA = project x to W_i, Q_i, K_i and ship to Attn Head i -> concat results *W_o -> out
# Attn Head i = [Q, K = RoPE(Q, K)] -> softmax((QK^T) / sqrt(d/H) + M) V -> out
# MLP = (linear -> activation) (Nx) -> out

# i guess we're doing batched


# d = dim(embedding), T = #toks, H = #heads, B = batch size, eps for numerical stability
d = 1024
T = 200
H = 8
B = 1
eps = 10 ** -4

d_ff = 2048 # d_ff = dim(FFN hidden layers)
MAXT = 4098 # max tokens
N = 4098 # vocab size

In [4]:
class FeedForwardNetwork(nn.Module):
    # TODO:
    #   SwiGLU
    #   support n layers?

    def __init__(self) -> None:
        # project onto d_ff-space
        super().__init__()
        self.stack = nn.Sequential(nn.Linear(d, d_ff),
                                    nn.GELU(),
                                    nn.Linear(d_ff, d))
        
    def forward(self, x):
        return self.stack(x)

In [5]:
class MultiHeadedAttention(nn.Module):
    # TODO:
    #   RoPE
    #   KV cache
    #   update mask?
    #   matmuls?
    #   require_grad?
    
    def __init__(self) -> None:
        super().__init__()
        # assuming q = v = d/H
        self.W_q = nn.Linear(d, d, bias=False)
        self.W_k = nn.Linear(d, d, bias=False)
        self.W_v = nn.Linear(d, d, bias=False)
        self.W_o = nn.Linear(d, d, bias=False)
        self.register_buffer('mask', torch.triu(torch.ones((1, 1, MAXT, MAXT)) * -float('inf'), 1)) # torch.finfo(dtype).min

    def forward(self, X):
        # X, Q, K, V : (B x T x d)
        Q, K, V = self.W_q(X), self.W_k(X), self.W_v(X)

        # reshape into heads [row: [tok: [dh: [], dh: [], ... ], tok: []] row: ]
        # [row: [dh: [tok: [], tok: []], dh: [tok: [], tok: []] row: ]
        # Q_h, K_h, V_h : (B x H x T x d_h)
        if d % H:
            raise Exception('dimension must divide number of heads')
        d_h = d // H
        Q_h = torch.reshape(Q, (B, T, H, d_h)).transpose(1, 2)
        K_h = torch.reshape(K, (B, T, H, d_h)).transpose(1, 2)
        V_h = torch.reshape(V, (B, T, H, d_h)).transpose(1, 2)

        # pattern : (B x H x T x T)
        # heads : (B x T x H x d_h)
        # out : (B x T x d)
        pattern = torch.softmax(Q_h @ K_h.transpose(2, 3) / math.sqrt(d_h) + self.mask[:,:,:T,:T], -1)
        heads = (pattern @ V_h).transpose(1, 2)
        out = torch.reshape(heads, (B, T, d))
        return self.W_o(out)

In [6]:
class AttentionBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.MHA = MultiHeadedAttention()
        self.FFN = FeedForwardNetwork()
        self.RMS1 = nn.RMSNorm(d)
        self.RMS2 = nn.RMSNorm(d)
    
    def forward(self, X):
        # X, X_n, X_out, Y, Y_n, Y_out, out : (B x T x d)
        X_n = self.RMS1(X)
        X_out = self.MHA(X_n)
        Y = X + X_out

        Y_n = self.RMS2(Y)
        Y_out = self.FFN(Y_n)
        out = Y + Y_out
        return out

In [7]:
class Transformer(nn.Module):
    # TODO:
    #   dropout
    #   input params
    #   LM head? (bias optional)
    def __init__(self) -> None:
        super().__init__()
        self.embedding = nn.Embedding(N, d)
        self.block = AttentionBlock()
        self.RMS = nn.RMSNorm(d)

    def forward(self, X):
        out = self.embedding(X)
        out = self.block(out)
        out = self.RMS(out)
        out = out @ self.embedding.weight.T
        return out

In [10]:
transformer = Transformer()
X = torch.randint(N, (B, T))
out = transformer(X)
print(out.shape)
print(out)

torch.Size([1, 200, 4098])
tensor([[[-34.7669,  36.9258,  17.3228,  ..., -12.8140,  69.5626,  18.0535],
         [  1.7060,  21.3565,  -6.8058,  ...,   7.1570,  17.1251,  -6.3025],
         [  2.7964,  54.0525,  41.5881,  ...,  18.3460,  26.6338,  24.1822],
         ...,
         [-10.4525, -28.0108, -12.2978,  ...,  30.8323,  45.6175,  13.5453],
         [ -3.5123, -62.7142,  -2.9039,  ...,  12.9198,  12.1696, -41.4250],
         [ 31.0708,  52.5245,  34.8461,  ..., -22.3652,   1.8458, -29.1236]]],
       grad_fn=<UnsafeViewBackward0>)


In [9]:
test = torch.Tensor([[[1,2],[3,4],[5,6]],[[1,2],[3,4],[5,6]]])
print(test.transpose(1,2))

tensor([[[1., 3., 5.],
         [2., 4., 6.]],

        [[1., 3., 5.],
         [2., 4., 6.]]])
