In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as Fn

To implement
- rot Q and K
- SwigLu
- RMS Norm
- Gating after Attention (G1 from Gating attention paper)
- Other Ideas (Mainfold Constrained Hyper connections from DeepSeek, this might be compute Intensive first try imlementing the above)
- how can you inculcate (5 brain freq. ideas)

In [15]:
def rotFrequency(headDim, seqLen, theta=10000.0):
    assert headDim % 2 == 0
    inv_freq = 1.0 / (
        theta ** (torch.arange(0, headDim, 2).float() / headDim)
    )
    positions = torch.arange(seqLen).float()  
    freqs = torch.outer(positions, inv_freq)

    return freqs

freq = rotFrequency(512, 4096)
freq.shape

torch.Size([4096, 256])

In [16]:
def ropEQK(x, freqs):
    cos = freqs.cos()[None, :, None, :]
    sin = freqs.sin()[None, :, None, :]

    evenVals = x[..., 0::2]
    oddVals  = x[..., 1::2]

    rotated = torch.cat(
        [
            evenVals * cos - oddVals * sin,
            evenVals * sin + oddVals * cos
        ],
        dim=-1
    )

    return rotated

q = torch.randn(1, 4096, 16, 512)
qRot = ropEQK(q, freq)
qRot.shape

torch.Size([1, 4096, 16, 512])

In [17]:
class RMSNorm(nn.Module):
    def __init__(self, dimension):
        super().__init__()
        self.eps = 1e-6
        self.weight = nn.Parameter(torch.ones(dimension))

    def forward(self, x):
        xNorm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps)
        normalized = self.weight * xNorm 
        return normalized.type_as(x)

rmsN = RMSNorm(512)
embeddings = torch.randn(2, 4096, 512)
norm = rmsN(embeddings)
norm.shape

torch.Size([2, 4096, 512])

In [30]:
class FeedForwardLayer(nn.Module):
    def __init__(self, dimension=2048, latentDim = 4096, hidden_dim=None):
        super().__init__()
        
        if hidden_dim is None:
            hidden_dim = 4 * dimension
        
        self.w1 = nn.Linear(dimension, hidden_dim)  
        self.w2 = nn.Linear(hidden_dim, latentDim)  
        self.w3 = nn.Linear(dimension, hidden_dim)  

    def forward(self, x):
        x = x.mean(1)
        swish = Fn.silu(self.w1(x))  
        xV = self.w3(x)
        x = swish * xV
        
        x = self.w2(x)
        return x


ffn = FeedForwardLayer(dimension=2048)
x = torch.randn(2, 1024, 2048)
out = ffn(x)
out.shape

torch.Size([2, 4096])

In [50]:
class MultiHeadSelfAttentionGating(nn.Module):
    def __init__(self, embedDimension, numHeads, dropout = 0.2):
        super().__init__()

        assert embedDimension%numHeads == 0, "Embedding Dimension is Not Divisible By NumHeads"
        self.embedDimension = embedDimension
        self.numHeads = numHeads
        self.headDim = embedDimension//numHeads

        self.queryKeyValue = nn.Linear(embedDimension, embedDimension * 3, bias=False)
        self.drop = nn.Dropout(dropout)
        self.gateProj = nn.Linear(embedDimension, embedDimension, bias = False)
        self.scale = self.headDim ** -0.5 
        self.outProjection = nn.Linear(embedDimension, embedDimension)

        nn.init.xavier_uniform_(self.queryKeyValue.weight)
        nn.init.xavier_uniform_(self.outProjection.weight)

    def forward(self, x):
        BatchSize, N, EmbedDim = x.shape

        qkv = self.queryKeyValue(x)
        qkv = qkv.reshape(BatchSize, N, 3, self.numHeads, EmbedDim // self.numHeads)
        q, k, v = qkv.unbind(2)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # print(q.shape, k.shape, v.shape)

        attentionScore = (q @ k.transpose(-2, -1)) * self.scale
        att = attentionScore.softmax(dim=-1)
        out = att @ v 

        # print(out.shape)
        gate = torch.sigmoid(self.gateProj(x))
        gate = gate.reshape(BatchSize, N, self.numHeads, self.headDim).transpose(1, 2)
        # print(gate.shape, out.shape)

        out = out * gate

        out = out.transpose(1, 2).reshape(BatchSize, N, EmbedDim)
        out = self.outProjection(out)
        out = self.drop(out)
        return out
    
mhsa = MultiHeadSelfAttentionGating(embedDimension = 2048, numHeads = 16)
x = torch.randn(2, 64, 2048)
out = mhsa(x)
out.shape

torch.Size([2, 64, 2048])