**OutLine**


Input
  ↓
LayerNorm
  ↓
Multi-Head Self-Attention
  ↓
Residual Add
  ↓
LayerNorm
  ↓
MLP (Feed Forward)
  ↓
Residual Add


**Step 1: Scaled Dot-Product Attention (from scratch)**

In [3]:
import torch 
import torch.nn as nn
import torch.nn.functional as F


In [4]:
class SelfAttention(nn.Module):
    def __init__(self,embed_dim):
        super().__init__()
        self.embed_dim=embed_dim
        self.q=nn.Linear(embed_dim,embed_dim)
        self.k=nn.Linear(embed_dim,embed_dim)
        self.v=nn.Linear(embed_dim,embed_dim)


        self.scale=embed_dim ** 0.5


    def forward(self,x):
        Q=self.q(x)
        K=self.k(x)
        V=self.v(x)


        score=torch.matmul(Q,K.transpose(-2,-1))/self.scale


        attn=F.softmax(scores,dim=-1)
        out=torch.matmul(attn,V)


        return out

In [None]:
class Attention(nn.Module):
    def __init__(self,embed_dim):
        super().__init__()
        self.embed_dim=embed_dim
        self.q=nn.Linear(embed_dim,embed_dim)
        self.k=nn.Linear(embed_dim,embed_dim)
        self.v=nn.Linear(embed_dim,embed_dim)


        self.scale=embed_dim ** 0.5


    def forward(self,x):
        Q=self.q(x)
        K=self.k(x)
        V=self.v(x)


        score=torch.matmul(Q,K.transpose(-2,-1))/self.scale


        attn=F.softmax(score,dim-1)
        out=torch.matmul(attn,V)


        return out
        

**Step 2: Multi-Head Attention (manual, clean)**

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self,embed_dim,num_heads):
        super().__init__()


        assert embed_dim % num_heads == 0


        self.embed_dim=embed_dim
        self.num_heads=num_heads
        self.head_dim=embed_dim // num_heads

        self.qkv=nn.Linear(embed_dim,embed_dim*3)
        self.out=nn.Linear(embed_dim,embed_dim)


    def forward(self,x):
        B,N,D=x.shape


        qkv=self.qkv(x)
        qkv=qkv.reshape(B,N,3,self.num_heads,self.head_dim)
        qkv=qkv.permute(2,0,3,1,4)


        Q,K,V=qkv[0],qkv[1],qkv[2]


        scores=(Q @ K.transpose(-2,-1))/(self.head_dim ** 0.5)
        attn=F.softmax(scores,dim=-1)


        out=attn @ V

        out=out.transpose(1,2).reshape(B,N,D)


        return self.out(out)

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self,embed_dim,num_heads):
        super().__init__()
        assert embed_dim%num_heads==0


        self.embed_dim=embed_dim
        self.num_heads=num_heads
        self.heads_dim=embed_dim//num_heads

        self.qkv=nn.Linear(embed_dim,embed_dim*3)
        self.out=nn.Linear(embed_dim,embed_dim)


    def forward(self,x):
        B,N,D=x.shape


        self.qkv=self.qkv(x)
        self.qkv=qkv.reshape(B,N,3,self.num_heads,self.head_dim)
        self.qkv=qkv.permute(2,0,4,1,4)


        Q,K,V=qkv[0],qkv[1],qkv[2]


        score=torch.matmul(Q @ K.transpose(-2,-1))/(self.num_heads ** 0.5)
        attn=F.softmax(score,dim=1)


        out= attn @ V
        out=out.transpose(-2,-1).reshape(B,N,D)


        return self.out(out)

*Step 3: Feed Forward Network (MLP block)*

In [7]:
class FeedForward(nn.Module):
    def __init__(self,embed_dim,hidden_dim):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(embed_dim,hidden_dim),
            nn.Gelu(),
            nn.Linear(hidden_dim,embed_dim)
        )


    def forward(self,x):
        return self.net(x)


In [None]:
class FeedForward(nn.Module):
    def __init__(self,embed_dim,hidden_dim):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(embed_dim,hidden_dim)
            nn.Gelu(),
            nn.Linear(hidden_dim,hidden_dim)
        )


    def forward(self,x):
        return self.net(x)

**Step 4: Transformer Encoder Block (THIS IS HUGE)**

In [8]:
class TransformerBlock(nn.Module):
    def __init__(self,embed_dim,num_heads,mlp_ratio=4):
        super().__init__()
        self.norm1=nn.LayerNorm(embed_dim)
        self.attn=MultiHeadAttention(embed_dim,num_heads)


        self.norm2=nn.LayerNorm(embed_dim)
        self.mlp=FeedForward(embed_dim,embed_dim*mlp_ratio)


    def forward(self,x):
        x=x + self.attn(self.norm1(x))

        x=x+self.mlp(self.norm2(x))

        return x

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self,embed_dim,num_heads,mlp_ratio=4):
        super().__init__()


        self.norm1=nn.LayerNorm(embed_dim)
        self.attn=MultiHeadedAttention(embed_dim,num_heads)


        self.norm2=nn.LayerNorm(embed_dim)
        self.mlp=FeedForward(embed_dim,embed_dim*mlp_ratio)

    def forward(self,x):
        x=x + self.attn(self.norm1(x))
        x=x + self.mlp(self.norm2(x))


        return x

**Step 5: Stack blocks → Transformer Encoder**

In [9]:
class TransformerEncoder(nn.Module):
    def __init__(self,depth,embed_dim,num_heads):
        super().__init__()


        self.layers=nn.ModuleList([
            TransformerBlock(embed_dim,num_heads)
            for _ in range(depth)
        ])


    def forward(self,x):
        for layer in self.layers:
            x=layer(x)

        return x

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self,depth,embed_dim,num_heads):
        super().__init__()


        self.layers=nn.ModuleList([
            TransformerBlock(embed_dim,num_heads)
            for _ in range(depth)
        ])

    def forward(self,x):
        for layers in self.layers:
            x=layer(x)

        return x