In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as Fn
from transformers import AutoModel, AutoTokenizer


To implement
- rot Q and K
- SwigLu
- RMS Norm
- Gating after Attention (G1 from Gating attention paper)
- Other Ideas (Mainfold Constrained Hyper connections from DeepSeek, this might be compute Intensive first try imlementing the above)
- how can you inculcate (5 brain freq. ideas)

In [29]:
def rotFrequency(headDim, seqLen, theta=10000.0):
    assert headDim % 2 == 0
    inv_freq = 1.0 / (
        theta ** (torch.arange(0, headDim, 2).float() / headDim)
    )
    positions = torch.arange(seqLen).float()  
    freqs = torch.outer(positions, inv_freq)

    return freqs

freq = rotFrequency(512, 4096)
freq.shape

torch.Size([4096, 256])

In [30]:
def ropEQK(x, freqs):
    cos = freqs.cos()[None, :, None, :]
    sin = freqs.sin()[None, :, None, :]

    evenVals = x[..., 0::2]
    oddVals  = x[..., 1::2]

    rotated = torch.cat(
        [
            evenVals * cos - oddVals * sin,
            evenVals * sin + oddVals * cos
        ],
        dim=-1
    )

    return rotated

q = torch.randn(1, 4096, 16, 512)
qRot = ropEQK(q, freq)
qRot.shape

torch.Size([1, 4096, 16, 512])

In [None]:
class RotatoryQK(nn.Module):

    def __init__(self, hiddenDim, seqLen, theta = 10000.0):
        super().__init__()
        self.invFreq = 1.0 / (
            theta ** (torch.arange(0, headDim, 2).float() / headDim)
        )
        self.positions = torch.arange(seqLen).float()
        self.freqs = torch.outer(positions, inv_freq)

    def forward(self, x):
        cos = self.freqs.cos()[None, :, None, :]
        sin = self.freqs.sin()[None, :, None, :]

        evenVals = x[..., 0::2]
        oddVals  = x[..., 1::2]
        
        rotated = torch.cat(
            [
                evenVals * cos - oddVals * sin,
                evenVals * sin + oddVals * cos
            ],
            dim=-1
        )
        return rotated


In [31]:
class RMSNorm(nn.Module):
    def __init__(self, dimension):
        super().__init__()
        self.eps = 1e-6
        self.weight = nn.Parameter(torch.ones(dimension))

    def forward(self, x):
        xNorm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps)
        normalized = self.weight * xNorm 
        return normalized.type_as(x)

rmsN = RMSNorm(512)
embeddings = torch.randn(2, 4096, 512)
norm = rmsN(embeddings)
norm.shape

torch.Size([2, 4096, 512])

In [32]:
class FeedForwardLayer(nn.Module):
    def __init__(self, dimension=2048, latentDim = 4096, hidden_dim=None):
        super().__init__()
        
        if hidden_dim is None:
            hidden_dim = 4 * dimension
        
        self.w1 = nn.Linear(dimension, hidden_dim)  
        self.w2 = nn.Linear(hidden_dim, latentDim)  
        self.w3 = nn.Linear(dimension, hidden_dim)  

    def forward(self, x):
        swish = Fn.silu(self.w1(x))  
        xV = self.w3(x)
        x = swish * xV
        
        x = self.w2(x)
        return x



ffn = FeedForwardLayer(dimension=2048)
x = torch.randn(2, 1024, 2048)
out = ffn(x)
out.shape

torch.Size([2, 1024, 4096])

In [33]:
class MultiHeadSelfAttentionGating(nn.Module):
    def __init__(self, embedDimension, numHeads, dropout = 0.2):
        super().__init__()

        assert embedDimension%numHeads == 0, "Embedding Dimension is Not Divisible By NumHeads"
        self.embedDimension = embedDimension
        self.numHeads = numHeads
        self.headDim = embedDimension//numHeads

        self.queryKeyValue = nn.Linear(embedDimension, embedDimension * 3, bias=False)
        self.drop = nn.Dropout(dropout)
        self.gateProj = nn.Linear(embedDimension, embedDimension, bias = False)
        self.scale = self.headDim ** -0.5 
        self.outProjection = nn.Linear(embedDimension, embedDimension)

        nn.init.xavier_uniform_(self.queryKeyValue.weight)
        nn.init.xavier_uniform_(self.outProjection.weight)

    def forward(self, x):
        BatchSize, seqLen, EmbedDim = x.shape

        qkv = self.queryKeyValue(x)
        qkv = qkv.reshape(BatchSize, seqLen, 3, self.numHeads, EmbedDim // self.numHeads)
        # rotFrequency(512, 4096)

        q, k, v = qkv.unbind(2)
        frequencies = rotFrequency(self.headDim, seqLen)
        qRot = ropEQK(q, frequencies)
        kRot = ropEQK(k, frequencies)

        q = qRot.transpose(1, 2)
        k = kRot.transpose(1, 2)
        v = v.transpose(1, 2)


        attentionScore = (q @ k.transpose(-2, -1)) * self.scale
        att = attentionScore.softmax(dim=-1)
        out = att @ v 

        # print(out.shape)
        gate = torch.sigmoid(self.gateProj(x))
        gate = gate.reshape(BatchSize, seqLen, self.numHeads, self.headDim).transpose(1, 2)
        # print(gate.shape, out.shape)

        out = out * gate

        out = out.transpose(1, 2).reshape(BatchSize, seqLen, EmbedDim)
        out = self.outProjection(out)
        out = self.drop(out)
        return out
    
mhsa = MultiHeadSelfAttentionGating(embedDimension = 2048, numHeads = 16)
x = torch.randn(2, 64, 2048)
out = mhsa(x)
out.shape

torch.Size([2, 64, 2048])

In [34]:
modelName = "nomic-ai/nomic-embed-text-v1.5" #"nomic-ai/nomic-embed-text-v1"
tokenizer = AutoTokenizer.from_pretrained(modelName)
model = AutoModel.from_pretrained(modelName, trust_remote_code=True)

texts = [
    "To dos walking towards each other",
    "All Dogs are playing in the Garden",
    "Generate Image of Dog"
]

inputs = tokenizer(texts, return_tensors='pt', padding='max_length', truncation=True, max_length=512)

with torch.no_grad():
    outputs = model(**inputs)
    embeddings2 = outputs.last_hidden_state

embeddings2.shape



torch.Size([3, 512, 768])

In [41]:
class TextToImageLatentModel(nn.Module):

    def __init__(self, embedDimension = 2048, textEmbed = 768,  numHeads = 16, latentDim = 2048):
        super().__init__()

        self.embedDimension = embedDimension
        self.numHeads = numHeads
        self.latentDim = latentDim

        self.mhsa = MultiHeadSelfAttentionGating(self.embedDimension, self.numHeads)
        
        self.rmsNorm = RMSNorm(self.embedDimension)

        self.feedForward = FeedForwardLayer(dimension = embedDimension, latentDim =  latentDim)

    def forward(self, x):
        batchSize, seqlen, _ = x.shape

        x1 = x
        x = self.rmsNorm(x1)

        x = self.mhsa(x)
        
        x1 = x1 + x

        x = self.rmsNorm(x1)
        x = self.feedForward(x)
        # print(x.shape, x1.shape)
        x = x1 + x

        return x

t2iM = TextToImageLatentModel()
x = torch.randn(2, 512, 2048)
out = t2iM(x)
out.shape

torch.Size([2, 512, 2048])

In [47]:
class NLayerT2I(nn.Module):

    def __init__(self, embedDimension = 2048, textEmbed = 768,  numHeads = 16, outDimension = 4096, nBlocks = 6):
        super().__init__()
        
        self.embedDimension = embedDimension
        self.numHeads = numHeads
        self.textEmbed = textEmbed
        self.outDimension = outDimension

        self.linearTextProjection = nn.Linear(textEmbed, embedDimension)

        TextToImageLatentModel(embedDimension = self.embedDimension, textEmbed = self.textEmbed,  numHeads = self.numHeads, latentDim = 2048)

        self.nAttentionBlocks = nn.ModuleList([
            TextToImageLatentModel(embedDimension = self.embedDimension, textEmbed = self.textEmbed,  numHeads = self.numHeads, latentDim = 2048)
            for _ in range(nBlocks)
        ])

        self.rmsNorm2 = RMSNorm(self.embedDimension)
        self.outputLayer = nn.Linear(embedDimension, outDimension)

    
    def forward(self, x):

        batchSize, seqlen, _ = x.shape

        x = self.linearTextProjection(x)

        for block in self.nAttentionBlocks:
            x = block(x)

        x = self.rmsNorm2(x)
        x = x.mean(1)
        x = self.outputLayer(x)
        x = x.reshape(batchSize, self.outDimension)

        return x


finalModel = NLayerT2I()
x = torch.randn(2, 512, 768)
out = finalModel(x)
out.shape

torch.Size([2, 4096])

In [45]:
print(f"Total Parameters: {sum(p.numel() for p in finalModel.parameters() if p.requires_grad)}")

Total Parameters: 437923840
