In [1]:
import torch
import math
import matplotlib.pyplot as plt
import numpy as np

## Position-wise Feed-Forward Network
Equation: $\mathrm{FFN}(x)=W_2\sigma(W_1 x + b_1)+b_2$
We'll implement a tiny FFN and show how it transforms token vectors.

In [2]:
import json

# Load English-Farsi translation data
with open('../.data/en_fa_train.jsonl', 'r', encoding='utf-8') as f:
  samples = [json.loads(line) for line in f]

# Demo: apply FFN to embeddings of sentences
from transformers import AutoTokenizer, AutoModel

# Get embeddings from a trained model for a sentence
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
model = AutoModel.from_pretrained('distilbert-base-uncased')

# Use first English sentence from the dataset
text = samples[0]['input'] # "Pack your stuff."
tokens = tokenizer(text, return_tensors='pt')

with torch.no_grad():
  outputs = model(**tokens, output_hidden_states=True)
  # Get embeddings from the last hidden state
  embeddings = outputs.last_hidden_state[0] # (seq_len, 768)

print(f"Text: {text}")
print(f"Embeddings shape from pretrained model: {embeddings.shape}")

# Define SimpleFFN
class SimpleFFN(torch.nn.Module):
  def __init__(self, d_model, d_ff, activation='gelu'):
    super().__init__()
    self.W1 = torch.nn.Linear(d_model, d_ff)
    self.W2 = torch.nn.Linear(d_ff, d_model)
    self.activation = activation

  def gelu(self, x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

  def forward(self, x):
    if self.activation == 'gelu':
      hidden = self.gelu(self.W1(x))
    else:
      hidden = torch.relu(self.W1(x))
    return self.W2(hidden)

# Now apply our FFN to these embeddings
d_model = embeddings.shape[1] # 768
d_ff = d_model * 4 # Common ratio
ffn = SimpleFFN(d_model, d_ff, activation='gelu')

with torch.no_grad():
  y = ffn(embeddings)

print(f"FFN output shape: {y.shape}")
print(f"The FFN took token embeddings and produced refined representations")

Text: I invited my foolish friend Jay around for tennis because I thought he'd make me look good.
Embeddings shape from pretrained model: torch.Size([22, 768])
FFN output shape: torch.Size([22, 768])
The FFN took token embeddings and produced refined representations


## Layer Normalization (from scratch)
LayerNorm normalizes across features for each token: $\hat{x}=\frac{x-\mu}{\sqrt{\sigma^2+\epsilon}}$ followed by scale and shift.

In [3]:
class SimpleLayerNorm(torch.nn.Module):
  def __init__(self, d_model, eps=1e-5):
    super().__init__()
    self.eps = eps
    self.gamma = torch.nn.Parameter(torch.ones(d_model))
    self.beta = torch.nn.Parameter(torch.zeros(d_model))

  def forward(self, x):
    # x: (seq_len, d_model)
    mean = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, unbiased=False, keepdim=True)
    x_hat = (x - mean) / torch.sqrt(var + self.eps)
    return self.gamma * x_hat + self.beta

# Demo LayerNorm - apply to embeddings from the sentence above
ln = SimpleLayerNorm(d_model)
x_ln = ln(embeddings)
print('LayerNorm output shape:', x_ln.shape)
print('Per-token mean (after LN):', x_ln.mean(dim=-1))

LayerNorm output shape: torch.Size([22, 768])
Per-token mean (after LN): tensor([ 1.1176e-08, -1.8626e-09,  0.0000e+00, -8.6923e-09,  2.4835e-09,
        -6.2088e-10,  6.2088e-09,  8.6923e-09, -6.8297e-09,  1.1797e-08,
         2.4835e-09, -4.9671e-09,  1.2418e-09, -2.4835e-09, -1.1331e-08,
        -2.4835e-09, -5.5879e-09, -4.3462e-09,  0.0000e+00, -8.5371e-09,
         4.1910e-09,  0.0000e+00], grad_fn=<MeanBackward1>)


## RMSNorm (from scratch)
RMSNorm divides by the root-mean-square of features (no centering).

In [4]:
class SimpleRMSNorm(torch.nn.Module):
  def __init__(self, d_model, eps=1e-8):
    super().__init__()
    self.eps = eps
    self.g = torch.nn.Parameter(torch.ones(d_model))

  def forward(self, x):
    # x: (seq_len, d_model)
    rms = torch.sqrt((x * x).mean(dim=-1, keepdim=True) + self.eps)
    return (x / rms) * self.g

# Demo RMSNorm - apply to embeddings from the sentence above
rms = SimpleRMSNorm(d_model)
x_input = embeddings # Use embeddings from cell above
x_r = rms(x_input)

# Compute per-token RMS after removing the learned scale `g`.
# Ensure proper broadcasting by unsqueezing `g` to shape (1, d_model).
if hasattr(rms, 'g'):
  unscaled = x_r / rms.g.unsqueeze(0)
else:
  unscaled = x_r

per_token_rms = torch.sqrt((unscaled ** 2).mean(dim=-1))

print('RMSNorm output shape:', x_r.shape)
print('Per-token RMS (should be close to 1 after scaling):', per_token_rms)

RMSNorm output shape: torch.Size([22, 768])
Per-token RMS (should be close to 1 after scaling): tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SqrtBackward0>)


## Residual Connection Example
Demonstrate how residual + normalization + FFN combine in a transformer block pattern.

In [5]:
# Simple transformer block (FFN + LayerNorm + residual)
class SimpleBlock(torch.nn.Module):
  def __init__(self, d_model, d_ff):
    super().__init__()
    self.ln1 = SimpleLayerNorm(d_model)
    self.ffn = SimpleFFN(d_model, d_ff, activation='gelu')
    self.ln2 = SimpleLayerNorm(d_model)

  def forward(self, x):
    # x: (seq_len, d_model)
    # Usually attention goes before this; here we demo FFN path
    residual = x
    x = self.ln1(x)
    x = self.ffn(x)
    x = x + residual
    x = self.ln2(x)
    return x

# Demo block - apply to embeddings from above
block = SimpleBlock(d_model, d_ff)
y = block(embeddings)
print('Block output shape:', y.shape)

Block output shape: torch.Size([22, 768])


## Final: Library Comparison
Show how `torch.nn.LayerNorm` and a simple feed-forward from `transformers` correspond to our implementations.

In [6]:
# Compare with PyTorch LayerNorm and a simple nn.Sequential FFN
ln_torch = torch.nn.LayerNorm(d_model)
ffn_torch = torch.nn.Sequential(
  torch.nn.Linear(d_model, d_ff),
  torch.nn.GELU(),
  torch.nn.Linear(d_ff, d_model),
)
x = torch.randn(3, d_model)
print('PyTorch LayerNorm output shape:', ln_torch(x).shape)
print('PyTorch FFN output shape:', ffn_torch(x).shape)

PyTorch LayerNorm output shape: torch.Size([3, 768])
PyTorch FFN output shape: torch.Size([3, 768])
