In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


using PreNorm instead of AddNorm

In [3]:
class PreNorm(nn.Module):
    def __init__(self, d_model, sublayer, dropout_rate=0.1):
        super(PreNorm, self).__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_rate)
        self.sublayer = sublayer

    def forward(self,x):
        normalized_output = self.norm(x)
        sublayer_output = self.sublayer(normalized_output)
        dropped_output = self.dropout(sublayer_output)
        return x + dropped_output
    
    


using GLU instead of Position wise FFN

using GELU instead of ReLU

In [4]:
class GLUFeedForward(nn.Module):
    def __init__(self, d_model, d_ff=None, dropout_rate = 0.1):
        super(GLUFeedForward, self).__init__()
        d_ff = d_ff or d_model
        self.gate_proj = nn.Linear(d_model , d_ff)
        self.value_proj = nn.Linear(d_model, d_ff)
        self.output_proj = nn.Linear(d_ff, d_model)

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        gate = self.gate_proj(x)
        value = self.value_proj(x)

        x = F.gelu(gate)*value
        x = self.dropout(x)

        #project back to d_model
        x = self.output_proj(x)
        return x

updating Multi head attention with FAVOR+ Fast attention via positive orthogonal random features kindof linear attention

In [5]:
def orthogonal_random_features(dim, num_heads, head_dim, device):
    rand_proj = torch.randn(num_heads, dim, head_dim, device=device)
    q, _ = torch.linalg.qr(rand_proj)
    return q 


def elu_kernal(x):
    return F.elu(x) + 1

class LinearAttention(nn.Module):
    def __init__(self, d_model , num_heads, dropout_rate = 0.1, num_rfs = 10):
        super(LinearAttention, self).__init__()
        assert d_model % num_heads ==0

        self.d_k =d_model // num_heads
        self.d_model = d_model
        self.num_rfs = num_rfs # number of random features

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout_rate)

        #Learnable scale Parameter for queries / keys
        self.scale = nn.Parameter(torch.tensor(d_model ** -0.5))

    def linear_attention(self, Q, K, V):
        # Q, K, V : shape = (batch_size, num_heads, seq_len, head_dim)

        batch_size, num_heads, seq_len, head_dim = Q.shape

        #random feature projection one per head
        if not hasattr(self, 'rand_projs'):
            self.rand_projs = orthogonal_random_features(
                self.d_model, self.num_heads , self.num_rfs, Q.device
            )

        #projecting Q and K through random features
        Q = Q.unsqueeze(-2) #(b, h, s, 1, d_k)
        K = K.unsqueeze(-2)

        Q_rand = torch.matmul(Q, self.rand_projs[:num_heads])  #(b, h, s, 1, r)
        K_rand = torch.matmul(K, self.rand_projs[:num_heads]) #(b,h,s,1,r)

        #applying kernal function
        Q_feat = elu_kernal(Q_rand).squeeze(-2)    #(b, h, s, r)
        K_feat = elu_kernal(K_rand).squeeze(-2)  #(b, h, s, r)

        #compute KV numerator: (b,h,r,d)
        K_feat_V = torch.einsum('bhsv, bhsd-> bhvd' , K_feat, V)

        #denominator: (b, h, r)
        z_denom = 1. / (torch.einsum('bhsr , bhs->bhr', Q_feat, torch.sum(K_feat, dim=2)) + 1e-6)

        #numerator: (b, h, s, d)
        attn_output = torch.einsum('bhsr, bhvd->bhsd', Q_feat, K_feat_V)

        #normalize
        attn_output = attn_output * z_denom.unsqueeze(-1)

        return attn_output
    

    def forward(self, query_input, key_input , value_input, mask=None):
        batch_size = query_input.size(0)

        Q = self.W_q(query_input)
        K = self.W_k(key_input)
        V = self.W_v(value_input)

        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)

        attn_output = self.linear_attention(Q, K , V)

        attn_output = attn_output.transpose(1,2).contiguous().view(batch_size , -1, self.d_model)
        output = self.W_o(attn_output)

        return output, None #no attention weights returned


using RoPE instead of normal Positional Encoding

In [6]:

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=2048, base=10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        # Compute inverse frequency vector
        inv_freq = 1. / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
        self.register_buffer("inv_freq", inv_freq)

        # Precompute position encodings
        t = torch.arange(self.max_seq_len, dtype=torch.float32)
        freqs = torch.einsum("i,j->ij", t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)  # Shape: (seq_len, dim)

        # Apply sin/cos to create positional embeddings
        self.register_buffer("cos_cached", emb.cos()[None, :, :])
        self.register_buffer("sin_cached", emb.sin()[None, :, :])

    def forward(self, x, seq_dim=1):
        seq_len = x.shape[seq_dim]

        if seq_dim == 0:
            cos_pos = self.cos_cached[:, :seq_len, ...]
            sin_pos = self.sin_cached[:, :seq_len, ...]
        else:
            cos_pos = self.cos_cached[:, :seq_len, ...]
            sin_pos = self.sin_cached[:, :seq_len, ...]

        x_rotated = rotate_half(x)
        return (x * cos_pos) + (x_rotated * sin_pos)
