# Dialogue GPT: Final Project

## Overview

Put everything together to train your own language model from scratch. You'll assemble all components from Days 1-4, implement KV caching for efficient generation, train on a text corpus, and analyze the results.

---

## Part 1: Model Architecture

### Required Components

Assemble your model using implementations from previous days:

**From Day 1:**
- Multi-head self-attention
- Causal masking for autoregressive generation

**From Day 2:**
- Implement 3 variants with different positional encodings:
  - Sinusoidal positional encoding
  - RoPE (Rotary Position Embedding)
  - ALiBi (Attention with Linear Biases)

**From Day 3:**
- Character-level or BPE tokenizer (your choice)

**From Day 4:**
- Layer normalization
- Transformer blocks (attention + FFN + residual)
- Full decoder LM with embedding layer and output head

**Additional Modules:**
- **Training scripts**
- **KV caching** for efficient generation (see below)

### Model Specifications

Choose reasonable hyperparameters:
- `d_model`: 128-256
- `num_heads`: 4-8
- `num_layers`: 4-8
- `d_ff`: 4 * d_model
- `max_seq_len`: 128-512
- `vocab_size`: depends on tokenizer

### KV Caching Implementation

During autoregressive generation, computing attention for all previous tokens at each step is wasteful. **KV caching** stores previously computed Key and Value matrices.

**How it works:**
1. On first forward pass, compute K and V for all input tokens
2. For each new token, only compute K and V for that token
3. Concatenate new K, V with cached values
4. Compute attention using cached KV pairs

**Implementation requirements:**
- Modify your attention module to accept and return optional `past_kv` cache
- Cache format: list of tuples `[(K_layer1, V_layer1), (K_layer2, V_layer2), ...]`
- Support both training mode (no cache) and generation mode (with cache)
- Ensure cache works correctly with all three positional encodings

**Expected speedup:** 3-10x faster generation for sequences > 100 tokens

---

## Part 2: Training Setup

### Dataset Options

Choose a corpus to train on. Some examples are:

1. **Complete Works of Shakespeare** (~5MB)
   - Download: `wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt`
   
2. **OpenWebText subset** (download ~100MB sample)
   - Download: `https://huggingface.co/datasets/Skylion007/openwebtext`
   
### Training Details

**Setup:**
- Train/validation split: 90/10
- Batch size: 32-64 (depends on GPU memory)
- Sequence length: Match your `max_seq_len`

**Optimizer:**
- AdamW with weight decay (0.01)
- Implement learning rate schedule: warmup + cosine decay (see below)
  - Max LR: 3e-4 to 1e-3
  - Min LR: 1e-5
  - Warmup: 5-10% of total steps

**Training:**
- Train for at least 10,000 steps (more is better)
- Evaluate every 500-1000 steps
- Track: training loss, validation loss, perplexity
- Save best checkpoint based on validation loss

**Training Loop Pseudocode:**
```python
# Setup
optimizer = AdamW(model.parameters(), weight_decay=0.01)
criterion = CrossEntropyLoss()
best_val_loss = infinity
global_step = 0

for epoch in range(num_epochs):
    model.train()
    
    for batch_inputs, batch_targets in train_loader:
        # 1. Update learning rate based on schedule
        lr = get_lr(global_step, warmup_steps, total_steps, max_lr, min_lr)
        set_optimizer_lr(optimizer, lr)
        
        # 2. Forward pass
        logits = model(batch_inputs)  # (batch, seq_len, vocab_size)
        
        # 3. Compute loss (reshape for cross-entropy)
        loss = criterion(
            logits.view(-1, vocab_size),  # (batch*seq_len, vocab_size)
            batch_targets.view(-1)         # (batch*seq_len,)
        )
        
        # 4. Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        global_step += 1
    
    # 5. Periodic validation
    if (epoch + 1) % eval_interval == 0:
        val_loss = evaluate(model, val_loader)
        
        # 6. Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(model)

# 7. Restore best model for inference
load_checkpoint(model, best_checkpoint)
```

### Learning Rate Schedule Implementation

Modern transformers use **warmup followed by cosine decay** instead of a fixed learning rate.

**Why warmup?**
- Model weights start randomly initialized with large, unstable gradients
- Starting with a small LR and gradually increasing prevents the model from diverging early
- Allows optimizer momentum statistics (in Adam) to stabilize

**Warmup phase** (first `warmup_steps` steps):
```python
lr = max_lr * (step / warmup_steps)  # Linear increase from 0 to max_lr
```

**Cosine decay phase** (remaining steps):
```python
progress = (step - warmup_steps) / (total_steps - warmup_steps)
lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
```

**Implementation:**
- Calculate `total_steps = num_epochs * steps_per_epoch`
- At each training step, compute current LR and update optimizer: `optimizer.param_groups[0]['lr'] = lr`
- Cosine decay smoothly reduces LR from `max_lr` to `min_lr`, helping the model converge better than linear decay

**Expected behavior:** LR increases linearly during warmup, then follows a smooth cosine curve down to minimum

### Perplexity Metric

**Perplexity** is the standard metric for evaluating language models. It measures how "surprised" the model is by the test data.

**Formula:**
```python
perplexity = math.exp(average_cross_entropy_loss)
```

**Interpretation:**
- Perplexity represents the **effective vocabulary size** the model is choosing from at each step
- Perplexity of 10 means: on average, the model is as uncertain as if choosing uniformly from 10 tokens
- **Lower is better** - perplexity of 1 would mean perfect prediction
- Random guessing gives perplexity equal to vocabulary size

**Example:** If your model has validation loss of 2.5, perplexity = e^2.5 ≈ 12.2

**Implementation:**
```python
def compute_perplexity(model, dataloader):
    total_loss = 0
    total_tokens = 0
    for batch in dataloader:
        loss = compute_loss(model, batch)
        total_loss += loss.item() * batch.size(0)
        total_tokens += batch.size(0)
    avg_loss = total_loss / total_tokens
    return math.exp(avg_loss)
```

---

## Part 3: Analysis & Experiments

### 3.1 Training Curves

Plot for each positional encoding variant:
- Training loss vs. steps
- Validation loss vs. steps  
- Perplexity over time

Also create one learning rate schedule plot showing:
- Warmup phase (linear increase)
- Cosine decay phase (smooth decrease)
- Label the transition point between phases

### 3.2 Positional Encoding Comparison

Create a comparison table:

| Metric | Sinusoidal | RoPE | ALiBi |
|--------|-----------|------|-------|
| Final train loss | | | |
| Final val loss | | | |
| Final perplexity | | | |
| Training time | | | |
| Best generation quality (subjective 1-5) | | | |

**Analysis questions:**
- Which positional encoding converged fastest?
- Which achieved lowest validation loss?
- Which generalizes best to longer sequences than training length?

### 3.3 Attention Pattern Visualization

For your best model, visualize attention patterns:

1. **Select 3-5 interesting generated sequences**
2. **Plot attention heatmaps** for different layers and heads:
   - Early layers vs. late layers
   - Different attention heads in the same layer
3. **Analyze patterns:**
   - Do you see local vs. global attention patterns?
   - Do different heads specialize (e.g., previous token, syntactic structure)?
   - How does attention change across layers?

**Hint:** Extract attention weights from your model during forward pass. Plot using `plt.imshow()` or `sns.heatmap()`.

### 3.4 KV Caching Performance

Measure the speedup from KV caching:

1. **Generate 500 tokens** with and without KV caching
2. **Time both approaches** and calculate speedup ratio
3. **Plot generation time vs. sequence length** (test at lengths: 50, 100, 200, 500)
4. **Verify correctness:** Ensure outputs are identical with/without caching

**Expected results:**
- Speedup should increase with sequence length
- Plot should show linear time (with cache) vs quadratic (without cache)

### 3.5 Text Generation

Generate text with different settings:

**Prompts to try:**
- Continuation of famous quotes from your corpus
- Generic prompts like "Once upon a time"
- Unusual prompts to test generalization

**Sampling strategies:**
- Temperature: {0.3, 0.7, 1.0, 1.5}
- Top-p: {0.5, 0.9, 0.95}

**Show 5-10 best examples** demonstrating:
- Coherent long-form generation (200+ tokens)
- Diverse outputs from same prompt
- Model's understanding of style/domain

---

## Part 4: Deliverable

### Final Jupyter Notebook

Your submission should be a **single well-organized notebook** containing:

#### 1. Introduction
- Brief description of your model architecture
- Dataset choice and statistics
- Hyperparameters used

#### 2. Training Results
- All training curves (losses, perplexity over time)
- Learning rate schedule plot showing warmup and cosine decay
- Comparison table for positional encodings
- Discussion: What worked? What didn't?

#### 3. KV Caching Analysis
- Speedup measurements and plots
- Correctness verification
- Discussion of implementation challenges

#### 4. Attention Analysis
- 5-10 attention heatmaps from interesting examples
- Written analysis of patterns observed
- Comparison across layers and heads

#### 5. Text Generation Showcase
- 10-15 generation examples with different settings
- Best examples demonstrating quality
- Failure cases and analysis

#### 6. Conclusions
- What did you learn about LLMs?
- Surprising findings?
- What would you improve with more time/compute?
- Impact of KV caching on production systems

### Code Quality
- Clean, well-commented code
- Reproducible (set random seeds)
- Model checkpoint saved
- Include requirements.txt if using external libraries

---

## Grading Rubric (100 points)

| Component | Points |
|-----------|--------|
| **Model Implementation** | 25 |
| - Correct assembly of all components | 10 |
| - Three positional encoding variants | 8 |
| - KV caching implementation | 7 |
| **Training** | 20 |
| - Proper training loop with validation | 6 |
| - Warmup + cosine LR schedule implemented | 6 |
| - Training for sufficient steps | 4 |
| - Clear training curves + perplexity tracking | 4 |
| **Positional Encoding Comparison** | 15 |
| - All three variants trained | 8 |
| - Quantitative comparison | 4 |
| - Thoughtful analysis | 3 |
| **KV Caching Analysis** | 10 |
| - Correct speedup measurements | 5 |
| - Plots and correctness verification | 3 |
| - Implementation discussion | 2 |
| **Attention Visualization** | 12 |
| - Multiple attention heatmaps | 6 |
| - Insightful pattern analysis | 6 |
| **Text Generation** | 13 |
| - Diverse examples | 7 |
| - Quality analysis | 6 |
| **Presentation** | 5 |
| - Organized notebook | 3 |
| - Clear writing | 2 |

---

## Tips

- **Start small:** Debug with tiny model on small dataset first
- **Save checkpoints:** Training takes time, don't lose progress
- **Use GPU:** Essential for reasonable training time (Google Colab free tier works)
- **Monitor validation:** Stop if overfitting occurs
- **Compare fairly:** Use same random seed for different positional encodings
- **Attention extraction:** Modify your attention class to return weights as well as outputs
- **KV caching tips:** Test with and without cache on short sequences first to verify correctness
- **Cache dimensions:** Ensure cache tensors have correct shapes: `(batch_size, num_heads, seq_len, d_k)`
- **LR schedule:** Plot your learning rate over training steps to verify warmup and decay work correctly
- **Debugging training:** If loss explodes (NaN), reduce max_lr or increase warmup steps

---

## Additional Tips

If your model performs poorly, you can experiment with:
1. **Try different architectures** (vary depth vs. width)
2. **Compare tokenization strategies** (char vs BPE)
3. **Fine-tune on a second domain** (transfer learning)
4. **Implement beam search** instead of sampling
5. **Multi-query attention** (MQA) or grouped-query attention (GQA) for faster inference

---

In [7]:
""" DO NOT CHANGE """
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

np.random.seed(2025)
torch.manual_seed(2025)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print("Setup complete!")

ModuleNotFoundError: No module named 'numpy'

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
    def forward(self, query, key, value, mask=None, past_kv=None):
        batch_size = query.size(0)
        seq_len_q = query.size(1)
        seq_len_k_new = key.size(1)
        Q_proj = self.W_Q(query)
        K_proj = self.W_K(key)
        V_proj = self.W_V(value)
        Q = Q_proj.view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
        K = K_proj.view(batch_size, seq_len_k_new, self.num_heads, self.d_k).transpose(1, 2)
        V = V_proj.view(batch_size, seq_len_k_new, self.num_heads, self.d_k).transpose(1, 2)
        if past_kv is not None:
            K_past, V_past = past_kv
            K = torch.cat([K_past, K], dim=2)
            V = torch.cat([V_past, V], dim=2)
        present_kv = (K, V)
        seq_len_k_full = K.size(2)
        scores = (Q @ K.transpose(-2, -1)) / (self.d_k ** 0.5)

        if mask is not None:
            scores = scores + mask.unsqueeze(0).unsqueeze(0)

        weights = torch.softmax(scores, dim=-1)
        output_attn = weights @ V
        output_attn = output_attn.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model)
        output = self.W_O(output_attn)

        return output, weights, present_kv

In [3]:
class SinusoidalPE(nn.Module):
    def __init__(self, d_model, max_len=5000):

        """
        Args:
            d_model: model dimension
            max_len: maximum sequence length
        """

        super().__init__()
        self.d_model = d_model
        pe = torch.zeros(max_len, d_model)
        position=torch.arange(0, max_len).unsqueeze(1)
        bottom=torch.exp(torch.arange(0, d_model,2)*(-np.log(10000)/d_model))
        pe[:,0::2] = torch.sin(position*bottom)
        pe[:,1::2] = torch.cos(position*bottom)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):

        """
        Args:
            x: (batch, seq_len, d_model)

        Returns:
            x + positional_encoding: (batch, seq_len, d_model)
        """

        return x + self.pe[:, :x.shape[1]]

In [4]:
class RoPE(nn.Module):
    def __init__(self, d_model, base=10000):
        super().__init__()
        self.d_model = d_model
        self.base = base
        frq = 1/(self.base**(torch.arange(0, d_model, 2)/d_model))
        self.register_buffer('inv_freq',frq)

    def _rotate_half(self, x):
        rotated_x = torch.empty_like(x)
        rotated_x[:,:, 0::2] = -x[:,:, 1::2]
        rotated_x[:,:, 1::2] = x[:,:, 0::2]
        return rotated_x

    def apply_rope(self, q, k, positions):
        freqs = positions.unsqueeze(-1)*self.inv_freq
        freqs_expanded=torch.stack((freqs, freqs), dim=-1).flatten(-2, -1)
        cos_vals=torch.cos(freqs_expanded)
        sin_vals=torch.sin(freqs_expanded)
        q_rot=q*cos_vals+self._rotate_half(q)*sin_vals
        k_rot=k*cos_vals+ self._rotate_half(k)*sin_vals
        return q_rot, k_rot

In [5]:
class ALiBi(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.num_heads=num_heads

        slopes = torch.pow(2,-8 /num_heads*torch.arange(1, num_heads + 1, dtype=torch.float))
        self.register_buffer('slopes', slopes.unsqueeze(1).unsqueeze(1))
    def get_alibi_bias(self, seq_len):
        positions = torch.arange(0, seq_len, dtype=torch.float)
        distance_matrix = torch.abs(positions.unsqueeze(0) - positions.unsqueeze(1))
        alibi_bias = -self.slopes * distance_matrix.unsqueeze(0)
        return alibi_bias
    def forward(self, attention_scores, seq_len):
        alibi_bias = self.get_alibi_bias(seq_len)
        return attention_scores + alibi_bias

In [6]:
class MyLayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        """
        Layer normalization.

        Args:
            d_model: feature dimension
            eps: small constant for numerical stability
        """
        super().__init__()
        self.eps = eps
        self.gamma =nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))

    def forward(self, x):
        """
        Apply layer normalization.

        Args:
            x: (..., d_model) tensor

        Returns:
            normalized tensor of same shape
        """
        mean = torch.mean(x, dim=-1, keepdim=True)
        var = torch.var(x, dim=-1, keepdim=True, unbiased=False)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma*x_norm + self.beta

In [7]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff=None):
        """
        Decoder transformer block with pre-norm.

        Args:
            d_model: model dimension
            num_heads: number of attention heads
            d_ff: feed-forward hidden dimension (default: 4 * d_model)
        """
        super().__init__()
        d_ff = d_ff or 4 * d_model

        self.ln1 = MyLayerNorm(d_model)
        self.ln2 = MyLayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, num_heads)

        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff,d_model)
        )
    def forward(self, x, past_kv=None):
        """
        Forward pass with pre-norm residual connections.
        Args:
            x: (batch, seq_len, d_model)
            past_kv: optional tuple of (K_cached, V_cached) for KV caching for this block.
                     (batch, num_heads, prev_seq_len, d_k) per K/V.
        Returns:
            output: (batch, seq_len, d_model)
            present_kv: updated KV cache for this block.
                        (K_present, V_present)
        """
        attn_output, _, present_kv = self.attn(self.ln1(x), past_kv=past_kv)
        x = x + attn_output
        x = x + self.ffn(self.ln2(x))
        return x, present_kv

In [4]:
class DecoderLM(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, max_seq_len=512):
        """
        Decoder-only language model.

        Args:
            vocab_size: size of vocabulary
            d_model: model dimension
            num_heads: attention heads per layer
            num_layers: number of transformer blocks
            max_seq_len: maximum sequence length (used for learned positional embeddings)
        """
        super().__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.max_seq_len = max_seq_len

        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_seq_len, d_model)

        self.blocks = nn.ModuleList([TransformerBlock(d_model, num_heads) for _ in range(num_layers)])

        self.ln_f = MyLayerNorm(d_model)

        self.head = nn.Linear(d_model, vocab_size)
    def forward(self, x, past_kv_list=None):
        """
        Forward pass.

        Args:
            x: (batch, seq_len) token IDs. During generation with caching, seq_len is typically 1.
            past_kv_list: optional list of tuples `[(K_layer1, V_layer1), ...]` for KV caching.
                          Each (K, V) is (batch, num_heads, prev_seq_len, d_k).

        Returns:
            logits: (batch, seq_len, vocab_size)
            present_kv_list: list of updated KV caches for all layers.
                             `[(K_present_layer1, V_present_layer1), ...]`
        """
        batch_size, seq_len = x.shape

        tok_emb = self.token_embed(x)

        if past_kv_list is not None:
            past_seq_len = past_kv_list[0][0].shape[2] if past_kv_list else 0
            current_position_idx = past_seq_len
            positions_to_embed = torch.tensor([current_position_idx], device=x.device) # (1,)
            pos_emb = self.pos_embed(positions_to_embed) # (1, d_model)
            # Expand to (1, 1, d_model) for broadcasting with tok_emb (batch_size, 1, d_model)
            pos_emb = pos_emb.unsqueeze(0)
        else:
            # Full sequence processing (training or first pass of generation)
            positions = torch.arange(seq_len, device=x.device) # (seq_len,)
            pos_emb = self.pos_embed(positions) # (seq_len, d_model)
            # Expand to (1, seq_len, d_model) for broadcasting with tok_emb (batch_size, seq_len, d_model)
            pos_emb = pos_emb.unsqueeze(0)


        h = tok_emb + pos_emb # (batch, seq_len, d_model)

        present_kv_list = []
        for i, block in enumerate(self.blocks):
            past_kv_i = past_kv_list[i] if past_kv_list else None
            h, present_kv_i = block(h, past_kv=past_kv_i)
            present_kv_list.append(present_kv_i)

        h = self.ln_f(h) # (batch, seq_len, d_model)
        logits = self.head(h) # (batch, seq_len, vocab_size)

        return logits, present_kv_list # Return updated KV cache list

NameError: name 'nn' is not defined

In [9]:
class CausalSelfAttention(nn.Module):
    """
    Causal self-attention layer that wraps MultiHeadAttention and applies a causal mask.
    Handles passing through KV cache.
    """
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)

    def forward(self, x, past_kv=None):
        """
        Args:
            x: (batch, seq_len, d_model)
            past_kv: Optional tuple of (K_cached, V_cached) for KV caching.
        Returns:
            output: (batch, seq_len, d_model)
            attn_weights: (batch, num_heads, seq_len, seq_len_k_full)
            present_kv: Updated KV cache (K_present, V_present)
        """
        batch_size, seq_len, d_model = x.shape

        if past_kv is not None:
            mask = None
        else:
            mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)
            mask = mask.masked_fill(mask, float('-inf'))
        output, attn_weights, present_kv = self.mha(x, x, x, mask=mask, past_kv=past_kv)
        return output, attn_weights, present_kv

In [10]:
import os

# Create a directory for data if it doesn't exist
if not os.path.exists('data'):
    os.makedirs('data')

# Download the dataset
!wget -O data/input.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-12-28 22:37:35--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘data/input.txt’


2025-12-28 22:37:35 (213 MB/s) - ‘data/input.txt’ saved [1115394/1115394]



In [11]:
from collections import Counter

class BPETokenizer:
    def __init__(self):
        self.merges = []
        self.vocab = {"<UNK>": 0}
        self.id_to_token = {0: "<UNK>"}
        self.next_id = 1

    def _get_pairs(self, tokens):
        """
        Get all adjacent pairs and their counts.

        Args:
            tokens: list of tokens

        Returns:
            Counter of pairs
        """
        pairs = Counter()
        for i in range(len(tokens) - 1):
            pairs[(tokens[i], tokens[i+1])] += 1
        return pairs


    def _merge_pair(self, tokens, pair, new_token):
        """
        Merge all occurrences of pair into new_token.

        Args:
            tokens: list of tokens
            pair: tuple (token1, token2) to merge
            new_token: the merged token

        Returns:
            new list of tokens with merges applied
        """
        new_tokens = []
        i = 0
        while i<len(tokens):
            if i+1<len(tokens) and (tokens[i], tokens[i+1]) == pair:
                new_tokens.append(new_token)
                i+=2
            else:
                new_tokens.append(tokens[i])
                i+=1
        return new_tokens


    def fit(self, text, num_merges=10):
        """
        Learn BPE merges from text.

        Args:
            text: training text
            num_merges: number of merge operations to perform
        """
        self.merges = []
        self.vocab = {"<UNK>": 0}
        self.id_to_token = {0: "<UNK>"}
        self.next_id = 1
        initial_chars = sorted(list(set(text)))
        for char in initial_chars:
            if char not in self.vocab:
                self.vocab[char] = self.next_id
                self.id_to_token[self.next_id] = char
                self.next_id += 1

        current_tokens = list(text)

        for _ in range(num_merges):
            pairs = self._get_pairs(current_tokens)
            if not pairs:
                break
            most_frequent_pair = max(pairs, key=pairs.get)
            new_token_str = "".join(most_frequent_pair)
            if new_token_str not in self.vocab:
                self.vocab[new_token_str] = self.next_id
                self.id_to_token[self.next_id] = new_token_str
                self.next_id += 1
                self.merges.append((most_frequent_pair, new_token_str))
            else:
                if (most_frequent_pair, new_token_str) not in self.merges:
                    self.merges.append((most_frequent_pair, new_token_str))

            current_tokens = self._merge_pair(current_tokens, most_frequent_pair, new_token_str)

    def encode(self, text):
        """
        Encode text using learned BPE merges.

        Args:
            text: string to encode

        Returns:
            list of token IDs
        """
        initial_tokens_list = []
        for char in text:
            if char not in self.vocab:
                self.vocab[char] = self.next_id
                self.id_to_token[self.next_id] = char
                self.next_id += 1
            initial_tokens_list.append(char)
        tokens_to_merge = list(initial_tokens_list)
        for pair, new_token_str in self.merges:
            tokens_to_merge = self._merge_pair(tokens_to_merge, pair, new_token_str)
        encoded_ids = [self.vocab.get(token, self.vocab["<UNK>"]) for token in tokens_to_merge]
        return encoded_ids

    def decode(self, ids):
        """
        Convert token IDs back to text.

        Args:
            ids: list of token IDs

        Returns:
            decoded string
        """
        decoded_tokens = [self.id_to_token.get(id, "<UNK>") for id in ids]
        return "".join(decoded_tokens)


with open('data/input.txt', 'r') as f:
    text_data = f.read()

tokenizer = BPETokenizer()
tokenizer.fit(text_data, num_merges=1000)


encoded_data = tokenizer.encode(text_data)

print(f"Text data loaded. Length: {len(text_data)} characters")
print(f"Tokenizer fitted. Vocabulary size: {len(tokenizer.vocab)}")
print(f"Encoded data length: {len(encoded_data)} tokens")
print(f"First 10 encoded tokens: {encoded_data[:10]}")

Text data loaded. Length: 1115394 characters
Tokenizer fitted. Vocabulary size: 1066
Encoded data length: 416705 tokens
First 10 encoded tokens: [536, 988, 77, 794, 465, 348, 416, 42, 837, 529]


In [12]:
import torch
from torch.utils.data import Dataset, DataLoader

max_seq_len = 256
batch_size = 32


data = torch.tensor(encoded_data, dtype=torch.long)

train_ratio = 0.9
split_idx = int(train_ratio * len(data))

train_data = data[:split_idx]
val_data = data[split_idx:]

class TextDataset(Dataset):
    def __init__(self, data, max_seq_len):
        self.data = data
        self.max_seq_len = max_seq_len

    def __len__(self):
        if len(self.data) < self.max_seq_len + 1:
            return 0
        return len(self.data) - self.max_seq_len

    def __getitem__(self, idx):
        chunk = self.data[idx : idx + self.max_seq_len + 1]
        inputs = chunk[:-1]
        targets = chunk[1:]
        return inputs, targets
train_dataset = TextDataset(train_data, max_seq_len)
val_dataset = TextDataset(val_data, max_seq_len)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Training data length: {len(train_data)}")
print(f"Validation data length: {len(val_data)}")
print(f"Training dataset sequences: {len(train_dataset)}")
print(f"Validation dataset sequences: {len(val_dataset)}")
print(f"Number of training batches: {len(train_dataloader)}")
print(f"Number of validation batches: {len(val_dataloader)}")

# Example of one batch
if len(train_dataloader) > 0:
    sample_inputs, sample_targets = next(iter(train_dataloader))
    print(f"\nSample input batch shape: {sample_inputs.shape}")
    print(f"Sample target batch shape: {sample_targets.shape}")
    print(f"Sample inputs[0][:10]: {sample_inputs[0][:10].tolist()}")
    print(f"Sample targets[0][:10]: {sample_targets[0][:10].tolist()}")


Training data length: 375034
Validation data length: 41671
Training dataset sequences: 374778
Validation dataset sequences: 41415
Number of training batches: 11712
Number of validation batches: 1295

Sample input batch shape: torch.Size([32, 256])
Sample target batch shape: torch.Size([32, 256])
Sample inputs[0][:10]: [94, 635, 45, 389, 859, 70, 42, 131, 76, 77]
Sample targets[0][:10]: [635, 45, 389, 859, 70, 42, 131, 76, 77, 844]


In [13]:
import torch.optim as optim

vocab_size = len(tokenizer.vocab)
d_model = 256
num_heads = 8
num_layers = 6
model = DecoderLM(vocab_size, d_model, num_heads, num_layers, max_seq_len=max_seq_len)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
model.to(device)
optimizer = optim.AdamW(model.parameters(), weight_decay=0.01)

print(f"Model initialized with vocab_size={vocab_size}, d_model={d_model}, num_heads={num_heads}, num_layers={num_layers}")
print(f"Model is on device: {next(model.parameters()).device}")
print("Optimizer (AdamW) initialized.")

Using device: cuda
Model initialized with vocab_size=1066, d_model=256, num_heads=8, num_layers=6
Model is on device: cuda:0
Optimizer (AdamW) initialized.


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import Counter


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None, past_kv=None):
        batch_size = query.size(0)
        seq_len_q = query.size(1)
        seq_len_k_new = key.size(1)

        Q_proj = self.W_Q(query)
        K_proj = self.W_K(key)
        V_proj = self.W_V(value)

        Q = Q_proj.view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
        K = K_proj.view(batch_size, seq_len_k_new, self.num_heads, self.d_k).transpose(1, 2)
        V = V_proj.view(batch_size, seq_len_k_new, self.num_heads, self.d_k).transpose(1, 2)

        if past_kv is not None:
            K_past, V_past = past_kv
            K = torch.cat([K_past, K], dim=2)
            V = torch.cat([V_past, V], dim=2)
        present_kv = (K, V)

        seq_len_k_full = K.size(2)

        scores = (Q @ K.transpose(-2, -1)) / (self.d_k ** 0.5)

        if mask is not None:
            scores = scores + mask.unsqueeze(0).unsqueeze(0)

        weights = torch.softmax(scores, dim=-1)
        output_attn = weights @ V

        output_attn = output_attn.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model)
        output = self.W_O(output_attn)

        return output, weights, present_kv


class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)

    def forward(self, x, past_kv=None):
        batch_size, seq_len, d_model = x.shape

        if past_kv is not None:
            mask = None
        else:
            mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)
            mask = mask.masked_fill(mask, float('-inf'))

        output, attn_weights, present_kv = self.mha(x, x, x, mask=mask, past_kv=past_kv)
        return output, attn_weights, present_kv


class MyLayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))

    def forward(self, x):
        mean = torch.mean(x, dim=-1, keepdim=True)
        var = torch.var(x, dim=-1, keepdim=True, unbiased=False)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta


class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff=None):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        self.ln1 = MyLayerNorm(d_model)
        self.ln2 = MyLayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, num_heads)

        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x, past_kv=None):
        attn_output, _, present_kv = self.attn(self.ln1(x), past_kv=past_kv)
        x = x + attn_output
        x = x + self.ffn(self.ln2(x))
        return x, present_kv


class DecoderLM(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, max_seq_len=512):
        super().__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.max_seq_len = max_seq_len

        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_seq_len, d_model)

        self.blocks = nn.ModuleList([TransformerBlock(d_model, num_heads) for _ in range(num_layers)])

        self.ln_f = MyLayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, x, past_kv_list=None):
        batch_size, seq_len = x.shape

        tok_emb = self.token_embed(x)

        if past_kv_list is not None:
            past_seq_len = past_kv_list[0][0].shape[2] if past_kv_list else 0
            current_position_idx = past_seq_len
            positions_to_embed = torch.tensor([current_position_idx], device=x.device)
            pos_emb = self.pos_embed(positions_to_embed)
            h = tok_emb + pos_emb.unsqueeze(0)
        else:
            positions = torch.arange(seq_len, device=x.device)
            pos_emb = self.pos_embed(positions)
            h = tok_emb + pos_emb.unsqueeze(0)

        present_kv_list = []
        for i, block in enumerate(self.blocks):
            past_kv_i = past_kv_list[i] if past_kv_list else None
            h, present_kv_i = block(h, past_kv=past_kv_i)
            present_kv_list.append(present_kv_i)

        h = self.ln_f(h)
        logits = self.head(h)

        return logits, present_kv_list

vocab_size = len(tokenizer.vocab)

d_model = 256
num_heads = 8
num_layers = 6

model = DecoderLM(vocab_size, d_model, num_heads, num_layers, max_seq_len=max_seq_len)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

model.to(device)

optimizer = optim.AdamW(model.parameters(), weight_decay=0.01)

print(f"Model initialized with vocab_size={vocab_size}, d_model={d_model}, num_heads={num_heads}, num_layers={num_layers}")
print(f"Model is on device: {next(model.parameters()).device}")
print("Optimizer (AdamW) initialized.")

ModuleNotFoundError: No module named 'torch'

In [15]:
import math
import torch.nn as nn

criterion = nn.CrossEntropyLoss()

def get_lr(step, warmup_steps, total_steps, max_lr, min_lr):
    if step < warmup_steps:
        return max_lr * (step / warmup_steps)
    elif step > total_steps:
        return min_lr
    else:
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))

def set_optimizer_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_tokens = 0
    with torch.no_grad():
        for batch_inputs, batch_targets in dataloader:
            batch_inputs, batch_targets = batch_inputs.to(device), batch_targets.to(device)
            logits, _ = model(batch_inputs)
            loss = criterion(
                logits.view(-1, logits.size(-1)),
                batch_targets.view(-1)
            )
            total_loss += loss.item() * batch_targets.numel()
            total_tokens += batch_targets.numel()
    avg_loss = total_loss / total_tokens
    model.train()
    return avg_loss

def compute_perplexity(avg_loss):
    return math.exp(avg_loss)


num_epochs = 10
total_training_steps = num_epochs * len(train_dataloader)
warmup_steps = int(0.1 * total_training_steps)
max_lr = 3e-4
min_lr = 1e-5

eval_interval = 2
save_interval = 2

best_val_loss = float('inf')
global_step = 0


train_losses = []
val_losses = []
perplexities = []
lrs = []

print("Starting training")
for epoch in range(num_epochs):
    model.train()
    epoch_train_losses = []
    for batch_inputs, batch_targets in train_dataloader:
        batch_inputs, batch_targets = batch_inputs.to(device), batch_targets.to(device)
        lr = get_lr(global_step, warmup_steps, total_training_steps, max_lr, min_lr)
        set_optimizer_lr(optimizer, lr)
        lrs.append(lr)

        logits, _ = model(batch_inputs)
        loss = criterion(
            logits.view(-1, logits.size(-1)),
            batch_targets.view(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_train_losses.append(loss.item())
        global_step += 1

    avg_epoch_train_loss = sum(epoch_train_losses) / len(epoch_train_losses)
    train_losses.append(avg_epoch_train_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_epoch_train_loss:.4f}, LR: {lrs[-1]:.6f}")

    if (epoch + 1) % eval_interval == 0:
        val_loss = evaluate(model, val_dataloader, criterion, device)
        perplexity = compute_perplexity(val_loss)
        val_losses.append(val_loss)
        perplexities.append(perplexity)
        print(f"Validation Loss: {val_loss:.4f}, Perplexity: {perplexity:.2f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model_checkpoint.pt')
            print(f"Saved new best model with validation loss: {best_val_loss:.4f}")

print("Training finished.")

metrics = {
    'train_losses': train_losses,
    'val_losses': val_losses,
    'perplexities': perplexities,
    'lrs': lrs,
    'best_val_loss': best_val_loss
}

print("Training metrics collected and stored.")

Starting training
Epoch 1/10, Train Loss: 1.3008, LR: 0.000300
Epoch 2/10, Train Loss: 0.0168, LR: 0.000291
Validation Loss: 0.0271, Perplexity: 1.03
Saved new best model with validation loss: 0.0271
Epoch 3/10, Train Loss: 0.0149, LR: 0.000266
Epoch 4/10, Train Loss: 0.0137, LR: 0.000228
Validation Loss: 0.0268, Perplexity: 1.03
Saved new best model with validation loss: 0.0268
Epoch 5/10, Train Loss: 0.0125, LR: 0.000180
Epoch 6/10, Train Loss: 0.0109, LR: 0.000130
Validation Loss: 0.0278, Perplexity: 1.03
Epoch 7/10, Train Loss: 0.0088, LR: 0.000083
Epoch 8/10, Train Loss: 0.0065, LR: 0.000044
Validation Loss: 0.0311, Perplexity: 1.03
Epoch 9/10, Train Loss: 0.0046, LR: 0.000019
Epoch 10/10, Train Loss: 0.0035, LR: 0.000010
Validation Loss: 0.0336, Perplexity: 1.03
Training finished.
Training metrics collected and stored.


In [2]:
# Load the trained model checkpoint
print("Loading model checkpoint...")

# Recreate the model with the same architecture
model_test = DecoderLM(vocab_size, d_model, num_heads, num_layers, max_seq_len=max_seq_len)
model_test.to(device)

# Load the checkpoint
checkpoint = torch.load('best_model_checkpoint.pt', map_location=device)
model_test.load_state_dict(checkpoint)
model_test.eval()

print("Model loaded successfully!")
print(f"Model is on device: {next(model_test.parameters()).device}")


Loading model checkpoint...


NameError: name 'DecoderLM' is not defined

In [None]:
def generate_text(model, tokenizer, prompt, max_new_tokens=200, temperature=1.0, top_k=None, top_p=None, use_kv_cache=True):
    """
    Generate text using the model with KV caching support.
    
    Args:
        model: The trained DecoderLM model
        tokenizer: The BPETokenizer
        prompt: Starting text prompt
        max_new_tokens: Maximum number of tokens to generate
        temperature: Sampling temperature (higher = more random)
        top_k: If set, only sample from top k tokens
        top_p: If set, use nucleus sampling with this probability
        use_kv_cache: Whether to use KV caching for faster generation
    
    Returns:
        Generated text string
    """
    model.eval()
    
    # Encode the prompt
    input_ids = tokenizer.encode(prompt)
    input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
    
    generated_ids = input_ids.copy()
    past_kv_list = None
    
    with torch.no_grad():
        # First pass: process the entire prompt
        if use_kv_cache:
            logits, past_kv_list = model(input_tensor, past_kv_list=None)
            # Only use the last token's logits for next token prediction
            next_token_logits = logits[0, -1, :] / temperature
        else:
            # Without KV cache, we'll process the full sequence each time (slower)
            logits, _ = model(input_tensor, past_kv_list=None)
            next_token_logits = logits[0, -1, :] / temperature
        
        # Generate tokens one by one
        for _ in range(max_new_tokens):
            # Apply top-k filtering if specified
            if top_k is not None:
                top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
                # Create a new tensor with -inf for non-top-k tokens
                filtered_logits = torch.full_like(next_token_logits, float('-inf'))
                filtered_logits[top_k_indices] = top_k_logits
                next_token_logits = filtered_logits
            
            # Apply top-p (nucleus) sampling if specified
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                # Remove tokens with cumulative probability above the threshold
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                next_token_logits[indices_to_remove] = float('-inf')
            
            # Sample from the distribution
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token_id = torch.multinomial(probs, num_samples=1).item()
            
            generated_ids.append(next_token_id)
            
            # Prepare next input (single token for KV caching)
            if use_kv_cache:
                next_input = torch.tensor([[next_token_id]], dtype=torch.long).to(device)
                logits, past_kv_list = model(next_input, past_kv_list=past_kv_list)
                next_token_logits = logits[0, -1, :] / temperature
            else:
                # Without cache, process the full sequence again
                input_tensor = torch.tensor([generated_ids], dtype=torch.long).to(device)
                logits, _ = model(input_tensor, past_kv_list=None)
                next_token_logits = logits[0, -1, :] / temperature
    
    # Decode the generated text
    generated_text = tokenizer.decode(generated_ids)
    return generated_text

print("Text generation function created!")


In [None]:
# Test the model with various prompts
print("=" * 80)
print("Testing Model Text Generation")
print("=" * 80)

# Test prompts inspired by Shakespeare
test_prompts = [
    "To be or not to be",
    "Once upon a time",
    "The king said",
    "Romeo and Juliet",
    "All the world's a stage"
]

print("\nGenerating text with different prompts...\n")

for i, prompt in enumerate(test_prompts, 1):
    print(f"\n{'='*80}")
    print(f"Test {i}: Prompt = '{prompt}'")
    print(f"{'='*80}")
    
    try:
        generated = generate_text(
            model_test, 
            tokenizer, 
            prompt, 
            max_new_tokens=150,
            temperature=0.8,
            use_kv_cache=True
        )
        print(f"\nGenerated text:\n{generated}\n")
    except Exception as e:
        print(f"Error generating text: {e}")
        import traceback
        traceback.print_exc()


In [None]:
# Test with different temperature settings
print("\n" + "=" * 80)
print("Testing Different Temperature Settings")
print("=" * 80)

prompt = "To be or not to be"
temperatures = [0.3, 0.7, 1.0, 1.5]

for temp in temperatures:
    print(f"\n{'='*80}")
    print(f"Temperature = {temp}")
    print(f"{'='*80}")
    try:
        generated = generate_text(
            model_test,
            tokenizer,
            prompt,
            max_new_tokens=100,
            temperature=temp,
            use_kv_cache=True
        )
        print(f"\nGenerated text:\n{generated}\n")
    except Exception as e:
        print(f"Error: {e}")


In [None]:
# Test KV caching speedup
import time

print("\n" + "=" * 80)
print("Testing KV Caching Performance")
print("=" * 80)

prompt = "The king said to his court"
num_tokens = 200

# Test with KV cache
start_time = time.time()
generated_with_cache = generate_text(
    model_test,
    tokenizer,
    prompt,
    max_new_tokens=num_tokens,
    temperature=0.8,
    use_kv_cache=True
)
time_with_cache = time.time() - start_time

# Test without KV cache
start_time = time.time()
generated_without_cache = generate_text(
    model_test,
    tokenizer,
    prompt,
    max_new_tokens=num_tokens,
    temperature=0.8,
    use_kv_cache=False
)
time_without_cache = time.time() - start_time

speedup = time_without_cache / time_with_cache

print(f"\nGeneration time with KV cache: {time_with_cache:.4f} seconds")
print(f"Generation time without KV cache: {time_without_cache:.4f} seconds")
print(f"Speedup: {speedup:.2f}x")
print(f"\nGenerated text (with cache):\n{generated_with_cache[:200]}...")
print(f"\nGenerated text (without cache):\n{generated_without_cache[:200]}...")
print(f"\nOutputs match: {generated_with_cache == generated_without_cache}")
