In [1]:
import torch
import einx
import torch.nn as nn

class MyLinear(nn.Module):
    def __init__(self, in_features, out_features, device=None, dtype=None):
        super().__init__()
        w = torch.empty(out_features, in_features)
        std = (2 / (in_features + out_features)) ** 1/2
        nn.init.trunc_normal_(w, mean = 0, std = std, a = -3 * std, b = 3 * std)
        self.mat = nn.Parameter(w)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        return einx.dot("out [in], ... [in] -> ... out", self.mat, x)

class MyEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
        super().__init__()
        w = torch.empty(num_embeddings, embedding_dim)
        nn.init.trunc_normal_(w, mean = 0, std = 1, a = -3 , b = 3)
        self.mat = nn.Parameter(w)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        return self.mat[x]

In [2]:
w = torch.empty(3, 7)
nn.init.trunc_normal_(w, mean = 0, std = 1, a = -3, b = 3)
x = nn.Parameter(w)

x[torch.LongTensor([[1, 2, 2, 2], [0, 0, 0, 0]])][0][1]

tensor([-1.1431, -0.4317, -0.6513, -0.2015,  1.3787, -0.9340,  0.6720],
       grad_fn=<SelectBackward0>)

In [3]:
class MyRMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
        super().__init__()
        gain = torch.ones(d_model)
        self.gain = nn.Parameter(gain)
        self.d_model = d_model
        self.eps = eps
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        in_dtype = x.dtype
        x = x.to(torch.float32)
        RMS_A = x ** 2
        RMS_A = RMS_A.sum(axis=-1)
        print(RMS_A)
        RMS_A /= self.d_model
        print(RMS_A)
        RMS_A += self.eps
        RMS_A = RMS_A ** (1/2)
        print(RMS_A)
        RMS_A = RMS_A.unsqueeze(-1)
        x = x / RMS_A * self.gain
        print(x)
        return x.to(in_dtype)
rms = MyRMSNorm(3)

batch = torch.tensor([[2, 1, 1.], [0, 0, 0]])
rms(batch)

tensor([6., 0.])
tensor([2., 0.])
tensor([1.4142, 0.0032])
tensor([[1.4142, 0.7071, 0.7071],
        [0.0000, 0.0000, 0.0000]], grad_fn=<MulBackward0>)


tensor([[1.4142, 0.7071, 0.7071],
        [0.0000, 0.0000, 0.0000]], grad_fn=<MulBackward0>)

In [4]:
import numpy as np
class MyRope(nn.Module):
    def __init__(self, d_key, theta, max_seq_length, device=None, dtype=None):
        super().__init__()
        seqOfThetaArr = []
        for i in range(max_seq_length):
            thetaArr = []
            k = 1
            for idx in range(d_key):
                tik = i / (theta ** ((2 * k - 2)/d_key))
                if idx % 2 == 0:
                    thetaArr.append(np.array([np.cos(tik), -np.sin(tik)]))
                else:
                    thetaArr.append(np.array([np.sin(tik), np.cos(tik)]))
                    k += 1
            seqOfThetaArr.append(np.stack(thetaArr))
        precompThetaArr = np.stack(seqOfThetaArr)
        print(precompThetaArr)
        self.register_buffer("rotaryTable", torch.tensor(precompThetaArr), persistent=False)
        self.d_key = d_key
        self.evenIndices = [x if x % 2 == 0 else x - 1 for x in range(d_key)]
        self.oddIndices = [x if x % 2 == 1 else x + 1 for x in range(d_key)]

    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        token_positions = token_positions.unsqueeze(-1)
        tablesOfInterest = einx.get_at("[i] k z, ... [idx] -> ... k z", self.rotaryTable, token_positions)
        evenX = x[..., self.evenIndices]
        oddX = x[..., self.oddIndices]
        evenTables = tablesOfInterest[..., 0]
        oddTables = tablesOfInterest[..., 1]

        print(evenX)
        print(evenTables)
        print()
        print(oddX)
        print(oddTables)

        return evenX * evenTables + oddX * oddTables
rope = MyRope(4, 2, 2)
rope(torch.tensor([[0, 1, 2, 3]]), torch.tensor([1,]))

[[[ 1.         -0.        ]
  [ 0.          1.        ]
  [ 1.         -0.        ]
  [ 0.          1.        ]]

 [[ 0.54030231 -0.84147098]
  [ 0.84147098  0.54030231]
  [ 0.7602446  -0.64963694]
  [ 0.64963694  0.7602446 ]]]
tensor([[0, 0, 2, 2]])
tensor([[0.5403, 0.8415, 0.7602, 0.6496]], dtype=torch.float64)

tensor([[1, 1, 3, 3]])
tensor([[-0.8415,  0.5403, -0.6496,  0.7602]], dtype=torch.float64)


tensor([[-0.8415,  0.5403, -0.4284,  3.5800]], dtype=torch.float64)

In [17]:
def softmax(x: torch.Tensor, dim):
    x = x - x.max(dim=dim, keepdim=True).values
    denom = torch.sum(torch.exp(x), dim=dim, keepdim=True)
    x = torch.exp(x) / denom
    return x

softmax(x, 1)

tensor([[0.0066, 0.0179, 0.9756],
        [0.1173, 0.8668, 0.0159]])

In [18]:
x = torch.tensor([
    [0, 1, 5],
    [2, 4, 0]
])

z = x - x.max(dim=0, keepdim=True).values

denom = torch.sum(torch.exp(z), dim=0, keepdim=True)
torch.exp(z) / denom

tensor([[0.1192, 0.0474, 0.9933],
        [0.8808, 0.9526, 0.0067]])

In [19]:
z = x - x.max(dim=1, keepdim=True).values
torch.sum(torch.exp(z), dim=1, keepdim=True)
torch.exp(z) / denom

tensor([[0.0059, 0.0174, 0.9933],
        [0.1192, 0.9526, 0.0182]])

In [53]:
keys = torch.randn(2, 3, 5)
queries = torch.randn(2, 1, 5)


In [69]:
values = torch.tensor([
    [
        [5., 5],
        [-1, -1],
        [-1, -1]
    ],
    [
        [0, 0],
        [1, 1],
        [2, 2]
    ]
])

mask = torch.tensor([[True, True, False]])

def scaledDotProdAttention(queries, keys, values, mask=None):

    presoftAttention = einx.dot("b ... key [dim], b ... quer [dim] -> b ... quer key", keys, queries) / keys.shape[-1] ** 0.5
    if mask is not None:
        presoftAttention[~mask.expand(presoftAttention.shape)] = - torch.inf
    
    softAttention = softmax(presoftAttention, dim = -1)
    return einx.dot("b ... quer [key], b ... [key] d_v -> b ... quer d_v", softAttention, values)
softAttention

tensor([[[0.2156, 0.7844, 0.0000]],

        [[0.6810, 0.3190, 0.0000]]])

tensor([[[0.2933, 0.2933]],

        [[0.3190, 0.3190]]])

In [44]:
z = torch.tensor([0.5, 1, 0.5])
vals = torch.tensor([
    [1, 0.],
    [100, 100],
    [3, 0]
])
einx.dot("[s], [s] k -> k", z, vals, a=1)

tensor([102., 100.])