In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [3]:
# Using PreNorm instead of AddNorm
class PreNorm(nn.Module):
    def __init__(self, d_model, sublayer_fn, dropout_rate=0.1): # Renamed sublayer to sublayer_fn
        super(PreNorm, self).__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_rate)
        self.sublayer_fn = sublayer_fn # Renamed

    def forward(self, x, sublayer_input=None, key_input=None, value_input=None, mask=None):
        normalized_x = self.norm(x)

        # Handle different sublayer types
        if "GLUFeedForward" in str(self.sublayer_fn.__class__):
            sublayer_output = self.sublayer_fn(normalized_x)
        elif "LinearAttention" in str(self.sublayer_fn.__class__):
            if key_input is not None and value_input is not None: # Cross-attention
                sublayer_output, _ = self.sublayer_fn(normalized_x, key_input=key_input, value_input=value_input, mask=mask)
            else: 
                # normalized_x is used as Q, K, V
                sublayer_output, _ = self.sublayer_fn(normalized_x, key_input=normalized_x, value_input=normalized_x, mask=mask)
        else: 
            if key_input is None and value_input is None:
                 sublayer_output = self.sublayer_fn(normalized_x)
            else:
                 sublayer_output = self.sublayer_fn(normalized_x, key_input=key_input, value_input=value_input)


        dropped_output = self.dropout(sublayer_output)
        return x + dropped_output


In [4]:
# using GLU instead of Position wise FFN
# using GELU instead of ReLU
class GLUFeedForward(nn.Module):
    def __init__(self, d_model, d_ff=None, dropout_rate = 0.1):
        super(GLUFeedForward, self).__init__()
        hidden_dim = d_ff or 4 * d_model 
        self.gate_proj = nn.Linear(d_model , hidden_dim)
        self.value_proj = nn.Linear(d_model, hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, d_model)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        gate = self.gate_proj(x)
        value = self.value_proj(x)
        x = F.gelu(gate) * value # GELU for the gate
        x = self.dropout(x)
        x = self.output_proj(x)
        return x


In [5]:
# updating Multi head attention with FAVOR+ Fast attention
# via positive orthogonal random features kindof linear attention

def orthogonal_random_features(dim, num_heads, num_rfs, device):

    # rand_proj shape: (num_heads, num_rfs, dim)
    rand_proj = torch.randn(num_heads, num_rfs, dim, device=device)
    q, _ = torch.linalg.qr(rand_proj.transpose(-2, -1))
    return q.transpose(-2, -1)


def elu_kernel(x): # Changed from kernal to kernel
    return F.elu(x) + 1.0 # Add 1.0 to ensure positivity for FAVOR+

class LinearAttention(nn.Module):
    def __init__(self, d_model , num_heads, dropout_rate = 0.1, num_rfs = 64): # num_rfs default from encoder
        super(LinearAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.d_model = d_model
        self.num_rfs = num_rfs # number of random features
        self.num_heads = num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout_rate)


        self.rand_projs = None # Initialize placeholder

    def _init_random_features(self, head_dim, device):
       
        if self.rand_projs is None or self.rand_projs.shape[-1] != head_dim or self.rand_projs.device != device :
             self.rand_projs = orthogonal_random_features(
                head_dim, self.num_heads, self.num_rfs, device
            )


    def _apply_mask(self, tensor, mask):
        if mask is None:
            return tensor
        
       
        if mask.ndim == 4 and mask.shape[1] == 1 and mask.shape[2] == 1: # (b, 1, 1, s_kv)
            mask = mask.transpose(-2, -1) # (b, 1, s_kv, 1)
        elif mask.ndim == 2: # (b, s_kv)
            mask = mask.unsqueeze(1).unsqueeze(-1) # (b, 1, s_kv, 1)
        else:
            pass
        
        if tensor.dtype == torch.bool:
             return tensor.masked_fill(mask == 0, False) # Assuming mask 0 is pad
        else:
             return tensor.masked_fill(mask == 0, 0.0)


    def linear_attention(self, Q, K, V, key_padding_mask=None):
        # Q, K, V : shape = (batch_size, num_heads, seq_len, head_dim)
        # key_padding_mask: (batch_size, 1, 1, seq_len_kv) or (batch_size, seq_len_kv)
        batch_size, num_heads, seq_len_q, head_dim = Q.shape
        _, _, seq_len_kv, _ = K.shape


        self._init_random_features(head_dim, Q.device)

        # Projecting Q and K through random features
        # Q: (b, h, s_q, d), self.rand_projs: (h, r, d) -> Q_rand: (b, h, s_q, r)
        Q_rand = torch.einsum('bhsd,hrd->bhsr', Q, self.rand_projs)
        K_rand = torch.einsum('bhsd,hrd->bhsr', K, self.rand_projs)

        # Applying kernel function (elu + 1 to ensure positivity)
        Q_feat = elu_kernel(Q_rand)    # (b, h, s_q, r)
        K_feat = elu_kernel(K_rand)    # (b, h, s_kv, r)

        # Apply padding mask to K_feat and V BEFORE computations
        if key_padding_mask is not None:
            # Transpose mask from (B, 1, 1, S_kv) to (B, 1, S_kv, 1) for broadcasting
            _mask = key_padding_mask.transpose(-2, -1).to(K_feat.device) # (B, 1, S_kv, 1)
            K_feat = K_feat.masked_fill(_mask == 0, 0.0)
            V = V.masked_fill(_mask.expand_as(V) == 0, 0.0)


        K_feat_V = torch.einsum('bhsr,bhsd->bhrd', K_feat, V)

       
        numerator = torch.einsum('bhsr,bhrd->bhsd', Q_feat, K_feat_V)

      
        K_feat_sum_across_seq = torch.sum(K_feat, dim=2)  # (b, h, r)
        denominator = torch.einsum('bhsr,bhr->bhs', Q_feat, K_feat_sum_across_seq) # (b, h, s_q)
        attn_output = numerator / (denominator.unsqueeze(-1) + 1e-6)


        return attn_output


    def forward(self, query_input, key_input=None, value_input=None, mask=None):
       
        batch_size = query_input.size(0)

        
        if key_input is None:
            key_input = query_input
        if value_input is None:
            value_input = query_input 

        # Linear projections
        Q = self.W_q(query_input)
        K = self.W_k(key_input)
        V = self.W_v(value_input)

        # Reshape and transpose for multi-head attention
        # (batch_size, seq_len, d_model) -> (batch_size, num_heads, seq_len, d_k)
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)

        # Apply linear attention

        attn_output = self.linear_attention(Q, K, V, key_padding_mask=mask)

        # Concatenate heads and apply final linear layer
        attn_output = attn_output.transpose(1,2).contiguous().view(batch_size , -1, self.d_model)
        output = self.W_o(attn_output)
        output = self.dropout(output) # Apply dropout on the final output of MHA

        return output, None # No attention weights returned for linear attention


In [6]:
# using RoPE instead of normal Positional Encoding
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=2048, base=10000):
        super().__init__()
        assert dim % 2 == 0, "Dimension must be even for RoPE."
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

       
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False) 

       
        t = torch.arange(self.max_seq_len).type_as(self.inv_freq)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().unsqueeze(0), persistent=False) # (1, max_seq_len, dim)
        self.register_buffer("sin_cached", emb.sin().unsqueeze(0), persistent=False) # (1, max_seq_len, dim)


    def forward(self, x, seq_dim=1):
      
        seq_len = x.shape[seq_dim]

        if seq_len > self.max_seq_len:

            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos().unsqueeze(0)
            sin = emb.sin().unsqueeze(0)
        else:
            cos = self.cos_cached[:, :seq_len, ...]
            sin = self.sin_cached[:, :seq_len, ...]

        
        if x.ndim == 3: # (batch, seq_len, dim)
             cos = cos.squeeze(0) # (seq_len, dim)
             sin = sin.squeeze(0) # (seq_len, dim)
       
        return (x * cos) + (rotate_half(x) * sin)


In [7]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model , num_heads, d_ff, dropout_rate=0.1, num_rfs=64):
        super(EncoderLayer,self).__init__()

        self.self_attn_module = LinearAttention( # Store module itself
            d_model = d_model,
            num_heads= num_heads,
            dropout_rate= dropout_rate,
            num_rfs=num_rfs
        )
        self.attn_norm = PreNorm(d_model , self.self_attn_module, dropout_rate=dropout_rate)

        self.ffn_module = GLUFeedForward( # Store module itself
            d_model=d_model,
            d_ff=d_ff,
            dropout_rate=dropout_rate
        )
        self.ffn_norm = PreNorm(d_model, self.ffn_module, dropout_rate=dropout_rate)
        self.rope = RotaryPositionalEmbedding(d_model) # RoPE for d_model

    def forward(self, x, src_mask = None): 
        x_with_rope = self.rope(x)

        
        x = self.attn_norm(x, sublayer_input=x_with_rope, mask=src_mask)

        # FFN with PreNorm
        x = self.ffn_norm(x)
        return x, None # No attention weights


In [8]:
class Encoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout_rate=0.1, num_rfs=64):
        super(Encoder, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList([
            EncoderLayer(
                d_model = d_model,
                num_heads= num_heads,
                d_ff = d_ff,
                dropout_rate=dropout_rate,
                num_rfs= num_rfs
            ) for _ in range(num_layers)
        ])
   

    def forward(self, x, src_mask = None): # x is embedded_src
       
        for layer in self.layers:
            x, _ = layer(x, src_mask)
        return x, None

In [9]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model , d_ff , num_heads, dropout_rate= 0.1, num_rfs = 64):
        super(DecoderLayer, self).__init__()

        self.masked_self_attn_module = LinearAttention(
            d_model = d_model, num_heads= num_heads, dropout_rate= dropout_rate, num_rfs=num_rfs
        )
        self.self_attn_norm = PreNorm(d_model, self.masked_self_attn_module, dropout_rate=dropout_rate)

        self.encoder_decoder_attn_module = LinearAttention(
            d_model= d_model, num_heads= num_heads, dropout_rate= dropout_rate, num_rfs= num_rfs
        )
        # For cross-attention, query comes from decoder, K/V from encoder_output
        self.encoder_decoder_attn_norm = PreNorm(d_model, self.encoder_decoder_attn_module, dropout_rate=dropout_rate)

        self.ffn_module = GLUFeedForward(
            d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate
        )
        self.ffn_norm = PreNorm(d_model, self.ffn_module, dropout_rate=dropout_rate)
        self.rope = RotaryPositionalEmbedding(d_model) # RoPE for d_model

    def forward(self, x , encoder_output , src_mask = None, tgt_mask= None):
       
      
        x_with_rope = self.rope(x)

 
        x = self.self_attn_norm(x, sublayer_input=x_with_rope, mask=tgt_mask)


      
        x = self.encoder_decoder_attn_norm(x, key_input=encoder_output, value_input=encoder_output, mask=src_mask)

        # FFN
        x = self.ffn_norm(x)
        return x, None , None



In [10]:
class Decoder(nn.Module):
    def __init__(self, num_layers, d_model , num_heads, d_ff, dropout_rate = 0.1 , num_rfs = 64):
        super(Decoder , self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList([
            DecoderLayer(
                d_model = d_model, d_ff = d_ff, num_heads= num_heads,
                dropout_rate=dropout_rate, num_rfs= num_rfs
            ) for _ in range(num_layers)
        ])
        

    def forward(self, x , encoder_output , src_mask = None, tgt_mask = None): # x is embedded_tgt
        for layer in self.layers:
            x, _, _ = layer(x , encoder_output, src_mask , tgt_mask)
        return x, None, None



In [11]:
class Transformer(nn.Module):
    def __init__(
        self,
        num_layers,
        d_model,
        num_heads,
        d_ff,
        input_vocab_size,
        target_vocab_size,
        dropout_rate=0.1,
        num_rfs=64  # Number of random features for Linear Attention
    ):
        super(Transformer , self).__init__()

        self.d_model = d_model # Store d_model
        self.src_embedding = nn.Embedding(input_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(target_vocab_size, d_model)

       
        self.encoder = Encoder(
            num_layers=num_layers, d_model=d_model, num_heads=num_heads,
            d_ff=d_ff, dropout_rate=dropout_rate, num_rfs=num_rfs
        )
        self.decoder = Decoder(
            num_layers=num_layers, d_model=d_model, num_heads=num_heads,
            d_ff=d_ff, dropout_rate=dropout_rate, num_rfs=num_rfs
        )
        self.final_linear = nn.Linear(d_model, target_vocab_size)
        self.dropout = nn.Dropout(dropout_rate) # General dropout for embeddings

        # Initialize parameters (optional but good practice)
        self._init_weights()

    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
    

     
        src_embedded = self.dropout(self.src_embedding(src) * math.sqrt(self.d_model))
        tgt_embedded = self.dropout(self.tgt_embedding(tgt) * math.sqrt(self.d_model))

       
        encoder_output, _ = self.encoder(src_embedded, src_mask)

     
        decoder_output, _, _ = self.decoder(tgt_embedded, encoder_output, src_mask, tgt_mask)

        final_output = self.final_linear(decoder_output)
        return final_output



In [12]:
# Mask creation functions
def create_padding_mask(seq, pad_idx):
   
    return (seq != pad_idx).unsqueeze(1).unsqueeze(2)

def create_look_ahead_mask(size, device):
   
    mask = torch.triu(torch.ones(size, size, device=device), diagonal=1).type(torch.bool)
    return (~mask).unsqueeze(0).unsqueeze(0) # Invert: True for non-masked

In [13]:
# --- Test Script ---
print("Testing Full Transformer Model...")

# Hyperparameters
num_layers = 2
d_model = 256
num_heads = 4
d_ff = d_model * 4 # Standard FFN hidden dim
input_vocab_size = 1000
target_vocab_size = 800
dropout_rate = 0.1
PAD_IDX = 0
num_rfs_test = 64 # For LinearAttention

# Data parameters
batch_size = 2
src_seq_len_test = 50
tgt_seq_len_test = 40
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dummy input data
dummy_src = torch.randint(1, input_vocab_size, (batch_size, src_seq_len_test), device=device)
dummy_tgt = torch.randint(1, target_vocab_size, (batch_size, tgt_seq_len_test), device=device)

# Apply padding
dummy_src[0, src_seq_len_test-5:] = PAD_IDX
dummy_tgt[1, tgt_seq_len_test-10:] = PAD_IDX

# Create masks
# src_padding_mask: True for non-pad, False for pad. Shape (B, 1, 1, S_src)
src_padding_mask = create_padding_mask(dummy_src, PAD_IDX)

# tgt_padding_mask: True for non-pad, False for pad. Shape (B, 1, 1, S_tgt)
tgt_self_padding_mask = create_padding_mask(dummy_tgt, PAD_IDX)

# look_ahead_mask: True for allowed, False for future. Shape (1, 1, S_tgt, S_tgt)
look_ahead_causal_mask = create_look_ahead_mask(tgt_seq_len_test, device)


decoder_self_attn_mask = tgt_self_padding_mask # (B,1,1,S_tgt)

# For encoder-decoder attention, the mask is src_padding_mask (for encoder keys)
decoder_cross_attn_mask = src_padding_mask


# Initialize model
transformer_model = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    d_ff=d_ff,
    input_vocab_size=input_vocab_size,
    target_vocab_size=target_vocab_size,
    dropout_rate=dropout_rate,
    num_rfs=num_rfs_test
).to(device)

# Forward pass
# Transformer.forward expects: src, tgt, src_mask (for encoder and cross-attn), tgt_mask (for decoder self-attn)
output_logits = transformer_model(dummy_src, dummy_tgt, decoder_cross_attn_mask, decoder_self_attn_mask)

# Print shapes to verify
print(f"\nSource input shape: {dummy_src.shape}")
print(f"Target input shape: {dummy_tgt.shape}")
print(f"Output logits shape: {output_logits.shape}")

# Assert output shape
assert output_logits.shape == (batch_size, tgt_seq_len_test, target_vocab_size), "Output shape mismatch!"

print("\n✅ Full Transformer Model test passed successfully (structurally)!")
print("Note: Logic of LinearAttention causality and full mask interaction needs careful consideration for semantic correctness beyond shape matching.")


Testing Full Transformer Model...
Using device: cuda

Source input shape: torch.Size([2, 50])
Target input shape: torch.Size([2, 40])
Output logits shape: torch.Size([2, 40, 800])

✅ Full Transformer Model test passed successfully (structurally)!
Note: Logic of LinearAttention causality and full mask interaction needs careful consideration for semantic correctness beyond shape matching.
