In [2]:
import torch
from torch import nn
import math

In [None]:
# Build Map:

# Transformer = Tokenizer -> Embedder -> Attention Block (Nx) -> Norm -> Linear -> Softmax
# Tokenizer = BPE / import
# Embedder = random weights
# Attn Block = RMSNorm + MHA + (add residual x) + RMSNorm + FFN (MLP) + (add residual x)
# RMSNorm = x / (RMS(x) + eps)
# MHA = project x to W_i, Q_i, K_i and ship to Attn Head i -> concat results *W_o -> out
# Attn Head i = [Q, K = RoPE(Q, K)] -> softmax((QK^T) / sqrt(d/H) + M) V -> out
# MLP = (linear -> activation) (Nx) -> out

# TODO:
#   DTYPE configs?
#   general dtype enforcement
#   remove global params



# d = dim(embedding), T = #toks, H = #heads, B = batch size, eps for numerical stability
d = 1024
T = 200
H = 8
B = 1
eps = 10 ** -4

d_ff = 2048 # d_ff = dim(FFN hidden layers)
MAXT = 4098 # max tokens
N = 4098 # vocab size
RoPE_base = 10000 # RoPE base (just need sufficiently large?)

f32 = torch.float32

In [4]:
class FeedForwardNetwork(nn.Module):
    # TODO:
    #   SwiGLU
    #   support n layers?

    def __init__(self) -> None:
        # project onto d_ff-space
        super().__init__()
        self.stack = nn.Sequential(nn.Linear(d, d_ff),
                                    nn.GELU(),
                                    nn.Linear(d_ff, d))
        
    def forward(self, x):
        return self.stack(x)

In [67]:
class MultiHeadedAttention(nn.Module):
    # TODO:
    #   RoPE X
    #   KV cache
    #   update mask? X
    #   matmuls? X
    #   require_grad?
    
    def __init__(self) -> None:
        super().__init__()
        # assuming q = v = d/H
        self.W_q = nn.Linear(d, d, bias=False)
        self.W_k = nn.Linear(d, d, bias=False)
        self.W_v = nn.Linear(d, d, bias=False)
        self.W_o = nn.Linear(d, d, bias=False)

        if d % H:
            raise Exception('#heads must divide dim(embedding)')
        self.d_h = d // H

        # causal mask
        self.register_buffer('mask', torch.triu(torch.ones((1, 1, MAXT, MAXT), dtype=bool), 1)) # torch.finfo(dtype).min

        # angles : (1 x 1 x MAXT x d_h/2)
        # thetas : (d_h/2)
        if self.d_h & 1:
            raise Exception('dim(head) must be even')
        
        self.register_buffer('theta', RoPE_base ** (-2 * torch.arange(self.d_h // 2, dtype=torch.float32) / self.d_h))

    def forward(self, X):
        # X, Q, K, V : (B x T x d)
        Q, K, V = self.W_q(X), self.W_k(X), self.W_v(X)

        # reshape into heads [row: [tok: [dh: [], dh: [], ... ], tok: []] row: ]
        # [row: [dh: [tok: [], tok: []], dh: [tok: [], tok: []] row: ]
        # Q_h, K_h, V_h : (B x H x T x d_h)
        Q_h = torch.reshape(Q, (B, T, H, self.d_h)).transpose(1, 2)
        K_h = torch.reshape(K, (B, T, H, self.d_h)).transpose(1, 2)
        V_h = torch.reshape(V, (B, T, H, self.d_h)).transpose(1, 2)


        # RoPE:
        #   x_(i,2j) = cos(angles[i]) * x_(i,2j) - sin(angles[i]) * x_(i,2j+1)
        #   x_(i,2j+1) = sin(angles[i]) * x_(i,2j) + cos(angles[i]) * x_(i,2j+1)
        angles = torch.outer(torch.arange(T, dtype=torch.float32), self.theta)[None,None,:,:]
        cos = torch.cos(angles).to(Q_h.dtype)   # Q_h/K_h.dtype
        sin = torch.sin(angles).to(K_h.dtype)

        Q_00, Q_10 = Q_h[:,:,:,0::2], Q_h[:,:,:,1::2]
        K_00, K_10 = K_h[:,:,:,0::2], K_h[:,:,:,1::2]
        Q_01 = cos * Q_00 - sin * Q_10
        Q_11 = sin * Q_00 + cos * Q_10
        K_01 = cos * K_00 - sin * K_10
        K_11 = sin * K_00 + cos * K_10

        Q_r = torch.flatten(torch.stack((Q_01, Q_11), dim=-1), start_dim=-2, end_dim=-1)
        K_r = torch.flatten(torch.stack((K_01, K_11), dim=-1), start_dim=-2, end_dim=-1)


        # pattern : (B x H x T x T)
        # heads : (B x T x H x d_h)
        # out : (B x T x d)
        mask = self.mask[:,:,:T,:T]
        scores = Q_r @ K_r.transpose(-2, -1) * self.d_h ** -0.5     # ** -0.5 vs sqrt
        scores = scores.masked_fill(mask, torch.finfo(scores.dtype).min)
        pattern = torch.softmax(scores, -1)
        heads = (pattern @ V_h).transpose(1, 2).contiguous()    # contiguous mem
        out = torch.reshape(heads, (B, T, d))
        return self.W_o(out)

In [68]:
class AttentionBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.MHA = MultiHeadedAttention()
        self.FFN = FeedForwardNetwork()
        self.RMS1 = nn.RMSNorm(d)
        self.RMS2 = nn.RMSNorm(d)
    
    def forward(self, X):
        # X, X_n, X_out, Y, Y_n, Y_out, out : (B x T x d)
        X_n = self.RMS1(X)
        X_out = self.MHA(X_n)
        Y = X + X_out

        Y_n = self.RMS2(Y)
        Y_out = self.FFN(Y_n)
        out = Y + Y_out
        return out

In [69]:
class Transformer(nn.Module):
    # TODO:
    #   dropout
    #   input params
    #   LM head? (bias optional) X
    
    def __init__(self) -> None:
        super().__init__()
        self.embedding = nn.Embedding(N, d)
        self.block = AttentionBlock()
        self.RMS = nn.RMSNorm(d)

    def forward(self, X):
        out = self.embedding(X)
        out = self.block(out)
        out = self.RMS(out)
        out = out @ self.embedding.weight.T
        return out

In [70]:
transformer = Transformer()
X = torch.randint(N, (B, T))
out = transformer(X)
print(out.shape)
print(out)

torch.Size([1, 200, 4098])
tensor([[[  8.0043, -23.9523,   4.7797,  ...,  16.7017, -12.4647, -17.4723],
         [ -4.6224,  17.3808,   7.3352,  ...,  25.2730,   0.0876,  16.3749],
         [-20.0768, -27.5841,  18.2552,  ...,  43.3483,   9.3240,   8.4941],
         ...,
         [ 71.5850,  -8.3694, -26.9546,  ...,  30.9275,   6.8790,  60.5894],
         [ -2.5121,  -4.8777, -62.1989,  ...,  16.8167,  74.4680, -32.2724],
         [-37.1538,   5.4505, -22.2627,  ..., -15.7488,  20.9891,  33.6959]]],
       grad_fn=<UnsafeViewBackward0>)


In [34]:
print(torch.sin(torch.outer(torch.arange(6), RoPE_base ** (-2 * torch.arange(4) / 4))))

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 8.4147e-01,  9.9998e-03,  1.0000e-04,  1.0000e-06],
        [ 9.0930e-01,  1.9999e-02,  2.0000e-04,  2.0000e-06],
        [ 1.4112e-01,  2.9995e-02,  3.0000e-04,  3.0000e-06],
        [-7.5680e-01,  3.9989e-02,  4.0000e-04,  4.0000e-06],
        [-9.5892e-01,  4.9979e-02,  5.0000e-04,  5.0000e-06]])


In [57]:
x = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
y = torch.Tensor([[5, 6, 7, 8], [1, 2, 3, 4]])
print(torch.stack((x,y), dim=-1))
print(torch.flatten(torch.stack((x,y), dim=-1), start_dim=-2, end_dim=-1))

tensor([[[1., 5.],
         [2., 6.],
         [3., 7.],
         [4., 8.]],

        [[5., 1.],
         [6., 2.],
         [7., 3.],
         [8., 4.]]])
tensor([[1., 5., 2., 6., 3., 7., 4., 8.],
        [5., 1., 6., 2., 7., 3., 8., 4.]])
