# Transformer From Scratch (PyTorch)

This notebook builds a **mini Transformer** (the key components) from scratch using **PyTorch**:

- Scaled Dot-Product Attention
- Multi-Head Attention
- Positional Encoding
- Transformer Encoder & Decoder layers (simple versions)

This is an educational implementation (not optimized). It demonstrates the main ideas so you can experiment and extend it.

## 1. Imports & Setup

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

## 2. Scaled Dot-Product Attention

Given queries Q, keys K and values V, attention is:
\[ \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]
We add an optional mask to prevent attention to certain positions (useful for decoder).

In [None]:
def scaled_dot_product_attention(q, k, v, mask=None):
    """
    q: (batch, heads, seq_len_q, depth)
    k: (batch, heads, seq_len_k, depth)
    v: (batch, heads, seq_len_v, depth)
    mask: (batch, 1, 1, seq_len_k) or None
    """
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (batch, heads, seq_q, seq_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    attn = F.softmax(scores, dim=-1)
    output = torch.matmul(attn, v)  # (batch, heads, seq_q, depth)
    return output, attn

# quick sanity check shapes
B, H, Lq, Lk, D = 2, 4, 5, 6, 16
q = torch.rand(B, H, Lq, D)
k = torch.rand(B, H, Lk, D)
v = torch.rand(B, H, Lk, D)
out, att = scaled_dot_product_attention(q, k, v)
print('out shape', out.shape, 'att shape', att.shape)

## 3. Multi-Head Attention

We project inputs to multiple heads, apply scaled dot-product attention in each head, then concatenate and project back.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, 'd_model must be divisible by num_heads'
        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = d_model // num_heads

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.dense = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        # x: (batch, seq_len, d_model) -> (batch, heads, seq_len, depth)
        B, S, _ = x.size()
        x = x.view(B, S, self.num_heads, self.depth).transpose(1, 2)
        return x

    def combine_heads(self, x):
        # x: (batch, heads, seq_len, depth) -> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous()
        B, S, H, D = x.size()
        x = x.view(B, S, H * D)
        return x

    def forward(self, q, k, v, mask=None):
        # q,k,v: (batch, seq_len, d_model)
        B = q.size(0)
        q = self.split_heads(self.wq(q))
        k = self.split_heads(self.wk(k))
        v = self.split_heads(self.wv(v))
        # mask expected shape: (batch, 1, 1, seq_k)
        attn_output, attn_weights = scaled_dot_product_attention(q, k, v, mask)
        concat = self.combine_heads(attn_output)
        out = self.dense(concat)
        return out, attn_weights

# quick shape check
mha = MultiHeadAttention(d_model=64, num_heads=8)
x = torch.rand(2, 10, 64)
o, w = mha(x, x, x)
print('mha out', o.shape, 'att weights', w.shape)

## 4. Positional Encoding
Transformers need positional information since they have no recurrence or convolution.
We'll implement the sinusoidal positional encoding from the original paper.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # shape (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len]
        return x

# check
pe = PositionalEncoding(d_model=64)
test = torch.zeros(1, 10, 64)
print(pe(test).shape)

## 5. Transformer Encoder Layer
Encoder = Multi-Head Attention + Add&Norm + Feedforward + Add&Norm

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
    def forward(self, x):
        return self.net(x)

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask=None):
        attn_out, _ = self.mha(x, x, x, src_mask)
        x = self.norm1(x + self.dropout(attn_out))
        ff_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ff_out))
        return x

# quick check
enc_layer = EncoderLayer(d_model=64, num_heads=8, d_ff=256)
dummy = torch.rand(2, 12, 64)
out = enc_layer(dummy)
print('encoder out', out.shape)

## 6. Transformer Decoder Layer (basic)
Decoder includes masked self-attention, encoder-decoder attention, and feedforward.

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_mha = MultiHeadAttention(d_model, num_heads)
        self.enc_dec_mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, tgt_mask=None, memory_mask=None):
        # Masked self-attention (prevent looking at future tokens)
        self_attn_out, _ = self.self_mha(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn_out))
        # Encoder-Decoder attention
        enc_dec_out, _ = self.enc_dec_mha(x, enc_output, enc_output, memory_mask)
        x = self.norm2(x + self.dropout(enc_dec_out))
        ff_out = self.ffn(x)
        x = self.norm3(x + self.dropout(ff_out))
        return x

# sanity check
dec_layer = DecoderLayer(d_model=64, num_heads=8, d_ff=256)
tgt = torch.rand(2, 8, 64)
mem = torch.rand(2, 12, 64)
out = dec_layer(tgt, mem)
print('decoder out', out.shape)

## 7. Small Transformer Encoder / Decoder Stacks
Assemble layers into encoder and decoder stacks and add input/output embeddings + positional encodings.

In [None]:
class MiniTransformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=64, num_heads=8, d_ff=256, num_enc_layers=2, num_dec_layers=2, max_len=100):
        super().__init__()
        self.src_embed = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.enc_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff) for _ in range(num_enc_layers)])
        self.dec_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff) for _ in range(num_dec_layers)])
        self.output_proj = nn.Linear(d_model, tgt_vocab_size)

    def make_tgt_mask(self, tgt_seq_len):
        # subsequent mask for causal decoding: shape (1, 1, tgt_len, tgt_len)
        mask = torch.tril(torch.ones((tgt_seq_len, tgt_seq_len), device=device)).unsqueeze(0).unsqueeze(0)
        return mask

    def forward(self, src, tgt):
        # src: (batch, src_len), tgt: (batch, tgt_len)
        src_x = self.pos_enc(self.src_embed(src))
        for layer in self.enc_layers:
            src_x = layer(src_x)
        memory = src_x

        tgt_x = self.pos_enc(self.tgt_embed(tgt))
        tgt_mask = self.make_tgt_mask(tgt.size(1))
        for layer in self.dec_layers:
            tgt_x = layer(tgt_x, memory, tgt_mask=tgt_mask, memory_mask=None)
        logits = self.output_proj(tgt_x)  # (batch, tgt_len, tgt_vocab)
        return logits

# quick forward pass with dummy token ids
src_vocab, tgt_vocab = 100, 100
model = MiniTransformer(src_vocab, tgt_vocab).to(device)
src = torch.randint(0, src_vocab, (2, 12)).to(device)
tgt = torch.randint(0, tgt_vocab, (2, 8)).to(device)
out = model(src, tgt)
print('MiniTransformer output shape', out.shape)  # (batch, tgt_len, tgt_vocab)

## 8. Toy Example: Forward Pass & Loss
We'll run a single training step on random data to illustrate usage. This is NOT real training.

In [None]:
criterion = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr=1e-3)

model.train()
optim.zero_grad()
logits = model(src, tgt[:, :-1])  # teacher forcing: input tgt without last token
target = tgt[:, 1:].to(device)    # target is next tokens
loss = criterion(logits.view(-1, tgt_vocab), target.contiguous().view(-1))
loss.backward()
optim.step()
print('toy training loss:', loss.item())

## 9. Notes & Next Steps
- This is a compact educational implementation. Real-world Transformers include many optimizations (layer dropout settings, careful initialization, label smoothing, learning rate schedulers like AdamW + warmup, positional embedding variants, etc.).
- To extend:
  - Implement masking for padded tokens in encoder/decoder.
  - Add attention visualization.
  - Train on a real dataset (machine translation toy dataset or tokenize and use small corpus).
  - Compare to `torch.nn.Transformer` or Hugging Face `transformers` implementations.

✅ You now have a readable from-scratch Transformer you can experiment with!