# 01 - Building an Instruction Dataset for SFT

## Context

Supervised fine-tuning (SFT) is how base language models learn to follow
instructions. A base model -- one trained only on next-token prediction --
has no concept of "question" and "answer." It simply continues whatever
text you give it. If you prompt it with a legal question, it might continue
with more questions, a textbook-style paragraph, or something entirely
off-topic. It has knowledge embedded in its weights, but no mechanism to
retrieve that knowledge on command.

SFT solves this by training the model on instruction-response pairs. After
enough examples of "here is an instruction" followed by "here is a good
response," the model learns the pattern: given an instruction, produce a
response in the same style.

**CoCounsel context:** The difference between a model that rambles about law
and one that answers legal questions precisely is SFT. A base model might
continue a court opinion passage with plausible-sounding but directionless
text. After SFT, the same model can summarize an opinion, extract holdings,
or list citations -- because it has learned the instruction-response pattern.

In this notebook, we build an instruction-tuning dataset from court opinions,
format it using chat templates, and understand loss masking -- which tokens
the model actually learns to generate during training.

## Instruction Tuning Format

Modern instruction-tuned models use a **chat template** that delineates
roles: system, user, and assistant.

```
<|system|>
You are a legal research assistant.
<|user|>
Summarize this court opinion: [text]
<|assistant|>
The court held that...
```

The exact tokens vary by model family:

| Format | System token | User token | Assistant token |
|--------|-------------|------------|----------------|
| ChatML | `<\|im_start\|>system` | `<\|im_start\|>user` | `<\|im_start\|>assistant` |
| Llama  | `<\|start_header_id\|>system<\|end_header_id\|>` | `<\|start_header_id\|>user<\|end_header_id\|>` | `<\|start_header_id\|>assistant<\|end_header_id\|>` |

**Why formatting matters:** The model learns the *pattern* of the template.
It learns that text after the assistant header is what it should generate,
and text after the user header is the instruction it should follow. If the
formatting is inconsistent, the model cannot learn this pattern reliably.

In this notebook we use the **ChatML** format because it is simple and
widely supported. The key structure is:

```
<|im_start|>system
You are a legal research assistant.<|im_end|>
<|im_start|>user
{instruction}
{input text}<|im_end|>
<|im_start|>assistant
{response}<|im_end|>
```

## Setup

In [None]:
import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoTokenizer

## Building the Dataset

We load five court opinions from our sample data and transform each one into
multiple instruction-response pairs. For each opinion we generate several
types of questions:

- **Summarize** -- a brief summary of the case
- **Holding** -- what the court decided
- **Citations** -- key legal citations referenced
- **Court** -- which court issued the opinion
- **Key issues** -- the central legal questions

This gives us 5 questions per opinion for a total of 25 training examples,
plus a few additional variations.

In [None]:
# Load court opinions
data_path = Path("../../datasets/sample/court_opinions.jsonl")
opinions = []
with open(data_path) as f:
    for line in f:
        opinions.append(json.loads(line))

print(f"Loaded {len(opinions)} court opinions")
for op in opinions:
    print(f"  - {op['case_name']} ({op['court']})")

In [None]:
def build_instruction_pairs(opinions):
    """Transform court opinions into instruction-response pairs.

    Each opinion generates multiple training examples with different
    instruction types: summarize, holding, citations, court, and issues.

    Returns:
        List of dicts with 'instruction', 'input', and 'output' fields.
    """
    dataset = []

    for op in opinions:
        text = op["text"]
        case_name = op["case_name"]
        court = op["court"]
        citations = op["citations"]

        # --- 1. Summarize ---
        # Extract the first and last sentences as a rough summary
        sentences = [s.strip() for s in text.split(".") if s.strip()]
        summary = ". ".join(sentences[:2]) + "." if len(sentences) >= 2 else text[:300]
        dataset.append({
            "instruction": "Summarize this court opinion.",
            "input": text,
            "output": summary,
        })

        # --- 2. Holding ---
        # Find the disposition (typically the last paragraph with action verbs)
        paragraphs = text.split("\n\n")
        holding = paragraphs[-1].strip() if paragraphs else text[-300:]
        # Look for holding indicators in the text
        for keyword in ["REVERSE", "AFFIRM", "REMAND", "GRANTED", "DENIED"]:
            for sent in sentences:
                if keyword in sent:
                    holding = sent.strip() + "."
                    break
        dataset.append({
            "instruction": "What was the holding in this case?",
            "input": text,
            "output": holding,
        })

        # --- 3. Citations ---
        citation_list = "\n".join(f"- {c}" for c in citations)
        dataset.append({
            "instruction": "List the key legal citations in this opinion.",
            "input": text,
            "output": citation_list,
        })

        # --- 4. Court ---
        dataset.append({
            "instruction": "What court issued this opinion?",
            "input": text,
            "output": court,
        })

        # --- 5. Key issues ---
        # Extract the core legal question from the first paragraph
        first_para = paragraphs[0] if paragraphs else text[:500]
        issue_summary = (
            f"In {case_name}, the key legal issue is: {sentences[0]}."
            if sentences
            else f"The key issue in {case_name} concerns the matters described in the opinion."
        )
        dataset.append({
            "instruction": "What are the key legal issues in this case?",
            "input": text,
            "output": issue_summary,
        })

        # --- 6. Additional: case name extraction ---
        dataset.append({
            "instruction": "What is the name of this case?",
            "input": text,
            "output": case_name,
        })

    return dataset


dataset = build_instruction_pairs(opinions)
print(f"Total instruction pairs: {len(dataset)}")
print(f"\nInstruction types:")
for i, example in enumerate(dataset):
    print(f"  [{i:>2}] {example['instruction'][:60]}")

In [None]:
# Inspect a few examples in detail
for idx in [0, 2, 3]:
    ex = dataset[idx]
    print("=" * 70)
    print(f"Example {idx}")
    print(f"  Instruction: {ex['instruction']}")
    print(f"  Input:       {ex['input'][:120]}...")
    print(f"  Output:      {ex['output'][:200]}")
    print()

## Chat Template Formatting

Now we format each instruction pair using the ChatML template. This is the
actual text the model will see during training. The model learns to generate
only the **assistant** portion -- the instruction and system message are
context, not targets.

In [None]:
CHATML_TEMPLATE = (
    "<|im_start|>system\n"
    "You are a legal research assistant. Answer the question about the "
    "provided court opinion accurately and concisely.<|im_end|>\n"
    "<|im_start|>user\n"
    "{instruction}\n\n"
    "{input}<|im_end|>\n"
    "<|im_start|>assistant\n"
    "{output}<|im_end|>"
)


def format_chatml(example):
    """Format an instruction pair as a ChatML conversation.

    Args:
        example: Dict with 'instruction', 'input', 'output' fields.

    Returns:
        Formatted string in ChatML format.
    """
    return CHATML_TEMPLATE.format(
        instruction=example["instruction"],
        input=example["input"],
        output=example["output"],
    )


# Format all examples
formatted_examples = [format_chatml(ex) for ex in dataset]

# Show one formatted example (using a short-output example for readability)
court_example_idx = 3  # "What court issued this opinion?" -- short output
print("Formatted ChatML example:")
print("=" * 70)
# Truncate the input portion for display
ex = dataset[court_example_idx]
display_text = format_chatml({
    "instruction": ex["instruction"],
    "input": ex["input"][:200] + "...[truncated]",
    "output": ex["output"],
})
print(display_text)
print("=" * 70)
print()
print("Key observations:")
print("- The system message sets the role.")
print("- The user message contains both the instruction and the input text.")
print("- The assistant message contains the target response.")
print("- <|im_start|> and <|im_end|> are special tokens that mark role boundaries.")
print("- During training, the model learns to generate ONLY the assistant portion.")

In [None]:
# Show the raw text with role boundaries highlighted
example_text = formatted_examples[court_example_idx]

# Find the assistant response portion
assistant_marker = "<|im_start|>assistant\n"
assistant_start = example_text.find(assistant_marker)
assistant_content_start = assistant_start + len(assistant_marker)
assistant_end = example_text.find("<|im_end|>", assistant_content_start)

instruction_part = example_text[:assistant_start]
response_part = example_text[assistant_content_start:assistant_end]

print("Parts of the formatted text:")
print()
print("--- CONTEXT (model sees but does NOT learn to generate) ---")
print(f"Length: {len(instruction_part)} characters")
print()
print("--- RESPONSE (model LEARNS to generate this) ---")
print(f"{response_part!r}")
print(f"Length: {len(response_part)} characters")

## Loss Masking

During SFT, we do **not** want the model to learn to generate the
instruction -- only the response. If we computed loss on all tokens,
the model would spend training capacity learning to reproduce instructions,
which is wasteful and can hurt response quality.

The solution is a **loss mask**: an array of 0s and 1s aligned with the
tokenized sequence. A value of 0 means "ignore this token in the loss
calculation" and 1 means "include this token."

The mask is set to:
- **0** for all system/user tokens (the instruction context)
- **1** for all assistant tokens (the response the model should learn)

In practice, this is implemented by setting the label to `-100` for
masked positions, which PyTorch's `CrossEntropyLoss` ignores by default.

In [None]:
def create_loss_mask(tokenized_input, response_start_idx):
    """Create a loss mask that only trains on response tokens.

    Args:
        tokenized_input: List of token IDs (the full sequence).
        response_start_idx: Index where the response tokens begin.

    Returns:
        List of 0s and 1s. 0 = masked (ignored in loss), 1 = trained.
    """
    mask = [0] * response_start_idx + [1] * (len(tokenized_input) - response_start_idx)
    return mask


def find_response_start(text, tokenizer):
    """Find the token index where the assistant response begins.

    Looks for the '<|im_start|>assistant\n' marker in the ChatML text
    and returns the token index of the first response token.

    Args:
        text: The full ChatML-formatted string.
        tokenizer: A HuggingFace tokenizer.

    Returns:
        Token index where the assistant response content starts.
    """
    marker = "<|im_start|>assistant\n"
    marker_pos = text.find(marker)
    if marker_pos == -1:
        raise ValueError("Could not find assistant marker in text")

    # Tokenize everything up to and including the assistant marker
    prefix = text[: marker_pos + len(marker)]
    prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
    return len(prefix_tokens)


# Load a tokenizer to demonstrate
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Demonstrate on the court example
example_text = formatted_examples[court_example_idx]
tokens = tokenizer.encode(example_text, add_special_tokens=False)
response_start = find_response_start(example_text, tokenizer)
mask = create_loss_mask(tokens, response_start)

print(f"Total tokens: {len(tokens)}")
print(f"Response starts at token index: {response_start}")
print(f"Masked tokens (instruction): {sum(1 for m in mask if m == 0)}")
print(f"Trained tokens (response):   {sum(1 for m in mask if m == 1)}")
print(f"Training ratio: {sum(mask) / len(mask):.1%} of tokens are trained on")

In [None]:
# Visualize which tokens are masked vs trained
# Use a short example for readability
short_example = {
    "instruction": "What court issued this opinion?",
    "input": "Before the Court is the appeal of plaintiff James Henderson...",
    "output": "United States Court of Appeals for the Seventh Circuit",
}
short_text = format_chatml(short_example)
short_tokens = tokenizer.encode(short_text, add_special_tokens=False)
short_response_start = find_response_start(short_text, tokenizer)
short_mask = create_loss_mask(short_tokens, short_response_start)

# Decode each token individually for display
decoded_tokens = [tokenizer.decode([t]) for t in short_tokens]

print("Token-level loss mask visualization")
print("=" * 70)
print("RED = masked (not trained on) | GREEN = trained on")
print("=" * 70)
print()

for i, (token_str, m) in enumerate(zip(decoded_tokens, short_mask)):
    color = "\033[92m" if m == 1 else "\033[91m"  # green or red
    reset = "\033[0m"
    label = "TRAIN" if m == 1 else "MASK "
    print(f"  [{i:>3}] {color}{label}{reset}  {token_str!r}")

print()
print(f"Total: {len(short_tokens)} tokens")
print(f"Masked: {short_mask.count(0)} tokens (instruction + template)")
print(f"Trained: {short_mask.count(1)} tokens (assistant response)")

In [None]:
# Graphical visualization of the mask
fig, ax = plt.subplots(figsize=(14, 2))

colors = ["#e74c3c" if m == 0 else "#2ecc71" for m in short_mask]
ax.bar(range(len(short_mask)), [1] * len(short_mask), color=colors, width=1.0, edgecolor="white", linewidth=0.3)

ax.set_xlim(-0.5, len(short_mask) - 0.5)
ax.set_yticks([])
ax.set_xlabel("Token position")
ax.set_title("Loss Mask: Red = masked (instruction), Green = trained (response)")

# Mark the boundary
ax.axvline(x=short_response_start - 0.5, color="black", linewidth=2, linestyle="--", label="Response start")
ax.legend(loc="upper right")

plt.tight_layout()
plt.show()

print(f"The vertical line marks where the assistant response begins (token {short_response_start}).")
print("Everything to the left (red) is instruction context -- the model sees it but")
print("does not compute loss on it. Everything to the right (green) is what the model")
print("is trained to generate.")

## Dataset Statistics

Before training, we analyze the dataset to understand its structure:
how long are responses, what types of instructions are most common,
and how many tokens each example requires.

In [None]:
# Compute statistics for the full dataset
stats = []
for i, (ex, text) in enumerate(zip(dataset, formatted_examples)):
    tokens = tokenizer.encode(text, add_special_tokens=False)
    response_start = find_response_start(text, tokenizer)
    response_tokens = tokens[response_start:]

    stats.append({
        "index": i,
        "instruction": ex["instruction"],
        "total_tokens": len(tokens),
        "instruction_tokens": response_start,
        "response_tokens": len(response_tokens),
        "response_chars": len(ex["output"]),
    })

# Summary
total_tokens_list = [s["total_tokens"] for s in stats]
response_tokens_list = [s["response_tokens"] for s in stats]
instruction_tokens_list = [s["instruction_tokens"] for s in stats]

print("Dataset Summary")
print("=" * 50)
print(f"Number of examples:     {len(stats)}")
print(f"Total tokens (all):     {sum(total_tokens_list):,}")
print()
print(f"Tokens per example:")
print(f"  Mean:   {np.mean(total_tokens_list):.0f}")
print(f"  Median: {np.median(total_tokens_list):.0f}")
print(f"  Min:    {min(total_tokens_list)}")
print(f"  Max:    {max(total_tokens_list)}")
print()
print(f"Response tokens:")
print(f"  Mean:   {np.mean(response_tokens_list):.0f}")
print(f"  Median: {np.median(response_tokens_list):.0f}")
print(f"  Min:    {min(response_tokens_list)}")
print(f"  Max:    {max(response_tokens_list)}")

In [None]:
# Count by instruction type
from collections import Counter

instruction_counts = Counter(s["instruction"] for s in stats)
print("Instruction type distribution:")
for instr, count in instruction_counts.most_common():
    avg_resp = np.mean(
        [s["response_tokens"] for s in stats if s["instruction"] == instr]
    )
    print(f"  {instr:<50} count={count}  avg_response_tokens={avg_resp:.0f}")

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Response token distribution
ax = axes[0]
ax.hist(response_tokens_list, bins=15, color="steelblue", edgecolor="white")
ax.set_xlabel("Response length (tokens)")
ax.set_ylabel("Count")
ax.set_title("Distribution of Response Lengths")
ax.axvline(np.mean(response_tokens_list), color="red", linestyle="--", label=f"Mean: {np.mean(response_tokens_list):.0f}")
ax.legend()

# Total token distribution
ax = axes[1]
ax.hist(total_tokens_list, bins=15, color="#2ecc71", edgecolor="white")
ax.set_xlabel("Total length (tokens)")
ax.set_ylabel("Count")
ax.set_title("Distribution of Total Example Lengths")
ax.axvline(np.mean(total_tokens_list), color="red", linestyle="--", label=f"Mean: {np.mean(total_tokens_list):.0f}")
ax.legend()

# Instruction vs response token split
ax = axes[2]
instruction_types = sorted(set(s["instruction"] for s in stats))
type_labels = [t.replace("this opinion", "...").replace("this case", "...")[:30] for t in instruction_types]
type_response_means = [
    np.mean([s["response_tokens"] for s in stats if s["instruction"] == t])
    for t in instruction_types
]
ax.barh(range(len(type_labels)), type_response_means, color="#e74c3c", edgecolor="white")
ax.set_yticks(range(len(type_labels)))
ax.set_yticklabels(type_labels, fontsize=8)
ax.set_xlabel("Mean response tokens")
ax.set_title("Response Length by Instruction Type")

plt.tight_layout()
plt.show()

print("Observations:")
print("- Summarization tasks produce the longest responses (full summaries).")
print("- Court identification produces the shortest responses (just a name).")
print("- This imbalance is typical in instruction datasets and affects training.")
print("- Very short responses can be learned quickly; longer ones need more epochs.")

In [None]:
# Save the dataset for use in the training notebook
output_path = Path("sft_dataset.json")
with open(output_path, "w") as f:
    json.dump(dataset, f, indent=2)

print(f"Saved {len(dataset)} examples to {output_path}")
print(f"This dataset will be loaded in notebook 02 for training.")

## Exercises

### Exercise (a): Create Instruction Pairs for a New Task

Create instruction pairs for a different task type. For example:

- "Is this opinion from a federal or state court?"
- "What statute or regulation is at the center of this dispute?"
- "Identify the standard of review used by the court."

Add at least 5 new instruction pairs (one per opinion) to the dataset.
Consider: how do you determine the correct output? For some tasks
(like federal vs state), the answer can be derived from the court name.
For others (like standard of review), you need to read the opinion text.

```python
# Example: federal vs state classification
for op in opinions:
    is_federal = "United States" in op["court"]
    dataset.append({
        "instruction": "Is this opinion from a federal or state court?",
        "input": op["text"],
        "output": "Federal court" if is_federal else "State court",
    })
```

### Exercise (b): Analyze Response Length Implications

The distribution of response lengths has practical implications for training:

1. Compute the ratio of instruction tokens to response tokens for each example.
   What fraction of compute is "wasted" on the masked instruction portion?
2. Very short responses (like court names) are learned in fewer gradient steps
   than long summaries. What problems might this cause?
3. If you wanted all instruction types to converge at roughly the same rate,
   how might you adjust the dataset? (Hint: consider oversampling or
   weighting.)

```python
# Starter code
for s in stats:
    ratio = s["response_tokens"] / s["total_tokens"]
    print(f"{s['instruction'][:40]:<42} response_ratio={ratio:.2%}")
```