# 02 - Supervised Fine-Tuning Training

## Setup

We use **SmolLM-135M** (`HuggingFaceTB/SmolLM-135M`) -- a 135-million
parameter language model small enough to train on a CPU. This model is a
base model: it has been trained only on next-token prediction and has no
instruction-following ability. Our goal is to give it that ability through
supervised fine-tuning.

The training loop in this notebook is **real and functional**. It will run
on a CPU, though slowly (expect a few minutes per epoch). The point is to
understand every component: data loading, tokenization, forward pass, loss
masking, backward pass, and parameter updates.

In [None]:
import json
import copy
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.manual_seed(42)

In [None]:
# Load the base model and tokenizer
model_name = "HuggingFaceTB/SmolLM-135M"

tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(model_name)

num_params = sum(p.numel() for p in base_model.parameters())
print(f"Model: {model_name}")
print(f"Parameters: {num_params:,}")
print(f"Vocabulary size: {tokenizer.vocab_size:,}")
print(f"Pad token: {tokenizer.pad_token!r} (id={tokenizer.pad_token_id})")
print(f"EOS token: {tokenizer.eos_token!r} (id={tokenizer.eos_token_id})")

## Before SFT: Baseline

Let's see what the base model does when given legal prompts in instruction
format. Since it has never been instruction-tuned, it should not follow
instructions -- it will simply continue the text as a language model.

In [None]:
def generate_text(model, tokenizer, prompt, max_new_tokens=80):
    """Generate text from a prompt using greedy decoding.

    Args:
        model: A HuggingFace causal LM.
        tokenizer: The corresponding tokenizer.
        prompt: Input text.
        max_new_tokens: Maximum tokens to generate.

    Returns:
        The generated text (prompt + completion).
    """
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )

    return tokenizer.decode(output_ids[0], skip_special_tokens=True)


# Legal prompts in instruction format
test_prompts = [
    (
        "<|im_start|>system\n"
        "You are a legal research assistant. Answer the question about the "
        "provided court opinion accurately and concisely.<|im_end|>\n"
        "<|im_start|>user\n"
        "What court issued this opinion?\n\n"
        "Before the Court is the appeal of plaintiff James Henderson from the "
        "district court's grant of summary judgment.<|im_end|>\n"
        "<|im_start|>assistant\n"
    ),
    (
        "<|im_start|>system\n"
        "You are a legal research assistant. Answer the question about the "
        "provided court opinion accurately and concisely.<|im_end|>\n"
        "<|im_start|>user\n"
        "List the key legal citations in this opinion.\n\n"
        "We review the district court's grant of summary judgment de novo. "
        "Anderson v. Liberty Lobby, Inc., 477 U.S. 242 (1986).<|im_end|>\n"
        "<|im_start|>assistant\n"
    ),
    (
        "<|im_start|>system\n"
        "You are a legal research assistant. Answer the question about the "
        "provided court opinion accurately and concisely.<|im_end|>\n"
        "<|im_start|>user\n"
        "Summarize this court opinion.\n\n"
        "Plaintiff Maria Thompson brings this action under the Individuals "
        "with Disabilities Education Act challenging the Board of Education's "
        "determination that D.T. is not eligible for special education services.<|im_end|>\n"
        "<|im_start|>assistant\n"
    ),
]

print("BASE MODEL OUTPUT (before SFT)")
print("=" * 70)
print("The base model has no instruction-following ability.")
print("It will just continue the text, ignoring the chat template.")
print("=" * 70)

base_outputs = []
for i, prompt in enumerate(test_prompts):
    output = generate_text(base_model, tokenizer, prompt, max_new_tokens=60)
    generated = output[len(prompt):]
    base_outputs.append(generated)
    print(f"\n--- Prompt {i + 1} ---")
    print(f"Instruction: {prompt.split(chr(10))[3][:60]}")
    print(f"Generated: {generated[:200]}")

## Manual Training Loop

Now we implement SFT from scratch. The steps are:

1. Load the instruction dataset (from notebook 01).
2. Tokenize each example with the ChatML template.
3. Create a PyTorch `Dataset` and `DataLoader`.
4. For each batch: forward pass, compute masked cross-entropy loss,
   backward pass, optimizer step.
5. Repeat for several epochs.

The key detail is **loss masking**: we only compute loss on the assistant
response tokens. The instruction tokens provide context but do not
contribute to the gradient.

In [None]:
# Load the instruction dataset built in notebook 01
dataset_path = Path("sft_dataset.json")

if not dataset_path.exists():
    # If notebook 01 has not been run, build the dataset here
    print("sft_dataset.json not found -- building dataset from court opinions...")
    opinions_path = Path("../../datasets/sample/court_opinions.jsonl")
    opinions = []
    with open(opinions_path) as f:
        for line in f:
            opinions.append(json.loads(line))

    raw_dataset = []
    for op in opinions:
        text = op["text"]
        sentences = [s.strip() for s in text.split(".") if s.strip()]
        summary = ". ".join(sentences[:2]) + "." if len(sentences) >= 2 else text[:300]

        paragraphs = text.split("\n\n")
        holding = paragraphs[-1].strip()
        for keyword in ["REVERSE", "AFFIRM", "REMAND", "GRANTED", "DENIED"]:
            for sent in sentences:
                if keyword in sent:
                    holding = sent.strip() + "."
                    break

        citation_list = "\n".join(f"- {c}" for c in op["citations"])
        issue_summary = f"In {op['case_name']}, the key legal issue is: {sentences[0]}." if sentences else ""

        raw_dataset.append({"instruction": "Summarize this court opinion.", "input": text, "output": summary})
        raw_dataset.append({"instruction": "What was the holding in this case?", "input": text, "output": holding})
        raw_dataset.append({"instruction": "List the key legal citations in this opinion.", "input": text, "output": citation_list})
        raw_dataset.append({"instruction": "What court issued this opinion?", "input": text, "output": op["court"]})
        raw_dataset.append({"instruction": "What are the key legal issues in this case?", "input": text, "output": issue_summary})
        raw_dataset.append({"instruction": "What is the name of this case?", "input": text, "output": op["case_name"]})
else:
    with open(dataset_path) as f:
        raw_dataset = json.load(f)

print(f"Loaded {len(raw_dataset)} instruction pairs")

In [None]:
CHATML_TEMPLATE = (
    "<|im_start|>system\n"
    "You are a legal research assistant. Answer the question about the "
    "provided court opinion accurately and concisely.<|im_end|>\n"
    "<|im_start|>user\n"
    "{instruction}\n\n"
    "{input}<|im_end|>\n"
    "<|im_start|>assistant\n"
    "{output}<|im_end|>"
)

ASSISTANT_MARKER = "<|im_start|>assistant\n"


def tokenize_example(example, tokenizer, max_length=512):
    """Tokenize a single instruction example with loss mask.

    Formats the example using ChatML, tokenizes it, and creates a label
    tensor where instruction tokens are set to -100 (ignored by loss).

    Args:
        example: Dict with 'instruction', 'input', 'output' fields.
        tokenizer: A HuggingFace tokenizer.
        max_length: Maximum sequence length (truncates if exceeded).

    Returns:
        Dict with 'input_ids', 'labels', and 'attention_mask' tensors.
    """
    text = CHATML_TEMPLATE.format(
        instruction=example["instruction"],
        input=example["input"],
        output=example["output"],
    )

    # Find where the assistant response starts
    marker_pos = text.find(ASSISTANT_MARKER)
    prefix = text[: marker_pos + len(ASSISTANT_MARKER)]
    prefix_token_ids = tokenizer.encode(prefix, add_special_tokens=False)
    response_start = len(prefix_token_ids)

    # Tokenize the full text
    token_ids = tokenizer.encode(text, add_special_tokens=False)

    # Truncate to max_length
    token_ids = token_ids[:max_length]

    # Create labels: -100 for instruction tokens, token_id for response tokens
    labels = [-100] * min(response_start, len(token_ids))
    labels += token_ids[len(labels):]

    # Attention mask: 1 for all real tokens
    attention_mask = [1] * len(token_ids)

    return {
        "input_ids": token_ids,
        "labels": labels,
        "attention_mask": attention_mask,
    }


# Test on one example
test_tokenized = tokenize_example(raw_dataset[3], tokenizer)  # "What court issued this opinion?"
print(f"Input IDs length: {len(test_tokenized['input_ids'])}")
print(f"Labels length:    {len(test_tokenized['labels'])}")
print(f"Masked tokens:    {sum(1 for l in test_tokenized['labels'] if l == -100)}")
print(f"Trained tokens:   {sum(1 for l in test_tokenized['labels'] if l != -100)}")

In [None]:
class SFTDataset(Dataset):
    """PyTorch Dataset for SFT training.

    Tokenizes instruction pairs and provides them as padded tensors
    with loss masks (labels set to -100 for instruction tokens).
    """

    def __init__(self, examples, tokenizer, max_length=512):
        self.tokenized = []
        for ex in examples:
            self.tokenized.append(tokenize_example(ex, tokenizer, max_length))

    def __len__(self):
        return len(self.tokenized)

    def __getitem__(self, idx):
        item = self.tokenized[idx]
        return {
            "input_ids": torch.tensor(item["input_ids"], dtype=torch.long),
            "labels": torch.tensor(item["labels"], dtype=torch.long),
            "attention_mask": torch.tensor(item["attention_mask"], dtype=torch.long),
        }


def collate_fn(batch):
    """Pad a batch of variable-length sequences to the same length.

    Pads input_ids with pad_token_id, labels with -100, and
    attention_mask with 0.
    """
    max_len = max(item["input_ids"].size(0) for item in batch)

    padded_input_ids = []
    padded_labels = []
    padded_attention_mask = []

    for item in batch:
        seq_len = item["input_ids"].size(0)
        pad_len = max_len - seq_len

        padded_input_ids.append(
            F.pad(item["input_ids"], (0, pad_len), value=tokenizer.pad_token_id)
        )
        padded_labels.append(
            F.pad(item["labels"], (0, pad_len), value=-100)
        )
        padded_attention_mask.append(
            F.pad(item["attention_mask"], (0, pad_len), value=0)
        )

    return {
        "input_ids": torch.stack(padded_input_ids),
        "labels": torch.stack(padded_labels),
        "attention_mask": torch.stack(padded_attention_mask),
    }


# Create dataset and dataloader
sft_dataset = SFTDataset(raw_dataset, tokenizer, max_length=512)
dataloader = DataLoader(sft_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

print(f"Dataset size: {len(sft_dataset)} examples")
print(f"Batch size: 2")
print(f"Batches per epoch: {len(dataloader)}")

# Inspect one batch
sample_batch = next(iter(dataloader))
print(f"\nSample batch shapes:")
for key, val in sample_batch.items():
    print(f"  {key}: {val.shape}")

In [None]:
# Clone the base model for manual training
# (we keep the original base_model untouched for comparison later)
manual_model = copy.deepcopy(base_model)
manual_model.train()

# Training hyperparameters
learning_rate = 5e-5
num_epochs = 3
optimizer = torch.optim.AdamW(manual_model.parameters(), lr=learning_rate)

print(f"Training configuration:")
print(f"  Learning rate: {learning_rate}")
print(f"  Epochs: {num_epochs}")
print(f"  Optimizer: AdamW")
print(f"  Batch size: 2")
print(f"  Total steps: {num_epochs * len(dataloader)}")
print()
print("Starting training...")
print("(This will take a few minutes on CPU.)")

In [None]:
# Training loop
loss_history = []

for epoch in range(num_epochs):
    epoch_losses = []
    manual_model.train()

    for step, batch in enumerate(dataloader):
        # Forward pass
        outputs = manual_model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],  # labels with -100 masking
        )

        # The model computes cross-entropy loss internally, respecting -100 mask.
        # This is equivalent to:
        #   logits = outputs.logits  # (batch, seq_len, vocab_size)
        #   shift_logits = logits[:, :-1, :].contiguous()
        #   shift_labels = batch["labels"][:, 1:].contiguous()
        #   loss = F.cross_entropy(
        #       shift_logits.view(-1, shift_logits.size(-1)),
        #       shift_labels.view(-1),
        #       ignore_index=-100,
        #   )
        loss = outputs.loss

        # Backward pass
        loss.backward()

        # Optimizer step
        optimizer.step()
        optimizer.zero_grad()

        loss_val = loss.item()
        epoch_losses.append(loss_val)
        loss_history.append(loss_val)

        if (step + 1) % 5 == 0 or step == 0:
            print(
                f"  Epoch {epoch + 1}/{num_epochs}, "
                f"Step {step + 1}/{len(dataloader)}, "
                f"Loss: {loss_val:.4f}"
            )

    avg_loss = np.mean(epoch_losses)
    print(f"Epoch {epoch + 1} complete. Average loss: {avg_loss:.4f}")
    print()

print(f"Training complete.")
print(f"Final loss: {loss_history[-1]:.4f}")
print(f"Total steps: {len(loss_history)}")

In [None]:
# Plot training loss curve
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Per-step loss
ax = axes[0]
ax.plot(loss_history, color="steelblue", alpha=0.7, linewidth=1)
ax.set_xlabel("Training step")
ax.set_ylabel("Loss")
ax.set_title("Training Loss (per step)")
ax.grid(alpha=0.3)

# Mark epoch boundaries
steps_per_epoch = len(dataloader)
for e in range(1, num_epochs):
    ax.axvline(x=e * steps_per_epoch, color="red", linestyle="--", alpha=0.5)

# Smoothed loss (rolling average)
ax = axes[1]
window = max(3, len(loss_history) // 10)
if len(loss_history) >= window:
    smoothed = np.convolve(loss_history, np.ones(window) / window, mode="valid")
    ax.plot(range(window - 1, len(loss_history)), smoothed, color="steelblue", linewidth=2)
else:
    ax.plot(loss_history, color="steelblue", linewidth=2)
ax.set_xlabel("Training step")
ax.set_ylabel("Loss (smoothed)")
ax.set_title(f"Training Loss (rolling average, window={window})")
ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Initial loss: {loss_history[0]:.4f}")
print(f"Final loss:   {loss_history[-1]:.4f}")
print(f"Reduction:    {loss_history[0] - loss_history[-1]:.4f}")

## After SFT: Results

Let's run the same legal prompts through the fine-tuned model and compare
with the base model. Even with a tiny dataset and a few epochs, the model
should start showing some instruction-following behavior -- attempting to
respond in the format it learned during training.

In [None]:
print("BEFORE vs AFTER SFT")
print("=" * 70)

manual_model.eval()

for i, prompt in enumerate(test_prompts):
    # Base model output (already computed above)
    base_out = base_outputs[i] if i < len(base_outputs) else ""

    # Fine-tuned model output
    sft_out = generate_text(manual_model, tokenizer, prompt, max_new_tokens=60)
    sft_generated = sft_out[len(prompt):]

    instruction_line = prompt.split("\n")[3] if len(prompt.split("\n")) > 3 else "N/A"

    print(f"\n--- Prompt {i + 1}: {instruction_line[:60]} ---")
    print(f"  BASE:      {base_out[:150]}")
    print(f"  AFTER SFT: {sft_generated[:150]}")

print()
print("=" * 70)
print("The fine-tuned model should show some tendency to follow the instruction")
print("format, even if the content is not perfect. With a 135M parameter model")
print("and only ~30 training examples, the results are modest but the behavioral")
print("shift is visible: the model attempts to respond rather than just continue.")

## Using trl SFTTrainer

The manual training loop above teaches you every component. In practice,
you would use the `trl` library's `SFTTrainer`, which handles:

- Tokenization and chat template formatting
- Padding, batching, and loss masking
- Learning rate scheduling
- Logging and checkpointing
- Gradient accumulation
- Metrics during training
- Distributed training support (multi-GPU)

Below we show how to achieve the same result with much less code.

In [None]:
from trl import SFTTrainer, SFTConfig
from datasets import Dataset as HFDataset


# Prepare data in the format trl expects:
# a HuggingFace Dataset with a "text" column containing the formatted ChatML text.
formatted_texts = []
for ex in raw_dataset:
    formatted_texts.append(
        CHATML_TEMPLATE.format(
            instruction=ex["instruction"],
            input=ex["input"],
            output=ex["output"],
        )
    )

trl_dataset = HFDataset.from_dict({"text": formatted_texts})

print(f"trl dataset: {trl_dataset}")
print(f"Sample text (first 200 chars): {trl_dataset[0]['text'][:200]}...")

In [None]:
# Clone a fresh base model for trl training
trl_model = copy.deepcopy(base_model)

# Configure SFTTrainer
training_args = SFTConfig(
    output_dir="./sft_trl_output",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    learning_rate=5e-5,
    logging_steps=5,
    save_strategy="no",
    max_seq_length=512,
    report_to="none",  # disable wandb / tensorboard
)

trainer = SFTTrainer(
    model=trl_model,
    args=training_args,
    train_dataset=trl_dataset,
    processing_class=tokenizer,
)

print("Starting trl SFTTrainer...")
print("(Same hyperparameters as the manual loop.)")

In [None]:
# Run training
trl_result = trainer.train()

print(f"\ntrl training complete.")
print(f"Training loss: {trl_result.training_loss:.4f}")
print(f"Total steps: {trl_result.global_step}")

## Comparison

Let's compare outputs from the manual training loop and the trl trainer.
They should produce similar results since they use the same model,
data, and hyperparameters.

In [None]:
print("COMPARISON: Manual Loop vs trl SFTTrainer")
print("=" * 70)

manual_model.eval()
trl_model.eval()

for i, prompt in enumerate(test_prompts):
    manual_out = generate_text(manual_model, tokenizer, prompt, max_new_tokens=60)
    manual_generated = manual_out[len(prompt):]

    trl_out = generate_text(trl_model, tokenizer, prompt, max_new_tokens=60)
    trl_generated = trl_out[len(prompt):]

    instruction_line = prompt.split("\n")[3] if len(prompt.split("\n")) > 3 else "N/A"

    print(f"\n--- Prompt {i + 1}: {instruction_line[:60]} ---")
    print(f"  MANUAL: {manual_generated[:150]}")
    print(f"  TRL:    {trl_generated[:150]}")

print()
print("=" * 70)
print("\nWhy use trl in practice?")
print("  - Handles edge cases in tokenization and padding.")
print("  - Built-in learning rate scheduling and warmup.")
print("  - Logging integration (wandb, tensorboard).")
print("  - Gradient accumulation for effective larger batch sizes.")
print("  - Metrics tracked during training.")
print("  - Checkpoint saving and resumption.")
print("  - Distributed training support (multi-GPU).")
print("\nThe manual loop is for understanding. trl is for production.")

## Exercises

### Exercise (a): Experiment with Hyperparameters

Try different learning rates and batch sizes. For each configuration,
train for 3 epochs and record the final loss and sample outputs.

Suggested configurations:

| Learning Rate | Batch Size | Expected Behavior |
|--------------|-----------|------------------|
| 1e-5 | 2 | Slow convergence, stable |
| 5e-5 | 2 | Moderate convergence |
| 1e-4 | 2 | Fast convergence, possible instability |
| 5e-5 | 1 | Noisier gradients, more updates per epoch |
| 5e-5 | 4 | Smoother gradients, fewer updates per epoch |

```python
# Starter code
configs = [
    {"lr": 1e-5, "batch_size": 2},
    {"lr": 5e-5, "batch_size": 2},
    {"lr": 1e-4, "batch_size": 2},
]

for cfg in configs:
    model_copy = copy.deepcopy(base_model)
    opt = torch.optim.AdamW(model_copy.parameters(), lr=cfg["lr"])
    loader = DataLoader(
        sft_dataset, batch_size=cfg["batch_size"],
        shuffle=True, collate_fn=collate_fn,
    )
    # Train for 3 epochs, record losses...
```

Questions:
- Which learning rate converges fastest?
- Does a higher learning rate always mean better results?
- How does batch size affect loss smoothness vs training speed?

### Exercise (b): Observe Overfitting

Train the model for 20+ epochs on this small dataset and observe what
happens.

```python
overfit_model = copy.deepcopy(base_model)
overfit_model.train()
opt = torch.optim.AdamW(overfit_model.parameters(), lr=5e-5)

overfit_losses = []
for epoch in range(25):
    for batch in dataloader:
        outputs = overfit_model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
        )
        outputs.loss.backward()
        opt.step()
        opt.zero_grad()
        overfit_losses.append(outputs.loss.item())

# Plot the loss -- it should approach 0
plt.plot(overfit_losses)
plt.title("Overfitting: Loss approaches 0")
plt.show()

# Test the overfit model -- outputs may be memorized/repetitive
overfit_model.eval()
for prompt in test_prompts:
    output = generate_text(overfit_model, tokenizer, prompt)
    print(output[len(prompt):][:200])
```

You should observe:
- Loss drops close to 0 (the model has memorized the training data).
- Outputs become repetitive or exactly reproduce training examples.
- On prompts not seen during training, the model may produce garbled text.
- This demonstrates why SFT needs diverse, large datasets in practice.