# Architecture

In [1]:
import torch
import torch.nn as nn

### RMSNorm
- Standard LayerNorm does two things - centers the data then scales by the standard deviation.  
- RMSNorm simplifies this by skipping the mean centering and just normalizes by the root mean square of the values.
- It is cheaper in temrs of compute and just as effective emperically.  
$$\text{RMS}(x) = \sqrt{\frac{1}{n}\sum_{i=1}^{n} x_i^2}$$

In [2]:
class RMSNorm(nn.Module):
    def __init(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True):
        super().__init__()
        self.eps = eps
        self.qwen3_compatible = qwen3_compatible
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeroes(emb_dim)) if bias else None

    def forward(self, x):
        input_dtype = x.dtype

        if self.qwen3_compatible:
            x = x.to(torch.float32)

        variance = x.pow(2).mean(dim=-1, keepdim=True)
        norm_x = x * torch.rsqrt(variance + self.eps)
        norm_x = norm_x * self.scale

        if self.shift is not None:
            norm_x = norm_x + self.shift
        
        return norm_x.to(input_dtype)


### ROPE
- Transformers process tokens in parallel, so they have no inherent information about token positions.
- Old GPT does this via learned position params. But this causes position info to degrade through the layers.
- RoPE encodes absolute position information of a token as rotation applied to its query and key vectors. When the dot product between Q and K is computed, the result naturally depends on the relative distance between the two tokens.
<div align="center">
   <img src="./assets/rope.png" alt="RoPE Illustration" width="350px"/><br>
   <sub><b>Figure:</b> A gentle introduction to Rotary Position Embedding (<a href="https://krasserm.github.io/2022/12/13/rotary-position-embedding/">image source</a>)</sub>
</div>

In [3]:
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
    assert head_dim % 2 == 0, "Embedding dim must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))

    # Generate position indices
    positions = torch.arange(context_length, dtype=dtype)

    # Compute the angles
    angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head // 2)

    # Expand angles to match the head_dim
    angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)

    # Precompute sin and cos
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin


def apply_rope(x, cos, sin):
    # x : (batch-size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dim must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2] # First half
    x2 = x[..., head_dim // 2 :] # Second half

    # Adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotations
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    # It's ok to use lower-precision after applying cos and sin rotation
    return x_rotated.to(dtype=x.dtype)


### FeedForward
**Standard MLP**
- There are two linear layers. 
- The input vector is first projected out into a higher dimension and then a ReLU activation is applied on that intermediate vector. ReLU hard suppresses negative values to 0. 
- Finally the intermediate vector is projected back to the original input dimension.

**SwiGLU in Qwen**
- There are three linear layers. 
- The input vector is projected out independently into two higher dimension intermediate vectors - the candidate and the gate. 
- SilU (x * sigmoid(x)) is applied on the gate and that product is multiplied with the candidate. 
- SilU unlike ReLU allows for negative values to exist as rather small negative values than 0. 
- Essentially, this creates a gating mechanism whereby the features of the candidate vector are weighted by the gating vector (element-wise multiplication). 
- Finally their product vector is projected back into the input dimension.
  
<div align="center">
   <img src="./assets/activation.png" alt="Activations" width="350px"/><br>
   <sub><b>Figure:</b> SiLU vs ReLU</sub>
</div>

In [4]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)


    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = nn.functional.silu(x_fc1) * x_fc2
        return self.fc3(x)

### GQA
**Self Attention - Scaled Dot Product**
1. Initialize QKV matices by muliplying input embedding vector with weights
2. Dot product QK to get unnormalized attention weights
3. Scale the unnormalized attention weights (divide by root of the dim of k)
4. Normalize using softmax
5. Finally multiply the normalized attention weights by V to obtain the context vector
<div align="center">
   <img src="./assets/self-attention.webp" alt="Self Attention" width="350px"/><br>
   <sub><b>Figure:</b> Self-Attention for a token that is second in a sequence (<a href="https://magazine.sebastianraschka.com/p/understanding-and-coding-self-attention">image source</a>)</sub>
</div>

**Adding Causal Masking and MHA will get you to the original Attention is All You Need Paper**

Unlike MHA, where each head also has its own set of keys and values, to reduce memory usage, GQA groups multiple heads to share the same key and value projections.
<div align="center">
   <img src="./assets/mha-gqa.webp" alt="MHA vs. GQA" width="350px"/><br>
   <sub><b>Figure:</b> MHA vs. GQA (<a href="https://github.com/rasbt/LLMs-from-scratch/blob/main/ch04/04_gqa/README.md">image source</a>)</sub>
</div>