In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(emb_dim))

    def forward(self, x):
        # x: (..., D)
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
        x_norm = x / rms
        return self.weight * x_norm

In [None]:
class FeedForward(nn.Module):
    """
    SwiGLU MLP: (x W1) ⊙ SiLU(x W3) @ W2
    LLaMA typically uses d_ff ≈ 4/3 * emb_dim (rounded).
    """
    def __init__(self, emb_dim, d_ff):
        super().__init__()
        self.W1 = nn.Linear(emb_dim, d_ff, bias=False)  # gate
        self.W3 = nn.Linear(emb_dim, d_ff, bias=False)  # up
        self.W2 = nn.Linear(d_ff, emb_dim, bias=False)  # down

    def forward(self, x):
        return self.W2(self.W1(x) * F.silu(self.W3(x)))

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, emb_dim, n_heads, kv_groups, d_ff, dropout=0.0):
        super().__init__()
        self.attn_norm = RMSNorm(emb_dim)
        self.attn = MultiGroupAttention(emb_dim, n_heads, kv_groups, dropout=dropout)
        self.ff_norm = RMSNorm(emb_dim)
        self.ff = FeedForward(emb_dim, d_ff)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):

        # --- Attention sub-layer (pre-norm) ---
        x_norm_attn = self.attn_norm(x)                    # normalize inputs for attention
        attn_out = self.attn(x_norm_attn, kv=None, mask=mask)  # self-attention with mask
        x = x + self.dropout(attn_out)                     # residual connection

        # --- Feed-forward sub-layer (pre-norm) ---
        x_norm_ff = self.ff_norm(x)                        # normalize inputs for MLP
        ff_out = self.ff(x_norm_ff)                        # SwiGLU MLP
        x = x + self.dropout(ff_out)                       # residual connection

        return x


In [None]:
class LlamaForCausalLM(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size
        self.emb_dim = config.emb_dim
        self.max_seq_len = config.max_seq_len

        self.tok_emb = nn.Embedding(config.vocab_size, config.emb_dim)
        self.blocks = nn.ModuleList([
            DecoderBlock(config.emb_dim, config.n_heads, config.d_ff,
                         dropout=config.dropout, kv_groups=config.kv_groups)
            for _ in range(config.n_layers)
        ])
        self.final_norm = RMSNorm(config.emb_dim)
        self.lm_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False)

        # tie weights (common practice in LLaMA)
        self.lm_head.weight = self.tok_emb.weight

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=0.02)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, mean=0.0, std=0.02)
            elif isinstance(m, RMSNorm):
                nn.init.ones_(m.weight)

    def _build_attention_mask(self, input_ids, key_padding_mask=None):
        B, S = input_ids.shape
        device = input_ids.device
        causal = torch.tril(torch.ones(S, S, device=device, dtype=torch.bool))  # (S,S)
        if key_padding_mask is None:
            return causal  # (S,S)
        # key_padding_mask True=keep; shape (B,S)
        return (key_padding_mask.to(torch.bool).unsqueeze(1) & causal.unsqueeze(0))  # (B, S_q, S_k)

    def forward(self, input_ids, key_padding_mask=None):
        B, S = input_ids.shape
        if S > self.max_seq_len:
            raise ValueError(f"sequence length {S} exceeds max_seq_len {self.max_seq_len}")

        attn_mask = self._build_attention_mask(input_ids, key_padding_mask)

        x = self.tok_emb(input_ids)  # (B,S,D)
        for blk in self.blocks:
            x = blk(x, mask=attn_mask)
        x = self.final_norm(x)
        logits = self.lm_head(x)     # (B,S,V)
        return logits