In [123]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


using PreNorm instead of AddNorm

In [124]:
  
class PreNorm(nn.Module):
    def __init__(self, d_model, sublayer, dropout_rate=0.1):
        super(PreNorm, self).__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_rate)
        self.sublayer = sublayer

    def forward(self, x, key_input=None, value_input=None):
        normalized_output = self.norm(x)
        if key_input is None and value_input is None:
            sublayer_output = self.sublayer(normalized_output)
        else:
            sublayer_output = self.sublayer(normalized_output, key_input=key_input, value_input=value_input)
        dropped_output = self.dropout(sublayer_output)
        return x + dropped_output


using GLU instead of Position wise FFN

using GELU instead of ReLU

In [125]:
class GLUFeedForward(nn.Module):
    def __init__(self, d_model, d_ff=None, dropout_rate = 0.1):
        super(GLUFeedForward, self).__init__()
        d_ff = d_ff or d_model
        self.gate_proj = nn.Linear(d_model , d_ff)
        self.value_proj = nn.Linear(d_model, d_ff)
        self.output_proj = nn.Linear(d_ff, d_model)

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        gate = self.gate_proj(x)
        value = self.value_proj(x)

        x = F.gelu(gate)*value
        x = self.dropout(x)

        #project back to d_model
        x = self.output_proj(x)
        return x

updating Multi head attention with FAVOR+ Fast attention via positive orthogonal random features kindof linear attention

In [126]:
 def linear_attention(self, Q, K, V):
        # Q, K, V : shape = (batch_size, num_heads, seq_len, head_dim)
        batch_size, num_heads, seq_len, head_dim = Q.shape

        # Random feature projection one per head
        if not hasattr(self, 'rand_projs'):
            self.rand_projs = orthogonal_random_features(
                head_dim, self.num_heads, self.num_rfs, Q.device
            )

        # Projecting Q and K through random features
        Q_rand = torch.einsum('bhsd,hrd->bhsr', Q, self.rand_projs)  # (b, h, s, r)
        K_rand = torch.einsum('bhsd,hrd->bhsr', K, self.rand_projs)  # (b, h, s, r)

In [127]:
def orthogonal_random_features(dim, num_heads, num_rfs, device):
    # Initialize random features for each head
    rand_proj = torch.randn(num_heads, num_rfs, dim, device=device)
    q, _ = torch.linalg.qr(rand_proj.transpose(-2, -1))
    return q.transpose(-2, -1) 


def elu_kernal(x):
    return F.elu(x) + 1

class LinearAttention(nn.Module):
    def __init__(self, d_model , num_heads, dropout_rate = 0.1, num_rfs = 10):
        super(LinearAttention, self).__init__()
        assert d_model % num_heads ==0

        self.d_k =d_model // num_heads
        self.d_model = d_model
        self.num_rfs = num_rfs # number of random features
        self.num_heads = num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout_rate)

        #Learnable scale Parameter for queries / keys
        self.scale = nn.Parameter(torch.tensor(d_model ** -0.5))

    def linear_attention(self, Q, K, V):
        # Q, K, V : shape = (batch_size, num_heads, seq_len, head_dim)
        batch_size, num_heads, seq_len, head_dim = Q.shape

        # Random feature projection one per head
        if not hasattr(self, 'rand_projs'):
            self.rand_projs = orthogonal_random_features(
                head_dim, self.num_heads, self.num_rfs, Q.device
            )

        # Projecting Q and K through random features
        Q_rand = torch.einsum('bhsd,hrd->bhsr', Q, self.rand_projs)  # (b, h, s, r)
        K_rand = torch.einsum('bhsd,hrd->bhsr', K, self.rand_projs)  # (b, h, s, r)

        Q_rand = torch.matmul(Q, self.rand_projs[:num_heads])  #(b, h, s, 1, r)
        K_rand = torch.matmul(K, self.rand_projs[:num_heads]) #(b,h,s,1,r)

        #applying kernal function
        Q_feat = elu_kernal(Q_rand).squeeze(-2)    #(b, h, s, r)
        K_feat = elu_kernal(K_rand).squeeze(-2)  #(b, h, s, r)

        #compute KV numerator: (b,h,r,d)
        K_feat_V = torch.einsum('bhsv, bhsd-> bhvd' , K_feat, V)

        #denominator: (b, h, r)
        z_denom = 1. / (torch.einsum('bhsr , bhs->bhr', Q_feat, torch.sum(K_feat, dim=2)) + 1e-6)

        #numerator: (b, h, s, d)
        attn_output = torch.einsum('bhsr, bhvd->bhsd', Q_feat, K_feat_V)

        #normalize
        attn_output = attn_output * z_denom.unsqueeze(-1)

        return attn_output
    

    def forward(self, query_input, key_input=None , value_input=None, mask=None):
        batch_size = query_input.size(0)

        
        if key_input is None:
            key_input = query_input
        if value_input is None:
            value_input = query_input

        Q = self.W_q(query_input)
        K = self.W_k(key_input)
        V = self.W_v(value_input)

        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)

        attn_output = self.linear_attention(Q, K , V)

        attn_output = attn_output.transpose(1,2).contiguous().view(batch_size , -1, self.d_model)
        output = self.W_o(attn_output)

        return output, None #no attention weights returned


using RoPE instead of normal Positional Encoding

In [128]:

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=2048, base=10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        # Create inverse frequency vector for half the dimension
        inv_freq = 1. / (base ** (torch.arange(0, dim//2, dtype=torch.float32) / (dim//2)))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, x, seq_dim=1):
        batch_size, seq_len, dim = x.shape
        
        # Create position indices
        pos = torch.arange(seq_len, device=x.device).type(torch.float32)
        
        # Compute frequencies
        freqs = torch.einsum('i,j->ij', pos, self.inv_freq)
        
        # Compute sin and cos
        emb = torch.cat((freqs, freqs), dim=-1).view(seq_len, dim//2, 2)
        cos = emb[..., 0].view(1, seq_len, dim//2)
        sin = emb[..., 1].view(1, seq_len, dim//2)
        
        # Expand to match input batch size
        cos = cos.expand(batch_size, -1, -1)
        sin = sin.expand(batch_size, -1, -1)
        
        # Split input into half for rotation
        x1, x2 = x.chunk(2, dim=-1)
        
        # Apply rotation using sin and cos
        rotated = torch.cat([
            x1 * cos - x2 * sin,
            x2 * cos + x1 * sin
        ], dim=-1)
        
        return rotated

In [129]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model , num_heads, d_ff, dropout_rate=0.1, num_rfs=64):
        super(EncoderLayer,self).__init__()

        self.self_attn = LinearAttention(
            d_model = d_model,
            num_heads= num_heads,
            dropout_rate= dropout_rate,
            num_rfs=num_rfs
        )

        #preNorm wrappers 
        self.attn_norm = PreNorm(d_model , self.self_attn)

        # GLU FFN
        self.ffn = GLUFeedForward(
            d_model=d_model,
            d_ff=d_ff,
            dropout_rate=dropout_rate
        )

        self.ffn_norm = PreNorm(d_model, self.ffn)


        # RoPE
        self.rope = RotaryPositionalEmbedding(d_model)


    def forward(self, x, src_mask = None):
        x_pos = self.rope(x) #(batch_size, seq_len, d_model)
        x = self.attn_norm(x_pos) #self attention sublayer with prenorm includes residual connection internally
        x = self.ffn_norm(x) #feed forward network sublayer with PreNorm


        return x, None

In [130]:
class Encoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout_rate=0.1, num_rfs=64):
        super(Encoder, self).__init__()
        self.num_layers = num_layers

        self.layers = nn.ModuleList([
            EncoderLayer(
                d_model = d_model,
                num_heads= num_heads,
                d_ff = d_ff,
                dropout_rate=dropout_rate,
                num_rfs= num_rfs
            )
            for _ in range(num_layers)
        ])

    def forward(self, x, src_mask = None):
        for layer in self.layers:
            x, _ = layer(x, src_mask) #second return value is none (no attention weights are being returned)\
        return x, None
        
         

In [131]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model , d_ff , num_heads, dropout_rate= 0.1, num_rfs = 64):
        super(DecoderLayer, self).__init__()
        
        self.masked_self_attn = LinearAttention(
            d_model = d_model,
            num_heads= num_heads,
            dropout_rate= dropout_rate,
            num_rfs=num_rfs
        )

        self.self_attn_norm= PreNorm(d_model, self.masked_self_attn)

#encoder decoder cross attention
        self.encoder_decoder_attn = LinearAttention(
            d_model= d_model,
            num_heads= num_heads,
            dropout_rate= dropout_rate,
            num_rfs= num_rfs
        )

        self.encoder_decoder_attn_norm = PreNorm(d_model, self.encoder_decoder_attn)

            # GLU FFN
        self.ffn = GLUFeedForward(
            d_model=d_model,
            d_ff=d_ff,
            dropout_rate=dropout_rate
        )

        self.ffn_norm = PreNorm(d_model, self.ffn)

        self.rope = RotaryPositionalEmbedding(d_model)


    def forward(self, x , encoder_output , src_mask = None, tgt_mask= None):
        """
        Args:
            x: Input to decoder (batch_size, tgt_seq_len, d_model)
            encoder_output: Output from encoder (batch_size, src_seq_len, d_model)
            src_mask: Mask for encoder outputs (optional)
            tgt_mask: Causal mask for decoder self-attention (optional)

        Returns:
            x: Updated decoder output
            None, None: Placeholder for attention weights (not returned by linear attention)
        """
        

        x_pos = self.rope(x)
        x = self.self_attn_norm(x_pos)
        x = self.encoder_decoder_attn_norm(x, key_input=encoder_output, value_input=encoder_output)
        x = self.ffn_norm(x)


        return x, None , None
        

In [132]:
class Decoder(nn.Module):
    def __init__(self, num_layers, d_model , num_heads, d_ff, dropout_rate = 0.1 , num_rfs = 64):
        super(Decoder , self).__init__()
        self.num_layers = num_layers

        self.layers = nn.ModuleList([
            DecoderLayer(
                d_model = d_model,
                num_heads= num_heads,
                d_ff = d_ff,
                dropout_rate=dropout_rate,
                num_rfs= num_rfs
            )
            for _ in range(num_layers)
        ])

    def forward(self, x , encoder_output , src_mask = None, tgt_mask = None):
        for layer in self.layers:
            x, _, _ = layer(x , encoder_output, src_mask , tgt_mask)
        return x, None, None
    

In [133]:
import torch
import torch.nn as nn

class Transformer(nn.Module):

    def __init__(
        self,
        num_layers,
        d_model,
        num_heads,
        d_ff,
        input_vocab_size,
        target_vocab_size,
        dropout_rate=0.1,
        num_rfs=64  # Number of random features for Linear Attention
    ):
        super(Transformer , self).__init__()

        # Input Embeddings
        self.src_embedding = nn.Embedding(input_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(target_vocab_size, d_model)

        # Encoder and Decoder Stacks (with RoPE, Linear Attention, GLU, PreNorm)
        self.encoder = Encoder(
            num_layers=num_layers,
            d_model=d_model,
            num_heads=num_heads,
            d_ff=d_ff,
            dropout_rate=dropout_rate,
            num_rfs=num_rfs
        )
        self.decoder = Decoder(
            num_layers=num_layers,
            d_model=d_model,
            num_heads=num_heads,
            d_ff=d_ff,
            dropout_rate=dropout_rate,
            num_rfs=num_rfs
        )

        # Final projection layer
        self.final_linear = nn.Linear(d_model, target_vocab_size)

        # Dropout
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        Args:
            src: Source token indices (batch_size, src_seq_len)
            tgt: Target token indices (batch_size, tgt_seq_len)
            src_mask: Optional mask for encoder (batch_size, 1, 1, src_seq_len)
            tgt_mask: Optional causal mask for decoder (1, 1, tgt_seq_len, tgt_seq_len)

        Returns:
            final_output: Logits over target vocabulary
            None: Placeholder (no attention weights returned in linear attention)
        """

        # Input embeddings
        src_embedded = self.dropout(self.src_embedding(src))
        tgt_embedded = self.dropout(self.tgt_embedding(tgt))

        # Encoder pass
        encoder_output, _ = self.encoder(src_embedded, src_mask)

        # Decoder pass
        decoder_output, _, _ = self.decoder(tgt_embedded, encoder_output, src_mask, tgt_mask)

        # Final linear projection
        final_output = self.final_linear(decoder_output)

        return final_output

In [134]:
# to be used while training 
def create_padding_mask(seq, pad_idx):
    return (seq != pad_idx).unsqueeze(1).unsqueeze(2)  # (batch, 1, 1, seq_len)

def create_look_ahead_mask(size):
    return torch.triu(torch.ones(size, size), diagonal=1).type(torch.bool).unsqueeze(0).unsqueeze(0)  # (1, 1, size, size)

Testing the Transformer

In [None]:
print("Testing Full Transformer Model...")

num_layers = 2
d_model = 256 
num_heads = 4
d_ff = 1024 
input_vocab_size = 1000 
target_vocab_size = 800 
dropout_rate = 0.1
PAD_IDX = 0 

batch_size = 2 
src_seq_len_test = 50
tgt_seq_len_test = 40

# Dummy input data
dummy_src = torch.randint(1, input_vocab_size, (batch_size, src_seq_len_test))
dummy_tgt = torch.randint(1, target_vocab_size, (batch_size, tgt_seq_len_test))

# Apply padding
dummy_src[0, 45:] = PAD_IDX
dummy_tgt[1, 30:] = PAD_IDX

# Create masks
def create_padding_mask(seq, pad_idx):
    return (seq != pad_idx).unsqueeze(1).unsqueeze(2)  # (batch, 1, 1, seq_len)

def create_look_ahead_mask(size):
    return torch.triu(torch.ones(size, size), diagonal=1).type(torch.bool).unsqueeze(0).unsqueeze(0)

src_padding_mask = create_padding_mask(dummy_src, PAD_IDX)
tgt_padding_mask = create_padding_mask(dummy_tgt, PAD_IDX)
look_ahead_mask = create_look_ahead_mask(tgt_seq_len_test)

# Combine padding and look-ahead masks
tgt_mask = tgt_padding_mask & (create_look_ahead_mask(tgt_seq_len_test) == 0)

# Initialize model
transformer_model = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    d_ff=d_ff,
    input_vocab_size=input_vocab_size,
    target_vocab_size=target_vocab_size,
    dropout_rate=dropout_rate
)

# Forward pass
output_logits = transformer_model(dummy_src, dummy_tgt, src_padding_mask, tgt_mask)

# Print shapes to verify
print(f"\nSource input shape: {dummy_src.shape}")
print(f"Target input shape: {dummy_tgt.shape}")
print(f"Output logits shape: {output_logits.shape}")

# Assert output shape
assert output_logits.shape == (batch_size, tgt_seq_len_test, target_vocab_size), "Output shape mismatch!"

print("\n✅ Full Transformer Model test passed successfully!")

Testing Full Transformer Model...


RuntimeError: einsum(): subscript s has size 64 for operand 1 which does not broadcast with previously seen size 50