# 05. The Decoder-Only Transformer Model

We have built all the necessary components:
1.  **Tokenizer** (Notebook 01)
2.  **Embeddings** (Notebook 02)
3.  **Attention** (Notebook 03)
4.  **Transformer Block** (Notebook 04)

Now, we will assemble these pieces into the full **Decoder-Only Transformer** architecture, similar to GPT-2.

## Architecture Overview

The full model consists of:
1.  **Token Embeddings**: Map token IDs to vectors.
2.  **Positional Embeddings**: Add position information.
3.  **Dropout**: Applied to the sum of embeddings.
4.  **Stack of Transformer Blocks**: The core processing units.
5.  **Final LayerNorm**: Stabilize the output.
6.  **Language Modeling Head**: Project back to vocabulary size to get logits.

## Weight Tying
A common practice in modern LLMs (like GPT-2) is to **tie the weights** of the token embedding layer and the language modeling head. This means `lm_head.weight = token_emb.weight`. This reduces the number of parameters and often improves performance.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from typing import Optional, Tuple

# Re-defining Config and Components for self-containment
@dataclass
class ModelConfig:
    n_embd: int = 768
    n_head: int = 12
    n_layer: int = 12
    n_positions: int = 1024
    vocab_size: int = 50257
    dropout: float = 0.1
    bias: bool = True

# ... (Paste MultiHeadAttention, FeedForward, TransformerBlock classes here if not importing)
# For this notebook, we assume they are available or we re-define them briefly for the class structure.
# To keep the notebook clean, we will assume the previous classes are defined or imported.
# Ideally, we would move them to `src/model.py` and import them, but for the notebook flow, we'll redefine the Block.

class MultiHeadAttention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head
        self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.register_buffer("mask", torch.tril(torch.ones(config.n_positions, config.n_positions)).view(1, 1, config.n_positions, config.n_positions))

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.qkv_proj(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
        attn = attn.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)
        y = attn @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.out_proj(y))
        return y

class FeedForward(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.fc2 = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)
    def forward(self, x):
        return self.dropout(self.fc2(F.gelu(self.fc1(x))))

class TransformerBlock(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.attn = MultiHeadAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.ffn = FeedForward(config)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

## 1. The DecoderLM Class

Here is the implementation of the full model.

In [2]:
class DecoderLM(nn.Module):
    """Decoder-only transformer language model."""
    
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        # Token embeddings
        self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
        
        # Positional embeddings
        self.pos_emb = nn.Embedding(config.n_positions, config.n_embd)
        
        # Dropout
        self.drop = nn.Dropout(config.dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layer)
        ])
        
        # Final layer norm
        self.ln_f = nn.LayerNorm(config.n_embd)
        
        # Output projection (language modeling head)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # Weight tying (share weights between token embeddings and output projection)
        self.lm_head.weight = self.token_emb.weight
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        """Initialize weights (GPT-2 style)."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(
        self,
        input_ids: torch.Tensor,
        targets: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Args:
            input_ids: Input token IDs [batch, seq_len]
            targets: Target token IDs for computing loss [batch, seq_len]
        
        Returns:
            logits: Output logits [batch, seq_len, vocab_size]
            loss: Cross-entropy loss (if targets provided)
        """
        device = input_ids.device
        b, t = input_ids.size()
        
        assert t <= self.config.n_positions, f"Sequence length {t} exceeds maximum {self.config.n_positions}"
        
        # Token embeddings
        tok_emb = self.token_emb(input_ids)  # [batch, seq_len, n_embd]
        
        # Positional embeddings
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)  # [1, seq_len]
        pos_emb = self.pos_emb(pos)  # [1, seq_len, n_embd]
        
        # Combine embeddings
        x = self.drop(tok_emb + pos_emb)
        
        # Apply transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Final layer norm
        x = self.ln_f(x)
        
        # Language modeling head
        logits = self.lm_head(x)  # [batch, seq_len, vocab_size]
        
        # Compute loss if targets provided
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-1  # Ignore padding tokens
            )
        
        return logits, loss

## 2. Verification

Let's instantiate a small model and check its parameter count and forward pass.

In [3]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Create a small config for testing
small_config = ModelConfig(
    n_embd=128,
    n_head=4,
    n_layer=2,
    n_positions=128,
    vocab_size=1000
)

model = DecoderLM(small_config)
print(f"Model parameters: {count_parameters(model):,}")

# Dummy input
x = torch.randint(0, small_config.vocab_size, (2, 32))
logits, loss = model(x)

print("Input shape:", x.shape)
print("Logits shape:", logits.shape)
assert logits.shape == (2, 32, 1000)
print("Verification successful!")

Model parameters: 541,184
Input shape: torch.Size([2, 32])
Logits shape: torch.Size([2, 32, 1000])
Verification successful!
