In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, emb_dim, ffn_dim, n_heads, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(emb_dim, n_heads)
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ln2 = nn.LayerNorm(emb_dim)

        self.ffn = nn.Sequential(
            nn.Linear(emb_dim, ffn_dim),
            nn.ReLU(),
            nn.Linear(ffn_dim, emb_dim),
            nn.Dropout(dropout),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        self_attn_output = self.self_attn(x, kv=None, mask=mask)
        x = x + self.dropout(self_attn_output) # residual connection
        x = self.ln1(x)
        ffn_output = self.ffn(x)
        x = x + self.dropout(ffn_output)
        x = self.ln2(x)
        return x    

class TransformerDecoderBlock(nn.Module):
    def __init__(self, emb_dim, ffn_dim, n_heads, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ln2 = nn.LayerNorm(emb_dim)
        self.ln3 = nn.LayerNorm(emb_dim)
        
        self.self_attn  = MultiHeadAttention(emb_dim, n_heads)  # decoder self-attn
        self.cross_attn = MultiHeadAttention(emb_dim, n_heads)  # encoder-decoder attn
        
        self.ffn = nn.Sequential(
            nn.Linear(emb_dim, ffn_dim),
            nn.ReLU(),
            nn.Linear(ffn_dim, emb_dim),
            nn.Dropout(dropout),
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, encoder_output, self_mask=None, encoder_decoder_mask=None):
        self_attn_output = self.self_attn(x, kv=None, mask=self_mask)
        x = x + self.dropout(self_attn_output)
        x = self.ln1(x)

        encoder_decoder_attn_output = self.cross_attn(q=x, kv=encoder_output, mask=encoder_decoder_mask)
        x = x + self.dropout(encoder_decoder_attn_output)
        x = self.ln2(x)
        
        ffn_output = self.ffn(x)
        x = x + self.dropout(ffn_output)
        x = self.ln3(x)
        return x

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, emb_dim, max_len=5000, base=10000):
        super().__init__()
        pe = torch.zeros(max_len, emb_dim) # (S, D)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (S, 1)
        div_term = torch.exp(torch.arange(0, emb_dim, 2).float() * (-math.log(base) / emb_dim)) # (D/2,)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) # (1, S, D)
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        seq_len = x.shape[1] # (B, S, D)
        return self.pe[:, :seq_len] # (1, S, D)


class RoPE(nn.Module):
    def __init__(self, emb_dim, n_heads, max_len=5000, base=10000):
        super().__init__()
        assert emb_dim % n_heads == 0, "emb_dim must be divisible by n_heads"
        self.base = base
        self.head_dim = emb_dim // n_heads
        self.max_len = max_len  

        inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2, dtype=torch.float) / self.head_dim))
        pos = torch.arange(max_len, dtype=torch.float).unsqueeze(1)        # (S, 1)
        angle = pos * inv_freq                                             # (S, D/2)

        cos = angle.cos()[None, :, None, :]  # (1, S, 1, D/2)  
        sin = angle.sin()[None, :, None, :]  # (1, S, 1, D/2)

        self.register_buffer('cos', cos, persistent=False)
        self.register_buffer('sin', sin, persistent=False)

    def forward(self, x, start=0):
        # x: [B, S, H, head_dim]
        B, S, H, D = x.shape
        assert D == self.head_dim, "x.shape[-1] must match head_dim"
        assert start + S <= self.max_len, "start + S must be less than max_len"

        cos = self.cos[:, start:start+S].to(x.device, dtype=x.dtype)  # (1, S, 1, D/2)
        sin = self.sin[:, start:start+S].to(x.device, dtype=x.dtype)  # (1, S, 1, D/2)

        x_even, x_odd = x[..., ::2], x[..., 1::2]                     # (B, S, H, D/2)
        x_even_r = x_even * cos - x_odd * sin
        x_odd_r  = x_even * sin + x_odd * cos

        out = torch.empty_like(x)
        out[..., ::2] = x_even_r
        out[..., 1::2] = x_odd_r
        return out


        

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_dim, n_heads, dropout=0.1):
        super().__init__()
        assert emb_dim % n_heads == 0, "emb_dim must be divisible by n_heads"

        self.emb_dim = emb_dim
        self.n_heads = n_heads
        self.head_dim = emb_dim // n_heads

        self.W_Q = nn.Linear(emb_dim, emb_dim)  # (B, S_q, D) 
        self.W_K = nn.Linear(emb_dim, emb_dim)  # (B, S_k, D) 
        self.W_V = nn.Linear(emb_dim, emb_dim)  # (B, S_k, D) 

        self.W_O = nn.Linear(emb_dim, emb_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, kv=None, mask=None):
        # x:  (B, S_q, D)  queries
        # kv: (B, S_k, D)  keys/values source; if None -> self-attention (kv = x)
        B, S_q, D = x.shape
        if kv is None:
            kv = x
        S_k = kv.shape[1]

        # project Q from x, K/V from kv
        q = self.W_Q(x)      # (B, S_q, D)
        k = self.W_K(kv)     # (B, S_k, D)
        v = self.W_V(kv)     # (B, S_k, D)

        # split heads
        q = q.view(B, S_q, self.n_heads, self.head_dim)  # (B, S_q, H, h_dim)
        k = k.view(B, S_k, self.n_heads, self.head_dim)  # (B, S_k, H, h_dim)
        v = v.view(B, S_k, self.n_heads, self.head_dim)  # (B, S_k, H, h_dim)

        # move heads forward
        q = q.transpose(1, 2)  # (B, H, S_q, h_dim)
        k = k.transpose(1, 2)  # (B, H, S_k, h_dim)
        v = v.transpose(1, 2)  # (B, H, S_k, h_dim)

        # scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)  # (B, H, S_q, S_k)

        if mask is not None:
            if mask.dtype != torch.bool:
                mask = mask != 0

            if mask.ndim == 2 and mask.shape == (S_q, S_k):   # causal/visibility mask
                mask = mask.unsqueeze(0).unsqueeze(0)          # (1, 1, S_q, S_k)
            elif mask.ndim == 2 and mask.shape == (B, S_k):    # key padding mask over K
                mask = mask.unsqueeze(1).unsqueeze(2)          # (B, 1, 1, S_k)
            else:
                raise ValueError(f"Invalid mask shape: {mask.shape}")
            scores = scores.masked_fill(~mask, float("-inf"))  # (B, H, S_q, S_k)

        attn_scores = F.softmax(scores, dim=-1)      # for each query over all keys (B, H, S_q, S_k)
        attn_scores = self.dropout(attn_scores)

        # weighted sum of V
        context_vector = torch.matmul(attn_scores, v)               # (B, H, S_q, h_dim)
        context_vector = context_vector.transpose(1, 2).contiguous()  # (B, S_q, H, h_dim)
        context_vector = context_vector.view(B, S_q, D)             # concat heads -> (B, S_q, D)

        out = self.W_O(context_vector)                             # mix head information -> (B, S_q, D)
        return out


In [None]:
class MultiHeadAttentionWithCache(nn.Module):
    def __init__(self, emb_dim, n_heads, dropout=0.1):
        super().__init__()
        assert emb_dim % n_heads == 0, "emb_dim must be divisible by n_heads"

        self.emb_dim = emb_dim
        self.n_heads = n_heads
        self.head_dim = emb_dim // n_heads

        self.W_Q = nn.Linear(emb_dim, emb_dim)  
        self.W_K = nn.Linear(emb_dim, emb_dim)  
        self.W_V = nn.Linear(emb_dim, emb_dim)  

        self.W_O = nn.Linear(emb_dim, emb_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, kv=None, past_kv=None):
        # prefilling: S > 1
        # decoding: S = 1 
        B, S_q, D = x.shape
        if kv is None:
            kv = x
        S_k = kv.shape[1]

        q = self.W_Q(x)     
        k = self.W_K(kv)   
        v = self.W_V(kv)     

        q = q.view(B, S_q, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, S_k, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, S_k, self.n_heads, self.head_dim).transpose(1, 2)

        if past_kv is not None:
            k = torch.cat([past_kv[0], k], dim=2)
            v = torch.cat([past_kv[1], v], dim=2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) 

        attn_scores = F.softmax(scores, dim=-1)      
        attn_scores = self.dropout(attn_scores)

        context_vector = torch.matmul(attn_scores, v)          
        context_vector = context_vector.transpose(1, 2).contiguous() 
        context_vector = context_vector.view(B, S_q, D)          

        out = self.W_O(context_vector)                             
        present_kv = (k, v)
        return out, present_kv


In [None]:
class MultiQueryAttention(nn.Module):
    def __init__(self, emb_dim, n_heads, dropout=0.1):
        super().__init__()
        assert emb_dim % n_heads == 0, "emb_dim must be divisible by n_heads"

        self.emb_dim = emb_dim
        self.n_heads = n_heads
        self.head_dim = emb_dim // n_heads

        self.W_Q = nn.Linear(emb_dim, emb_dim)  # (B, S_q, D) 
        self.W_K = nn.Linear(emb_dim, self.head_dim)  #  -> (B, S_k, h_dim) 
        self.W_V = nn.Linear(emb_dim, self.head_dim)  # -> (B, S_k, h_dim) 
        self.W_O = nn.Linear(emb_dim, emb_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, kv=None, mask=None):
        B, S_q, D = x.shape
        if kv is None:
            kv = x
        S_k = kv.shape[1]

        q = self.W_Q(x)      # (B, S_q, D)
        k = self.W_K(kv)     # (B, S_k, h_dim)
        v = self.W_V(kv)     # (B, S_k, h_dim)

        q = q.view(B, S_q, self.n_heads, self.head_dim)  # (B, S_q, H, h_dim)
        k = k.view(B, S_k, 1, self.head_dim)  # (B, S_k, 1, h_dim)
        v = v.view(B, S_k, 1, self.head_dim)  # (B, S_k, 1, h_dim)

        q = q.transpose(1, 2)  # (B, H, S_q, h_dim)
        k = k.transpose(1, 2)  # (B, 1, S_k, h_dim)
        v = v.transpose(1, 2)  # (B, 1, S_k, h_dim)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)  # (B, H, S_q, S_k)

        if mask is not None:
            if mask.dtype != torch.bool:
                mask = mask != 0

            if mask.ndim == 2 and mask.shape == (S_q, S_k):   # causal/visibility mask
                mask = mask.unsqueeze(0).unsqueeze(0)          # (1, 1, S_q, S_k)
            elif mask.ndim == 2 and mask.shape == (B, S_k):    # key padding mask over K
                mask = mask.unsqueeze(1).unsqueeze(2)          # (B, 1, 1, S_k)
            else:
                raise ValueError(f"Invalid mask shape: {mask.shape}")
            scores = scores.masked_fill(~mask, float("-inf"))  # (B, H, S_q, S_k)

        attn_scores = F.softmax(scores, dim=-1)      # for each query over all keys (B, H, S_q, S_k)
        attn_scores = self.dropout(attn_scores)

        # weighted sum of V
        context_vector = torch.matmul(attn_scores, v)               # (B, H, S_q, h_dim)
        context_vector = context_vector.transpose(1, 2).contiguous()  # (B, S_q, H, h_dim)
        context_vector = context_vector.view(B, S_q, D)             # concat heads -> (B, S_q, D)

        out = self.W_O(context_vector)                             # mix head information -> (B, S_q, D)
        return out
        
        

In [None]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, emb_dim, n_heads, kv_groups, dropout=0.1):
        super().__init__()
        assert emb_dim % n_heads == 0, "emb_dim must be divisible by n_heads"
        assert emb_dim % kv_groups == 0, "emb_dim must be divisible by kv_groups"
        assert n_heads % kv_groups == 0, "n_heads must be divisible by kv_groups"

        self.emb_dim = emb_dim
        self.n_heads = n_heads
        self.kv_groups = kv_groups
        self.head_dim = emb_dim // n_heads

        self.W_Q = nn.Linear(emb_dim, emb_dim)  # (B, S_q, D) 
        self.W_K = nn.Linear(emb_dim, self.kv_groups * self.head_dim) 
        self.W_V = nn.Linear(emb_dim, self.kv_groups * self.head_dim) 
        self.W_O = nn.Linear(emb_dim, emb_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, kv=None, mask=None):
        B, S_q, D = x.shape
        if kv is None:
            kv = x
        S_k = kv.shape[1]

        q = self.W_Q(x)      # (B, S_q, D)
        k = self.W_K(kv)     # (B, S_k, h_dim)
        v = self.W_V(kv)     # (B, S_k, h_dim)

        q = q.view(B, S_q, self.n_heads, self.head_dim)  # (B, S_q, H, h_dim)
        k = k.view(B, S_k, self.kv_groups, self.head_dim)  # (B, S_k, G, h_dim)
        v = v.view(B, S_k, self.kv_groups, self.head_dim)  # (B, S_k, G, h_dim)

        q = q.transpose(1, 2)  # (B, H, S_q, h_dim)
        k = k.transpose(1, 2)  # (B, G, S_k, h_dim)
        v = v.transpose(1, 2)  # (B, G, S_k, h_dim)

        if self.kv_groups != self.n_heads:
            k = k.repeat(1, self.n_heads // self.kv_groups, 1, 1)
            v = v.repeat(1, self.n_heads // self.kv_groups, 1, 1)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)  # (B, H, S_q, S_k)

        if mask is not None:
            if mask.dtype != torch.bool:
                mask = mask != 0

            if mask.ndim == 2 and mask.shape == (S_q, S_k):   # causal/visibility mask
                mask = mask.unsqueeze(0).unsqueeze(0)          # (1, 1, S_q, S_k)
            elif mask.ndim == 2 and mask.shape == (B, S_k):    # key padding mask over K
                mask = mask.unsqueeze(1).unsqueeze(2)          # (B, 1, 1, S_k)
            else:
                raise ValueError(f"Invalid mask shape: {mask.shape}")
            scores = scores.masked_fill(~mask, float("-inf"))  # (B, H, S_q, S_k)

        attn_scores = F.softmax(scores, dim=-1)      # for each query over all keys (B, H, S_q, S_k)
        attn_scores = self.dropout(attn_scores)

        # weighted sum of V
        context_vector = torch.matmul(attn_scores, v)               # (B, H, S_q, h_dim)
        context_vector = context_vector.transpose(1, 2).contiguous()  # (B, S_q, H, h_dim)
        context_vector = context_vector.view(B, S_q, D)             # concat heads -> (B, S_q, D)

        out = self.W_O(context_vector)                             # mix head information -> (B, S_q, D)
        return out

        
        

In [None]:
class MutiheadLatentAttention(nn.Module):
    def __init__(self, emb_dim, down_dim, up_dim, rope_dim, n_heads, dropout=0.1):
        super().__init__()
        assert emb_dim % n_heads == 0, "emb_dim must be divisible by n_heads"

        self.emb_dim = emb_dim
        self.n_heads = n_heads
        self.down_dim = down_dim
        self.up_dim = up_dim
        self.rope_dim = rope_dim
        self.head_dim = emb_dim // n_heads
        self.v_head_dim = up_dim // n_heads

        self.down_proj_kv = nn.Linear(emb_dim, down_dim)
        self.down_proj_q = nn.Linear(emb_dim, down_dim)

        self.up_proj_q = nn.Linear(down_dim, up_dim)
        self.up_proj_k = nn.Linear(down_dim, up_dim)
        self.up_proj_v = nn.Linear(down_dim, up_dim)

        self.proj_qr = nn.Linear(down_dim, rope_dim * n_heads)
        self.proj_kr_ = nn.Linear(emb_dim, rope_dim)

        self.rope_q = RoPE(rope_dim * n_heads, n_heads)
        self.rope_k = RoPE(rope_dim, 1)

        self.dropout = nn.Dropout(dropout)
        self.res_dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(n_heads * self.v_head_dim, emb_dim)

    def forward(self, h, kv=None, mask=None):
        B, S_q, D = h.shape

        if kv is None:
            kv = h
        S_k = kv.shape[1]

        c_t_kv = self.down_proj_kv(kv) # (B, S_k, down_dim)
        k_t_c = self.up_proj_k(c_t_kv) # (B, S_q, up_dim)
        v_t_c = self.up_proj_v(c_t_kv) # (B, S_q, up_dim)

        c_t_q = self.down_proj_q(h) # (B, S_q, down_dim)
        q_t_c = self.up_proj_q(c_t_q) # (B, S_q, up_dim)

        q_t_r = self.proj_qr(c_t_q) # (B, S_q, rope_dim * n_heads)
        q_t_r = q_t_r.view(B, S_q, self.n_heads, self.rope_dim).transpose(1, 2) # (B, H, S_q, rope_dim)
        q_t_r = self.rope_q(q_t_r) # (B, H, S_q, rope_dim)

        k_t_r = self.proj_kr_(kv).unsqueeze(1) # (B, 1, S_k, rope_dim)
        k_t_r = self.rope_k(k_t_r) # (B, 1, S_k, rope_dim)

        q_t_c = q_t_c.view(B, S_q, self.n_heads, self.v_head_dim).transpose(1, 2) # (B, H, S_q, v_head_dim)
        q = torch.cat([q_t_c, q_t_r], dim=-1) # (B, H, S_q, rope_dim + v_head_dim)

        k_t_c = k_t_c.view(B, S_k, self.n_heads, self.v_head_dim).transpose(1, 2) # (B, H, S_k, v_head_dim)
        k_t_r = k_t_r.expand(B, self.n_heads, S_k, self.rope_dim) # (B, H, S_k, rope_dim)
        k = torch.cat([k_t_c, k_t_r], dim=-1) # (B, H, S_k, rope_dim + v_head_dim)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (math.sqrt(self.rope_dim) + math.sqrt(self.v_head_dim)) # (B, H, S_q, S_k)
        scores = self.dropout(scores)

        if mask is not None:
            scores = scores.masked_fill(mask[:, None, None, :]==0, float("-inf"))
        
        attn_scores = F.softmax(scores, dim=-1) # (B, H, S_q, S_k)
        attn_scores = self.dropout(attn_scores)

        v_t_c = v_t_c.view(B, S_k, self.n_heads, self.v_head_dim).transpose(1, 2) # (B, H, S_k, v_head_dim)
        context_vector = torch.matmul(attn_scores, v_t_c) # (B, H, S_q, v_head_dim)
        context_vector = context_vector.transpose(1, 2).contiguous().view(B, S_q, D) # (B, S_q, up_dim)

        out = self.fc(context_vector) # (B, S_q, D)
        out = self.res_dropout(out)
        return out