In [1]:
import torch
import math
import matplotlib.pyplot as plt
import numpy as np

## Position-wise Feed-Forward Network
Equation: $\mathrm{FFN}(x)=W_2\sigma(W_1 x + b_1)+b_2$
We'll implement a tiny FFN and show how it transforms token vectors.

In [2]:
class SimpleFFN(torch.nn.Module):
    def __init__(self, d_model, d_ff, activation='gelu'):
        super().__init__()
        self.W1 = torch.nn.Linear(d_model, d_ff)
        self.W2 = torch.nn.Linear(d_ff, d_model)
        self.activation = activation

    def gelu(self, x):
        # approximate GeLU
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

    def silu(self, x):
        return x * torch.sigmoid(x)

    def swiglu(self, x):
        # SwiGLU style: (W_a x) * SiLU(W_b x) implemented by splitting channels
        # assume input x has shape (seq_len, d_model)
        # to keep things simple: apply W1 then split features in half along dim=1
        z = self.W1(x)
        a, b = z.chunk(2, dim=-1)
        return a * torch.sigmoid(b)  # SiLU approx via sigmoid*b (teaching simplification)

    def forward(self, x):
        if self.activation == 'gelu':
            hidden = self.gelu(self.W1(x))
        elif self.activation == 'relu':
            hidden = torch.relu(self.W1(x))
        elif self.activation == 'swiglu':
            # for Swiglu, we expect d_ff to be 2 * (d_model * factor)
            hidden = self.swiglu(x)
        else:
            hidden = torch.relu(self.W1(x))
        return self.W2(hidden)

# Demo: apply FFN to tiny sequence
d_model = 16
d_ff = 64
ffn = SimpleFFN(d_model, d_ff, activation='gelu')
x = torch.randn(4, d_model)  # seq_len=4
y = ffn(x)
print('Input shape:', x.shape)
print('Output shape:', y.shape)

Input shape: torch.Size([4, 16])
Output shape: torch.Size([4, 16])


## Layer Normalization (from scratch)
LayerNorm normalizes across features for each token: $\hat{x}=\frac{x-\mu}{\sqrt{\sigma^2+\epsilon}}$ followed by scale and shift.

In [3]:
class SimpleLayerNorm(torch.nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = torch.nn.Parameter(torch.ones(d_model))
        self.beta = torch.nn.Parameter(torch.zeros(d_model))

    def forward(self, x):
        # x: (seq_len, d_model)
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, unbiased=False, keepdim=True)
        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_hat + self.beta

# Demo LayerNorm
ln = SimpleLayerNorm(d_model)
x = torch.randn(3, d_model)
x_ln = ln(x)
print('LayerNorm output shape:', x_ln.shape)
print('Per-token mean (after LN):', x_ln.mean(dim=-1))

LayerNorm output shape: torch.Size([3, 16])
Per-token mean (after LN): tensor([-3.7253e-08, -7.4506e-09,  2.6077e-08], grad_fn=<MeanBackward1>)


## RMSNorm (from scratch)
RMSNorm divides by the root-mean-square of features (no centering).

In [5]:
class SimpleRMSNorm(torch.nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.g = torch.nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        # x: (seq_len, d_model)
        rms = torch.sqrt((x * x).mean(dim=-1, keepdim=True) + self.eps)
        return (x / rms) * self.g

# Demo RMSNorm
rms = SimpleRMSNorm(d_model)

# Reuse an existing `x` if available (from previous cells); otherwise create a new random input.
if 'x' in globals():
    x_input = x
else:
    x_input = torch.randn(3, d_model)

x_r = rms(x_input)

# Compute per-token RMS after removing the learned scale `g`.
# Ensure proper broadcasting by unsqueezing `g` to shape (1, d_model).
if hasattr(rms, 'g'):
    unscaled = x_r / rms.g.unsqueeze(0)
else:
    unscaled = x_r

per_token_rms = torch.sqrt((unscaled ** 2).mean(dim=-1))

print('RMSNorm output shape:', x_r.shape)
print('Per-token RMS (should be close to 1 after scaling):', per_token_rms)

RMSNorm output shape: torch.Size([3, 16])
Per-token RMS (should be close to 1 after scaling): tensor([1.0000, 1.0000, 1.0000], grad_fn=<SqrtBackward0>)


## Residual Connection Example
Demonstrate how residual + normalization + FFN combine in a transformer block pattern.

In [6]:
# Simple transformer block (FFN + LayerNorm + residual)
class SimpleBlock(torch.nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.ln1 = SimpleLayerNorm(d_model)
        self.ffn = SimpleFFN(d_model, d_ff, activation='gelu')
        self.ln2 = SimpleLayerNorm(d_model)

    def forward(self, x):
        # x: (seq_len, d_model)
        # Usually attention goes before this; here we demo FFN path
        residual = x
        x = self.ln1(x)
        x = self.ffn(x)
        x = x + residual
        x = self.ln2(x)
        return x

# Demo block
block = SimpleBlock(d_model, d_ff)
x = torch.randn(5, d_model)
y = block(x)
print('Block output shape:', y.shape)

Block output shape: torch.Size([5, 16])


## Final: Library Comparison
Show how `torch.nn.LayerNorm` and a simple feed-forward from `transformers` correspond to our implementations.

In [7]:
# Compare with PyTorch LayerNorm and a simple nn.Sequential FFN
ln_torch = torch.nn.LayerNorm(d_model)
ffn_torch = torch.nn.Sequential(
    torch.nn.Linear(d_model, d_ff),
    torch.nn.GELU(),
    torch.nn.Linear(d_ff, d_model),
)
x = torch.randn(3, d_model)
print('PyTorch LayerNorm output shape:', ln_torch(x).shape)
print('PyTorch FFN output shape:', ffn_torch(x).shape)

PyTorch LayerNorm output shape: torch.Size([3, 16])
PyTorch FFN output shape: torch.Size([3, 16])
