In [1]:
import torch
import einx
import torch.nn as nn

class MyLinear(nn.Module):
    def __init__(self, in_features, out_features, device=None, dtype=None):
        super().__init__()
        w = torch.empty(out_features, in_features)
        std = (2 / (in_features + out_features)) ** 1/2
        nn.init.trunc_normal_(w, mean = 0, std = std, a = -3 * std, b = 3 * std)
        self.mat = nn.Parameter(w)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        return einx.dot("out [in], ... [in] -> ... out", self.mat, x)

class MyEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
        super().__init__()
        w = torch.empty(num_embeddings, embedding_dim)
        nn.init.trunc_normal_(w, mean = 0, std = 1, a = -3 , b = 3)
        self.mat = nn.Parameter(w)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        return self.mat[x]

In [2]:
w = torch.empty(3, 7)
nn.init.trunc_normal_(w, mean = 0, std = 1, a = -3, b = 3)
x = nn.Parameter(w)

x[torch.LongTensor([[1, 2, 2, 2], [0, 0, 0, 0]])][0][1]

tensor([-1.1431, -0.4317, -0.6513, -0.2015,  1.3787, -0.9340,  0.6720],
       grad_fn=<SelectBackward0>)

In [3]:
class MyRMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
        super().__init__()
        gain = torch.ones(d_model)
        self.gain = nn.Parameter(gain)
        self.d_model = d_model
        self.eps = eps
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        in_dtype = x.dtype
        x = x.to(torch.float32)
        RMS_A = x ** 2
        RMS_A = RMS_A.sum(axis=-1)
        print(RMS_A)
        RMS_A /= self.d_model
        print(RMS_A)
        RMS_A += self.eps
        RMS_A = RMS_A ** (1/2)
        print(RMS_A)
        RMS_A = RMS_A.unsqueeze(-1)
        x = x / RMS_A * self.gain
        print(x)
        return x.to(in_dtype)
rms = MyRMSNorm(3)

batch = torch.tensor([[2, 1, 1.], [0, 0, 0]])
rms(batch)

tensor([6., 0.])
tensor([2., 0.])
tensor([1.4142, 0.0032])
tensor([[1.4142, 0.7071, 0.7071],
        [0.0000, 0.0000, 0.0000]], grad_fn=<MulBackward0>)


tensor([[1.4142, 0.7071, 0.7071],
        [0.0000, 0.0000, 0.0000]], grad_fn=<MulBackward0>)

In [105]:
import numpy as np
class MyRope(nn.Module):
    def __init__(self, d_key, theta, max_seq_length, device=None, dtype=None):
        super().__init__()
        seqOfThetaArr = []
        for i in range(max_seq_length):
            thetaArr = []
            k = 1
            for idx in range(d_key):
                tik = i / (theta ** ((2 * k - 2)/d_key) )
                if idx % 2 == 0:
                    thetaArr.append(np.array([np.cos(tik), -np.sin(tik)]))
                else:
                    thetaArr.append(np.array([np.sin(tik), np.cos(tik)]))
                    k += 1
            seqOfThetaArr.append(np.stack(thetaArr))
        precompThetaArr = np.stack(seqOfThetaArr)
        self.register_buffer("rotaryTable", torch.tensor(precompThetaArr), persistent=False)
        self.d_key = d_key
        self.evenIndices = [x if x % 2 == 0 else x - 1 for x in range(d_key)]
        self.oddIndices = [x if x % 2 == 1 else x + 1 for x in range(d_key)]

    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        token_positions = token_positions.unsqueeze(-1)
        tablesOfInterest = einx.get_at("[i] k z, ... [idx] -> ... k z", self.rotaryTable, token_positions)
        evenX = x[..., self.evenIndices]
        oddX = x[..., self.oddIndices]
        evenTables = tablesOfInterest[..., 0]
        oddTables = tablesOfInterest[..., 1]

        return (evenX * evenTables + oddX * oddTables).to(x.dtype)
rope = MyRope(4, 2, 2)
rope(torch.tensor([[0, 1, 2, 3]]), torch.tensor([1,]))

tensor([[0, 0, 0, 3]])

In [17]:
def softmax(x: torch.Tensor, dim):
    x = x - x.max(dim=dim, keepdim=True).values
    denom = torch.sum(torch.exp(x), dim=dim, keepdim=True)
    x = torch.exp(x) / denom
    return x

softmax(x, 1)

tensor([[0.0066, 0.0179, 0.9756],
        [0.1173, 0.8668, 0.0159]])

In [18]:
x = torch.tensor([
    [0, 1, 5],
    [2, 4, 0]
])

z = x - x.max(dim=0, keepdim=True).values

denom = torch.sum(torch.exp(z), dim=0, keepdim=True)
torch.exp(z) / denom

tensor([[0.1192, 0.0474, 0.9933],
        [0.8808, 0.9526, 0.0067]])

In [19]:
z = x - x.max(dim=1, keepdim=True).values
torch.sum(torch.exp(z), dim=1, keepdim=True)
torch.exp(z) / denom

tensor([[0.0059, 0.0174, 0.9933],
        [0.1192, 0.9526, 0.0182]])

In [53]:
keys = torch.randn(2, 3, 5)
queries = torch.randn(2, 1, 5)


In [81]:
values = torch.tensor([
    [
        [5., 5],
        [-1, -1],
        [-1, -1]
    ],
    [
        [0, 0],
        [1, 1],
        [2, 2]
    ]
])

mask = torch.tensor([[True, True, False]])

def scaledDotProdAttention(queries, keys, values, mask=None):

    presoftAttention = einx.dot("b ... key [dim], b ... quer [dim] -> b ... quer key", keys, queries) / keys.shape[-1] ** 0.5
    if mask is not None:
        presoftAttention[~mask.expand(presoftAttention.shape)] = - torch.inf
    
    softAttention = softmax(presoftAttention, dim = -1)
    return einx.dot("b ... quer [key], b ... [key] d_v -> b ... quer d_v", softAttention, values)
softAttention

tensor([[[0.2156, 0.7844, 0.0000]],

        [[0.6810, 0.3190, 0.0000]]])

In [70]:
einx.dot("b ... quer [key], b ... [key] d_v -> b ... quer d_v", softAttention, values)

tensor([[[0.2933, 0.2933]],

        [[0.3190, 0.3190]]])

In [102]:
fakeAttention = torch.tensor([0.5, 1, 0.5])
vals = torch.tensor([
    [1, 0.],
    [100, 100],
    [3, 0]
])
einx.dot("[s], [s] k -> k", fakeAttention, vals, a=1)

tensor([102., 100.])

In [99]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads, device=None, dtype=None):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.dk = d_model // num_heads

        self.Wq = MyLinear(d_model, d_model)
        self.Wk = MyLinear(d_model, d_model)
        self.Wv = MyLinear(d_model, d_model)
        self.Wo = MyLinear(d_model, d_model)
    def forward(self, x):
        seq = x.shape[-2]
        queries = self.Wq(x)
        keys = self.Wk(x)
        values = self.Wv(x)
        queries = einx.rearrange("... seq (heads dk) -> ... heads seq dk", queries, dk=self.dk)
        keys = einx.rearrange("... seq (heads dk) -> ... heads seq dk", keys, dk=self.dk)
        values = einx.rearrange("... seq (heads dk) -> ... heads seq dk", values, dk=self.dk)
        mask = torch.tril(torch.ones((seq, seq))).bool()
        attended = scaledDotProdAttention(queries, keys, values, mask)
        multiHead = einx.rearrange("... heads seq dk -> ... seq (heads dk)", attended)
        return self.Wo(multiHead)

mhsa = MultiHeadSelfAttention(8, 2)
mhsa(torch.randn(3, 5, 8))

tensor([[[-5.6356e-03,  2.7596e-03,  1.3597e-02,  1.7189e-02,  1.4190e-02,
           6.9568e-03,  3.5568e-02,  3.4001e-03],
         [ 2.8591e-02,  4.3619e-03,  2.5728e-02, -9.0700e-03,  1.2817e-03,
           1.8614e-03,  3.1938e-02, -4.8912e-03],
         [-1.6222e-03,  2.9483e-03,  3.3452e-04,  7.3399e-03,  4.7355e-03,
          -3.2652e-05,  9.7079e-03, -6.0467e-03],
         [-9.9415e-03,  4.8823e-03, -3.5935e-03,  9.5569e-03,  5.3491e-03,
           3.8651e-03,  8.7364e-03, -2.3996e-05],
         [-7.2212e-03,  7.8951e-03, -1.1571e-05,  7.1348e-03,  2.5053e-03,
           1.6936e-03,  1.1393e-02,  5.0010e-03]],

        [[ 5.6048e-02, -4.6875e-02,  6.7406e-02, -1.8357e-02,  5.8747e-03,
          -4.3492e-02,  1.8585e-02, -2.8123e-02],
         [ 1.4272e-02, -3.6911e-02,  2.9286e-02, -1.7568e-03,  6.8680e-03,
          -2.8475e-02, -5.3976e-03, -1.4381e-02],
         [ 5.8715e-03, -2.1917e-02,  1.2103e-02, -3.2530e-03,  5.3082e-03,
          -1.2299e-02, -9.1095e-03, -1.4640e-02]

In [106]:
class MultiHeadSelfAttentionRope(nn.Module):
    def __init__(self, d_model, num_heads, max_seq_len, theta, device=None, dtype=None):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.dk = d_model // num_heads

        self.rope = MyRope(self.dk, theta, max_seq_len)

        self.Wq = MyLinear(d_model, d_model)
        self.Wk = MyLinear(d_model, d_model)
        self.Wv = MyLinear(d_model, d_model)
        self.Wo = MyLinear(d_model, d_model)
    def forward(self, x, token_positions):
        seq = x.shape[-2]
        queries = self.Wq(x)
        keys = self.Wk(x)
        values = self.Wv(x)
        queries = einx.rearrange("... seq (heads dk) -> ... heads seq dk", queries, dk=self.dk)
        keys = einx.rearrange("... seq (heads dk) -> ... heads seq dk", keys, dk=self.dk)
        values = einx.rearrange("... seq (heads dk) -> ... heads seq dk", values, dk=self.dk)

        queries = self.rope(queries, token_positions.unsqueeze(-2).expand(keys.shape[:-1]))
        keys = self.rope(keys, token_positions.unsqueeze(-2).expand(keys.shape[:-1]))
        
        mask = torch.tril(torch.ones((seq, seq))).bool()

        print(queries.dtype)
        attended = scaledDotProdAttention(queries, keys, values, mask)
        multiHead = einx.rearrange("... heads seq dk -> ... seq (heads dk)", attended)
        return self.Wo(multiHead)
mhsar = MultiHeadSelfAttentionRope(8, 2, 5, 1e-5)
mhsar(torch.randn(3, 5, 8), torch.arange(5).expand(torch.randn(3, 5, 8).shape[:-1]))

torch.float32


tensor([[[-0.0183,  0.0232, -0.0330,  0.0124,  0.0167, -0.0213, -0.0121,
           0.0343],
         [-0.0160,  0.0069, -0.0132, -0.0006, -0.0093, -0.0103,  0.0054,
           0.0082],
         [ 0.0025,  0.0078,  0.0001, -0.0079, -0.0146, -0.0042,  0.0045,
           0.0096],
         [-0.0040, -0.0073,  0.0083, -0.0046, -0.0153,  0.0078,  0.0082,
          -0.0047],
         [-0.0104, -0.0101,  0.0114, -0.0054, -0.0137,  0.0085,  0.0082,
          -0.0050]],

        [[ 0.0026,  0.0199, -0.0410,  0.0478,  0.0391, -0.0033, -0.0245,
           0.0200],
         [-0.0072,  0.0125, -0.0237,  0.0472,  0.0413, -0.0006, -0.0246,
           0.0041],
         [ 0.0085,  0.0062, -0.0120,  0.0297,  0.0246,  0.0077, -0.0151,
           0.0072],
         [ 0.0037,  0.0073, -0.0027,  0.0286,  0.0210,  0.0057, -0.0162,
          -0.0012],
         [ 0.0022,  0.0048, -0.0017,  0.0251,  0.0187,  0.0066, -0.0181,
          -0.0027]],

        [[ 0.0132, -0.0215,  0.0254, -0.0368, -0.0446,  0.0126,  0

In [108]:
class MySwiGLU(nn.Module):
    def __init__(self, d_model, d_ff = None, device=None, dtype=None):
        super().__init__()
        if d_ff is None:
            d_ff = int(round(8 / 3 * d_model / 64) * 64)
        self.d_ff = d_ff

        self.w1 = MyLinear(d_model, d_ff)
        self.w2 = MyLinear(d_ff, d_model)
        self.w3 = MyLinear(d_model, d_ff)

        
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w1x = self.w1(x)
        w3x = self.w3(x)
        elwiseProd = self.SILU(w1x) * w3x
        out = self.w2(elwiseProd)
        return out

    def SILU(self, x: torch.Tensor):
        return x * torch.sigmoid(x)
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, max_seq_len, theta):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.max_seq_len = max_seq_len
        self.theta = theta

        self.RMS1 = MyRMSNorm(d_model)
        self.RMS2 = MyRMSNorm(d_model)

        self.mhsar = MultiHeadSelfAttentionRope(d_model, num_heads, max_seq_len, theta)
        self.ffwd = MySwiGLU(d_model, d_ff)

    def forward(self, x):
        token_positions = torch.arange(x.shape[-2]).expand(x.shape[:-1])
        y = x + self.mhsar(self.RMS1(x), token_positions)
        print("y Shape", y.shape)
        z = y + self.ffwd(self.RMS2(y))
        return z

In [109]:
t = MySwiGLU(64, 128)
t(torch.randn(4, 12, 64))

tensor([[[-5.5784e-05,  2.1722e-05, -4.0647e-05,  ...,  4.3912e-06,
           8.6138e-06, -4.3669e-05],
         [-1.0096e-04, -3.7781e-06,  3.6266e-05,  ..., -1.0608e-04,
           3.1835e-06,  2.4812e-06],
         [ 2.5156e-05, -6.6149e-05, -5.3761e-05,  ...,  1.9268e-05,
          -3.5005e-06, -6.9269e-05],
         ...,
         [ 2.8158e-05, -5.7070e-05,  8.8641e-05,  ...,  3.2948e-05,
          -1.7302e-05,  1.3374e-05],
         [ 6.1499e-05, -3.5933e-05, -3.2347e-05,  ..., -5.6971e-05,
          -3.9987e-05,  4.1134e-05],
         [ 3.7481e-05,  2.4902e-05,  1.1735e-05,  ...,  8.8581e-06,
          -2.2123e-05, -6.4748e-05]],

        [[-3.7254e-05, -2.6964e-05,  1.0151e-04,  ..., -5.7524e-06,
          -1.5768e-06, -2.1357e-05],
         [-2.7622e-05, -9.3886e-06,  2.0122e-05,  ...,  5.3260e-05,
          -1.4699e-05,  2.7704e-05],
         [ 4.4525e-05, -1.8425e-05,  4.2244e-05,  ..., -5.1800e-05,
          -9.2712e-06,  3.7986e-05],
         ...,
         [ 3.0380e-05, -1

In [None]:
class Transfomer(nn.Module):
    def __init__(self, vocab_size, context_length, num_layers, d_model, num_heads, d_ff, theta):
        super().__init__()
        self.vocab_size = vocab_size
        self.context_length = context_length
        self.embedding = MyEmbedding(vocab_size, d_model)
        self.transformerBlocks = nn.ModuleList([TransformerBlock(d_model, num_heads, d_ff, context_length, theta) for _ in range(num_layers)])
        self.rmsNorm = MyRMSNorm(d_model)
        self.outFwd = MyLinear(d_model, vocab_size)
    def forward(self, tokIds):
        emb = self.embedding(tokIds)
        for block in self.transformerBlocks:
            emb = block(emb)
        emb = self.rmsNorm(emb)
        emb = self.outFwd(emb)
        emb = softmax(emb, dim=-1)
        return emb
        