In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import json
import matplotlib.pyplot as plt
import numpy as np

# Set seeds for reproducibility
torch.manual_seed(42)
torch.set_grad_enabled(False)  # No gradients for this didactic walkthrough

# Load standard Persian-English data
with open("../.data/en_fa_train.jsonl", "r", encoding="utf-8") as f:
    data = [json.loads(line) for line in f]

# We will use a real sentence for our token embeddings
sample_idx = 0
text = data[sample_idx]["input"]
print(f"Using sample sentence: {text}")

Using sample sentence: I invited my foolish friend Jay around for tennis because I thought he'd make me look good.


## 1. The Feed-Forward Network

The Feed-Forward Network (FFN) is a simple MLP applied independently to each position. Its main job is to expand the state space and allow for non-linear feature mixing.

Math: $\mathrm{FFN}(x) = \sigma(xW_1 + b_1)W_2 + b_2$

In [2]:
class PositionWiseFFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        # W1: Expansion layer (d_model -> d_ff)
        self.w_1 = nn.Linear(d_model, d_ff)
        # W2: Projection layer (d_ff -> d_model)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.gelu = nn.GELU()
        
    def forward(self, x):
        # x shape: (seq_len, d_model)
        return self.w_2(self.gelu(self.w_1(x)))

## 2. Layer Normalization (LayerNorm)

LayerNorm stabilizes training by normalizing the feature dimension for each token. Unlike Batchnorm, it works identically during training and inference and is well-suited for varying sequence lengths.

Math: $\mathrm{LayerNorm}(x) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \odot \gamma + \beta$

In [4]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        
    def forward(self, x):
        # Normalize across the feature dimension (last dim)
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)
        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_hat + self.beta

## 3. Advanced: Modern Variants

Modern architectures (Llama, Gemma, Mistral) often replace standard LayerNorm with RMSNorm for speed, and replace GeLU/ReLU with SwiGLU for better parameter efficiency.

### 3.1 RMSNorm

RMSNorm simply scales the input by its root-mean-square. It removes the "centering" (mean subtraction) step of LayerNorm.

Math: $\mathrm{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d} \sum_i x_i^2 + \epsilon}} \odot \gamma$

In [6]:
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(d_model))
        
    def forward(self, x):
        # Use root mean square for normalization
        rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return (x / rms) * self.gamma

### 3.2 SwiGLU

SwiGLU is a gated activation function that uses the Swish activation function (SiLU in PyTorch).

Math: $\mathrm{SwiGLU}(x, W, V) = \mathrm{SiLU}(xW) \otimes xV$

In [7]:
class SwiGLU(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        # In SwiGLU, we project into two parallel paths (W and V)
        self.w_gate = nn.Linear(d_model, d_ff, bias=False)
        self.w_up = nn.Linear(d_model, d_ff, bias=False)
        self.w_down = nn.Linear(d_ff, d_model, bias=False)
        
    def forward(self, x):
        # Gated Linear Unit logic
        gate = F.silu(self.w_gate(x))
        up = self.w_up(x)
        return self.w_down(gate * up)

## 4. Residual Connections

Finally, we wrap everything in residual connections to ensure gradient health.

Math:  = $x + \mathrm{Sublayer}(x)$

In [8]:
# Demo of a single Transformer sub-block (Pre-norm)
d_model = 512
d_ff = 2048

ffn_sublayer = PositionWiseFFN(d_model, d_ff)
norm = LayerNorm(d_model)

x = torch.randn(1, 10, d_model)  # (batch, seq, d_model)

# Standard Pre-norm architecture:
# x = x + FFN(Norm(x))
output = x + ffn_sublayer(norm(x))

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print("Successfully computed residual FFN block.")

Input shape: torch.Size([1, 10, 512])
Output shape: torch.Size([1, 10, 512])
Successfully computed residual FFN block.
