# Transformers from Scratch in PyTorch

This notebook implements a Transformer encoder block from scratch, covering:
1. Multi-Head Self-Attention
2. Position-wise Feedforward Network
3. Transformer Encoder Block
4. Complete Transformer Encoder
5. Example usage and testing

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Multi-Head Self-Attention

Core mechanism of transformers. Key ideas:
- Projects input to Q, K, V
- Computes scaled dot-product attention: softmax(QK^T / √d_k)V
- Multiple heads learn different attention patterns
- Concatenate heads and project to output

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # Final output projection
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        # x: (batch, seq_len, d_model)
        B, T, _ = x.size()

        # Project to Q, K, V
        q = self.W_q(x)  # (B, T, d_model)
        k = self.W_k(x)
        v = self.W_v(x)

        # Split into heads: (B, T, d_model) -> (B, n_heads, T, d_head)
        def split_heads(t):
            return t.view(B, T, self.n_heads, self.d_head).transpose(1, 2)

        q = split_heads(q)
        k = split_heads(k)
        v = split_heads(v)

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
        # scores: (B, n_heads, T, T)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)  # (B, n_heads, T, T)
        context = torch.matmul(attn, v)    # (B, n_heads, T, d_head)

        # Concatenate heads: (B, n_heads, T, d_head) -> (B, T, d_model)
        context = context.transpose(1, 2).contiguous().view(B, T, self.d_model)

        # Final linear projection
        out = self.W_o(context)
        return out

## 2. Position-wise Feedforward Network

Two-layer MLP applied independently to each position:
- Expands to d_ff (typically 4 × d_model)
- Applies ReLU
- Projects back to d_model

In [None]:
class PositionwiseFFN(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = self.fc1(x)       # (B, T, d_ff)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)       # (B, T, d_model)
        return x

## 3. Transformer Encoder Block

Standard encoder block structure:
1. Multi-head self-attention + residual + layer norm
2. Feedforward + residual + layer norm

This is the fundamental building block that stacks to form deep transformers.

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadSelfAttention(d_model, n_heads)
        self.ffn = PositionwiseFFN(d_model, d_ff, dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # x: (batch, seq_len, d_model)

        # 1) Self-attention + residual + norm
        attn_out = self.self_attn(x, mask=mask)
        x = x + self.dropout1(attn_out)
        x = self.norm1(x)

        # 2) Feedforward + residual + norm
        ffn_out = self.ffn(x)
        x = x + self.dropout2(ffn_out)
        x = self.norm2(x)

        return x

## 4. Positional Encoding

Since transformers have no built-in sense of position (unlike RNNs/CNNs), we add sinusoidal position encodings:
- PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
- PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                             (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

## 5. Complete Transformer Encoder

Putting it all together:
1. Token embedding + positional encoding
2. Stack of N transformer blocks
3. Optional output projection head

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, n_heads: int, 
                 d_ff: int, n_layers: int, max_len: int = 5000, 
                 dropout: float = 0.1):
        super().__init__()
        
        self.d_model = d_model
        
        # Token embedding
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.dropout = nn.Dropout(dropout)
        
        # Stack of transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
        
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        # x: (batch, seq_len) - token indices
        
        # Embed and add positional encoding
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # Pass through transformer blocks
        for block in self.blocks:
            x = block(x, mask)
        
        x = self.norm(x)
        return x

## 6. Testing and Usage Example

Let's create a small transformer and test it with random data.

In [None]:
# Model hyperparameters
vocab_size = 10000
d_model = 512
n_heads = 8
d_ff = 2048
n_layers = 6
max_len = 512
dropout = 0.1

# Create model
model = TransformerEncoder(
    vocab_size=vocab_size,
    d_model=d_model,
    n_heads=n_heads,
    d_ff=d_ff,
    n_layers=n_layers,
    max_len=max_len,
    dropout=dropout
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"\nModel architecture:")
print(model)

In [None]:
# Test forward pass
batch_size = 32
seq_len = 128

# Random token indices
x = torch.randint(0, vocab_size, (batch_size, seq_len))

# Forward pass
model.eval()
with torch.no_grad():
    output = model(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Expected output shape: (batch_size={batch_size}, seq_len={seq_len}, d_model={d_model})")

## 7. Testing Individual Components

Let's test each component separately to understand their behavior.

In [None]:
# Test Multi-Head Attention
print("=== Testing Multi-Head Self-Attention ===")
mha = MultiHeadSelfAttention(d_model=512, n_heads=8)
x_test = torch.randn(2, 10, 512)  # (batch=2, seq=10, d_model=512)
attn_out = mha(x_test)
print(f"Input: {x_test.shape}")
print(f"Output: {attn_out.shape}")
print(f"Parameters: {sum(p.numel() for p in mha.parameters()):,}\n")

In [None]:
# Test Feedforward Network
print("=== Testing Position-wise FFN ===")
ffn = PositionwiseFFN(d_model=512, d_ff=2048)
x_test = torch.randn(2, 10, 512)
ffn_out = ffn(x_test)
print(f"Input: {x_test.shape}")
print(f"Output: {ffn_out.shape}")
print(f"Parameters: {sum(p.numel() for p in ffn.parameters()):,}\n")

In [None]:
# Test Transformer Block
print("=== Testing Transformer Block ===")
block = TransformerBlock(d_model=512, n_heads=8, d_ff=2048)
x_test = torch.randn(2, 10, 512)
block_out = block(x_test)
print(f"Input: {x_test.shape}")
print(f"Output: {block_out.shape}")
print(f"Parameters: {sum(p.numel() for p in block.parameters()):,}\n")

In [None]:
# Test Positional Encoding
print("=== Testing Positional Encoding ===")
pos_enc = PositionalEncoding(d_model=512, max_len=100)
x_test = torch.randn(2, 50, 512)
pos_out = pos_enc(x_test)
print(f"Input: {x_test.shape}")
print(f"Output: {pos_out.shape}")
print(f"Positional encoding buffer shape: {pos_enc.pe.shape}")

## 8. Visualization (Optional)

Visualize attention patterns and positional encodings.

In [None]:
import matplotlib.pyplot as plt

# Visualize positional encoding
pe = PositionalEncoding(d_model=128, max_len=100)
pos_encoding = pe.pe[0].numpy()  # (max_len, d_model)

plt.figure(figsize=(12, 6))
plt.imshow(pos_encoding.T, aspect='auto', cmap='RdBu')
plt.colorbar()
plt.xlabel('Position')
plt.ylabel('Dimension')
plt.title('Positional Encoding Visualization')
plt.tight_layout()
plt.show()

## Next Steps

To extend this implementation:

1. **Add a decoder**: Implement masked self-attention and cross-attention for encoder-decoder architecture
2. **Add task heads**: Classification head, language modeling head, etc.
3. **Training loop**: Implement training with a real dataset
4. **Advanced features**: 
   - Relative positional encodings
   - Pre-layer normalization
   - Different attention variants (sparse, linear, etc.)
5. **Optimization**: Flash attention, gradient checkpointing, mixed precision

For robotics applications, you might want to:
- Replace token embeddings with continuous state encoders
- Add action prediction heads
- Integrate with RL algorithms (e.g., Decision Transformer)