In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Chain-of-Thought Supervised Fine-Tuning -- Vizuara

In this notebook, we will teach a small language model to produce step-by-step reasoning traces using supervised fine-tuning (SFT). By the end, you will have a model that wraps its thinking in `<think>` tags before answering math questions.

**What you will build:** A fine-tuned GPT-2 model that generates chain-of-thought reasoning for grade-school math problems.

In [None]:
# GPU check and setup
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    print("Running on CPU — training will be slower but still works.")

In [None]:
!pip install -q transformers datasets accelerate

## 1. Why Does This Matter?

Large language models like GPT-4 and DeepSeek-R1 can solve complex math problems by "thinking out loud" — generating intermediate reasoning steps before the final answer. But how does a model learn to do this?

The first step is **supervised fine-tuning on chain-of-thought data**. We take a base model and train it on examples where each answer includes step-by-step reasoning. The model learns the *format* of thinking: how to break problems down, how to show its work, and how to structure a logical argument.

By the end of this notebook, you will understand:
- How to format chain-of-thought training data
- How the SFT loss trains on reasoning tokens
- How to fine-tune a small model to produce reasoning traces
- Why SFT alone is necessary but not sufficient

## 2. Building Intuition

Think about how you learned math in school. Your teacher did not just give you answers — they showed you worked examples. Step by step, they demonstrated *how* to solve a problem, and then you practiced on similar problems.

SFT does exactly the same thing for language models. We show the model many worked examples, and it learns to generate similar step-by-step solutions.

The key idea: we wrap the reasoning in special `<think>...</think>` tags so the model knows where the "thinking" starts and ends.

### Think About This
Before we dive in, consider:
- Why might a model that just predicts the answer directly fail on multi-step problems?
- What advantage does "thinking out loud" provide?

## 3. The Mathematics

The SFT training objective is standard next-token prediction. Given an input sequence, the model learns to predict each token given all previous tokens.

The loss function is the cross-entropy loss over the entire completion sequence (including the reasoning tokens):

$$\mathcal{L}_{\text{SFT}} = -\sum_{t=1}^{T} \log p_\theta(y_t \mid y_{<t}, x)$$

where:
- $x$ is the input prompt
- $y_t$ is the $t$-th token in the target sequence
- $T$ is the total number of tokens (reasoning + answer)
- $\theta$ represents the model parameters

**Computational meaning:** For each position in the sequence, the model outputs a probability distribution over the vocabulary. We take the log probability of the correct next token and sum these up. The negative sum is our loss — lower is better.

In [None]:
# Let us compute this loss manually to build understanding
import torch
import torch.nn.functional as F

# Simulated model probabilities for 4 tokens
# In reality, these come from the model's softmax output
token_probs = torch.tensor([0.8, 0.7, 0.5, 0.9])

# Compute per-token log probabilities
log_probs = torch.log(token_probs)
print("Per-token log probabilities:")
for i, (p, lp) in enumerate(zip(token_probs, log_probs)):
    print(f"  Token {i+1}: p = {p:.1f}, log(p) = {lp:.3f}")

# Compute total loss
total_loss = -log_probs.sum()
avg_loss = total_loss / len(token_probs)
print(f"\nTotal loss: {total_loss:.3f}")
print(f"Average per-token loss: {avg_loss:.3f}")
print(f"\nAs the model gets better, probabilities increase and loss decreases.")

## 4. Let's Build It -- Component by Component

### 4.1 Loading the Base Model

We will use GPT-2 as our base model. It is small enough to fine-tune on a single GPU but powerful enough to demonstrate chain-of-thought reasoning.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_NAME = "gpt2"

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)

# GPT-2 does not have a padding token by default
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id

# Add special tokens for reasoning
special_tokens = {"additional_special_tokens": ["<think>", "</think>"]}
tokenizer.add_special_tokens(special_tokens)
model.resize_token_embeddings(len(tokenizer))

print(f"Model: {MODEL_NAME}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Vocabulary size: {len(tokenizer)}")
print(f"<think> token ID: {tokenizer.convert_tokens_to_ids('<think>')}")
print(f"</think> token ID: {tokenizer.convert_tokens_to_ids('</think>')}")

### 4.2 Creating Chain-of-Thought Training Data

For training, we need examples with step-by-step reasoning. Let us create a small dataset of math problems with chain-of-thought solutions.

In [None]:
# Chain-of-thought training examples
cot_examples = [
    {
        "prompt": "What is 15% of 80?",
        "completion": "<think>\nStep 1: 15% means 15/100 = 0.15.\nStep 2: 0.15 * 80 = 12.\nStep 3: Let me verify: 10% of 80 is 8, 5% of 80 is 4, so 15% = 8 + 4 = 12. Correct.\n</think>\nThe answer is 12."
    },
    {
        "prompt": "A store sells pencils for $2 each. If you buy 8 pencils and have a $3 coupon, how much do you pay?",
        "completion": "<think>\nStep 1: Each pencil costs $2. I need 8 pencils.\nStep 2: Total before coupon: 8 * 2 = $16.\nStep 3: Apply $3 coupon: 16 - 3 = $13.\n</think>\nThe answer is $13."
    },
    {
        "prompt": "If a rectangle has length 7 and width 4, what is its area?",
        "completion": "<think>\nStep 1: Area of a rectangle = length * width.\nStep 2: Area = 7 * 4 = 28.\n</think>\nThe answer is 28."
    },
    {
        "prompt": "Tom has 15 apples. He gives 3 to each of his 4 friends. How many does he have left?",
        "completion": "<think>\nStep 1: Tom gives away 3 apples per friend to 4 friends.\nStep 2: Total given away: 3 * 4 = 12 apples.\nStep 3: Remaining: 15 - 12 = 3 apples.\n</think>\nThe answer is 3."
    },
    {
        "prompt": "What is 2^5?",
        "completion": "<think>\nStep 1: 2^5 means 2 multiplied by itself 5 times.\nStep 2: 2 * 2 = 4, 4 * 2 = 8, 8 * 2 = 16, 16 * 2 = 32.\n</think>\nThe answer is 32."
    },
]

# Format for training
def format_example(example):
    return f"Question: {example['prompt']}\n{example['completion']}"

# Show one formatted example
print("=== Formatted Training Example ===")
print(format_example(cot_examples[0]))
print(f"\nTotal examples: {len(cot_examples)}")

### 4.3 Tokenizing the Training Data

We tokenize the full sequence (prompt + reasoning + answer) and create input-target pairs for next-token prediction.

In [None]:
def tokenize_examples(examples, tokenizer, max_length=256):
    """Tokenize chain-of-thought examples for SFT training."""
    input_ids_list = []
    attention_mask_list = []
    labels_list = []

    for ex in examples:
        text = format_example(ex)
        encoded = tokenizer(
            text,
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors="pt"
        )

        input_ids = encoded["input_ids"].squeeze()
        attention_mask = encoded["attention_mask"].squeeze()

        # Labels are same as input_ids (shifted internally by the model)
        # Set padding tokens to -100 so they are ignored in loss
        labels = input_ids.clone()
        labels[attention_mask == 0] = -100

        input_ids_list.append(input_ids)
        attention_mask_list.append(attention_mask)
        labels_list.append(labels)

    return {
        "input_ids": torch.stack(input_ids_list),
        "attention_mask": torch.stack(attention_mask_list),
        "labels": torch.stack(labels_list),
    }

# Tokenize our examples
train_data = tokenize_examples(cot_examples, tokenizer)
print(f"Input shape: {train_data['input_ids'].shape}")
print(f"Labels shape: {train_data['labels'].shape}")

# Show the first example decoded
decoded = tokenizer.decode(train_data["input_ids"][0], skip_special_tokens=False)
print(f"\nDecoded first example (first 200 chars):\n{decoded[:200]}")

## 5. Your Turn -- TODO Exercises

### TODO 1: Implement the SFT Training Loop

Complete the training loop below. You need to:
1. Move the batch to the GPU
2. Forward pass through the model
3. Compute the loss
4. Backpropagate and update

In [None]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
NUM_EPOCHS = 50  # Small dataset, so we train for many epochs

losses = []

for epoch in range(NUM_EPOCHS):
    model.train()

    # ============ TODO ============
    # Move data to device
    batch = {k: v.to(device) for k, v in train_data.items()}

    # Forward pass — the model computes the loss internally
    # when you pass both input_ids and labels
    outputs = ???  # YOUR CODE HERE: call model with the batch

    loss = ???  # YOUR CODE HERE: extract the loss from outputs

    # Backward pass
    optimizer.zero_grad()
    ???  # YOUR CODE HERE: backpropagate
    ???  # YOUR CODE HERE: update weights
    # ==============================

    losses.append(loss.item())

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {loss.item():.4f}")

In [None]:
# Verification: Run this cell to check your training
assert len(losses) == NUM_EPOCHS, f"Expected {NUM_EPOCHS} loss values, got {len(losses)}"
assert losses[-1] < losses[0], f"Loss should decrease! First: {losses[0]:.4f}, Last: {losses[-1]:.4f}"
print(f"Training complete!")
print(f"Initial loss: {losses[0]:.4f}")
print(f"Final loss: {losses[-1]:.4f}")
print(f"Loss reduction: {((losses[0] - losses[-1]) / losses[0] * 100):.1f}%")

### Visualization Checkpoint: Training Loss Curve

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.figure(figsize=(10, 4))
plt.plot(losses, color='#2196F3', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('SFT Training Loss', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 6. Putting It All Together

Now let us test our fine-tuned model. We will give it a math problem and see if it generates a chain-of-thought before answering.

In [None]:
def generate_with_cot(model, tokenizer, question, max_new_tokens=200):
    """Generate a chain-of-thought response from the model."""
    prompt = f"Question: {question}\n"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    model.eval()
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    return response

# Test on training examples
print("=== Testing on Training Examples ===\n")
for ex in cot_examples[:2]:
    response = generate_with_cot(model, tokenizer, ex["prompt"])
    print(f"Q: {ex['prompt']}")
    print(f"Model: {response}")
    print(f"Expected answer: {ex['completion'].split('The answer is')[-1].strip()}")
    print("-" * 60)

### TODO 2: Test on New Problems

Test the model on problems it has not seen during training. Does it generalize?

In [None]:
# ============ TODO ============
# Create 3 new math problems and test the model on them.
# Observe whether the model produces <think> tags and
# whether the reasoning is correct.
#
# new_questions = [
#     "What is 25% of 200?",
#     ???,
#     ???,
# ]
#
# for q in new_questions:
#     response = generate_with_cot(model, tokenizer, q)
#     print(f"Q: {q}")
#     print(f"Model: {response}\n")
# ==============================

## 7. Training and Results

Let us analyze what the model learned.

In [None]:
# Analyze the generated outputs
print("=== Analysis of Model Behavior ===\n")

test_questions = [
    "What is 20% of 50?",
    "If you have 10 cookies and eat 3, how many are left?",
    "What is 6 times 7?"
]

for q in test_questions:
    response = generate_with_cot(model, tokenizer, q)
    has_think = "<think>" in response and "</think>" in response
    print(f"Q: {q}")
    print(f"Has <think> tags: {has_think}")
    print(f"Response: {response[:300]}")
    print("-" * 60)

## 8. Final Output

Our SFT-trained model can now produce chain-of-thought reasoning traces. However, notice the limitations:
- The model learns the *format* of reasoning (using `<think>` tags)
- But the *quality* of reasoning is limited to mimicking the training examples
- On novel problems, the reasoning may be plausible-looking but incorrect

This is exactly why we need reinforcement learning (covered in Notebook 2) — to teach the model which reasoning strategies actually lead to correct answers.

## 9. Reflection and Next Steps

### Think About This
1. Why does SFT alone teach format but not quality?
2. What would happen if we trained on millions of CoT examples instead of 5?
3. How is this similar to / different from how humans learn to solve math problems?

### What Comes Next
In Notebook 2, we will implement GRPO (Group Relative Policy Optimization) from scratch — the RL algorithm that teaches a model to reason *well*, not just to reason *plausibly*.

### Key Takeaway
SFT is the foundation of reasoning model training. It teaches the model the structure of step-by-step thinking. But structure without correctness is not enough — that is where reinforcement learning comes in.