# Verification Notebook

This notebook verifies the hierarchical CoT dataset by:
1. Loading a small language model and running vanilla inference as a baseline
2. Fine-tuning the model on hierarchical CoT data with `[THOUGHT]`, `[SOLUTION]`, `[RETURN]` special tokens
3. Running inference with a custom generate loop that prunes intermediate reasoning when `[RETURN]` is produced
4. Comparing token efficiency and answer quality between vanilla and hierarchical approaches

In [1]:
# %load_ext autoreload

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from huggingface_hub import notebook_login

import model
import re
import os

In [3]:
# notebook_login()
os.environ["HF_TOKEN"]="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"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    _dev_idx = torch.cuda.current_device()
    total_memory = torch.cuda.get_device_properties(_dev_idx).total_memory
    allocated_memory = torch.cuda.memory_allocated(_dev_idx)
    reserved_memory = torch.cuda.memory_reserved(_dev_idx)

    print(f"GPU {_dev_idx} Memory Status:")
    print(f"  Total Memory: {total_memory / (1024**2):.2f} MiB")
    print(f"  Allocated Memory: {allocated_memory / (1024**2):.2f} MiB")
    print(f"  Reserved (Cached) Memory: {reserved_memory / (1024**2):.2f} MiB")
    print(f"  Free in Cache: {(reserved_memory - allocated_memory) / (1024**2):.2f} MiB")

    print("\nDetailed Memory Summary:")
    print(torch.cuda.memory_summary(device=_dev_idx, abbreviated=True))
else:
    print("CUDA not available.")


# Configuration
MODEL_NAME = "nvidia/OpenMath-Nemotron-1.5B"
DATASET_NAME = "anujjamwal/openmathreasoning-hierarchical-cot"
MAX_NEW_TOKENS = 8192
MAX_SEQ_LENGTH = 32768
PROMPT_TEMPLATE = "Solve the following math problem. Make sure to put the answer (and only answer) inside \\boxed{{}}.\\n\\n{problem}"

# Checkpoint configuration
CHECKPOINT_DIR = "./checkpoints/hierarchical-cot"
HF_CHECKPOINT_REPO = "anujjamwal/OpenMath-Nemotron-1.5B-hierarchical"

Using device: cuda
GPU 0 Memory Status:
  Total Memory: 81152.75 MiB
  Allocated Memory: 0.00 MiB
  Reserved (Cached) Memory: 0.00 MiB
  Free in Cache: 0.00 MiB

Detailed Memory Summary:
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Requested memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|--------------------------------

In [6]:
import gc

def clear_cache(*extra_vars):
    """Free GPU memory: delete named globals, run Python GC, flush CUDA allocator cache.

    gc.collect() MUST run before empty_cache(): Python needs to decrement reference
    counts first so tensors become truly unreferenced before the CUDA allocator can
    reclaim those pages.  Calling empty_cache() alone (as was done previously) has
    no effect on live Python objects still holding tensor references.
    """
    _defaults = ('hier_model', 'base_model', 'optimizer', 'vanilla_model')
    for var in _defaults + tuple(extra_vars):
        if var in globals():
            del globals()[var]
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:
# Load tokenizer and base model (before adding special tokens so vanilla inference is unaffected)
vanilla_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if vanilla_tokenizer.pad_token is None:
    vanilla_tokenizer.pad_token = vanilla_tokenizer.eos_token

clear_cache()

# low_cpu_mem_usage=True loads weights onto a meta device first, then
# materialises each tensor individually in bfloat16 — peak CPU RAM is
# ~1× model size instead of ~2×.  Then .to(device) moves everything to GPU
# in one pass.  This avoids the accelerate dispatch hooks used by
# device_map="auto", which were placing the model on CPU instead of GPU.
vanilla_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map='auto'
)

print(f"Model: {MODEL_NAME}")
print(f"Parameters: {sum(p.numel() for p in vanilla_model.parameters()) / 1e6:.1f}M")
print(f"Vocab size: {len(vanilla_tokenizer)}")
print(f"Model device: {next(vanilla_model.parameters()).device}")

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Loading weights:   0%|          | 0/338 [00:00<?, ?it/s]

AcceleratorError: CUDA error: out of memory
Search for `cudaErrorMemoryAllocation' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
dataset = load_dataset(DATASET_NAME, split="train")

print(f"Dataset: {DATASET_NAME}")
print(f"Total examples: {len(dataset)}")
print(f"Columns: {dataset.column_names}")

# Split into train (90%) and test (10%)
split = dataset.train_test_split(test_size=0.1, seed=42)
train_data = split["train"]
test_data = split["test"]

print(f"\nTrain: {len(train_data)} examples")
print(f"Test:  {len(test_data)} examples")

# Preview a test example
example = test_data[0]
print(f"\n--- Example Problem (truncated) ---")
print(example["problem"][:300])
print(f"\n--- Expected Answer ---")
print(example["expected_answer"])

Dataset: anujjamwal/openmathreasoning-hierarchical-cot
Total examples: 10
Columns: ['expected_answer', 'problem_type', 'problem_source', 'generation_model', 'pass_rate_72b_tir', 'problem', 'generated_solution', 'inference_mode', 'used_in_kaggle', 'hierarchical_cot']

Train: 9 examples
Test:  1 examples

--- Example Problem (truncated) ---
Calculate the integral

\[
\int^{\frac{3\pi}{2}}_{\frac{\pi}{2}} \left|\left(\frac{2}{x^3}+\frac{1}{x}\right)\sin x\right|dx
\]

--- Expected Answer ---
\(\frac{2}{\pi} + \frac{32}{9\pi^2}\)


In [None]:
def extract_boxed_answer(text):
    """Extract the answer from the last \\boxed{...} in text, handling nested braces."""
    idx = text.rfind("\\boxed{")
    if idx == -1:
        return None
    idx += len("\\boxed{")
    depth = 1
    end = idx
    while end < len(text) and depth > 0:
        if text[end] == "{":
            depth += 1
        elif text[end] == "}":
            depth -= 1
        end += 1
    return text[idx : end - 1] if depth == 0 else None

In [None]:
# Execute vanilla model with each example and capture token counts
vanilla_results = []
vanilla_model.eval()

for i, example in enumerate(tqdm(train_data, desc="Vanilla inference")):
    problem = example["problem"]
    prompt = PROMPT_TEMPLATE.format(problem=problem)

    # inputs = tokenizer(prompt, return_tensors="pt")
    inputs = vanilla_tokenizer(prompt, return_tensors="pt").to(device)
    input_length = inputs["input_ids"].shape[1]

    with torch.no_grad():
        outputs = vanilla_model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
        )


    generated_tokens = outputs.shape[1] - input_length
    output_text = vanilla_tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
    predicted_answer = extract_boxed_answer(output_text)

    vanilla_results.append(
        {
            "problem": problem,
            "output": output_text,
            "predicted_answer": predicted_answer,
            "expected_answer": example["expected_answer"],
            "tokens_generated": generated_tokens,
            "total_context_length": outputs.shape[1],
        }
    )
    
    del output_text
    del outputs
    del inputs

    print(f"\n[Example {i}]")
    print(f"  Problem:   {problem[:100]}...")
    print(f"  Predicted: {predicted_answer}")
    print(f"  Expected:  {example['expected_answer'][:80]}")
    print(f"  Tokens generated: {generated_tokens}")

print(f"\n--- Vanilla Summary ---")
print(
    f"Avg tokens generated:  {np.mean([r['tokens_generated'] for r in vanilla_results]):.1f}"
)
print(
    f"Avg context length:    {np.mean([r['total_context_length'] for r in vanilla_results]):.1f}"
)

Vanilla inference:   0%|          | 0/9 [00:00<?, ?it/s]

AcceleratorError: CUDA error: out of memory
Search for `cudaErrorMemoryAllocation' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# Load Model For Training
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

clear_cache()

try:
    from torch.nn.attention.flex_attention import create_block_mask
    _attn_impl = "flex_attention"
    print("FlexAttention available — using memory-efficient block-sparse masking")
except ImportError:
    _attn_impl = "sdpa"
    print("WARNING: FlexAttention not available (requires PyTorch >= 2.3).")
    print("         Falling back to SDPA. Training will OOM at seq_len > ~4096.")

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    attn_implementation=_attn_impl,
    dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map=
)

print(f"Model device: {next(base_model.parameters()).device}")

FlexAttention available — using memory-efficient block-sparse masking


Loading weights:   0%|          | 0/338 [00:00<?, ?it/s]

AcceleratorError: CUDA error: out of memory
Search for `cudaErrorMemoryAllocation' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# --- Add special tokens and prepare for hierarchical training ---
num_added = tokenizer.add_special_tokens(
    {"additional_special_tokens": SPECIAL_TOKENS}
)
print(f"Added {num_added} special tokens: {SPECIAL_TOKENS}")
print(f"New vocab size: {len(tokenizer)}")


# --- Checkpoint detection and loading ---
from pathlib import Path

start_epoch = 0
training_losses = []
checkpoint_path = Path(CHECKPOINT_DIR)


def _load_model_from_dir(model_dir):
    """Load model, tokenizer, and optional training state from a local directory.
    Returns (model, tokenizer, training_state_dict) or (None, None, None) if not found."""
    p = Path(model_dir)
    if not (p / "config.json").exists():
        return None, None, None
    m = AutoModelForCausalLM.from_pretrained(
        str(model_dir), dtype=torch.bfloat16
    ).to(device)
    # Load the saved tokenizer so its vocab (including special tokens) exactly
    # matches the embedding matrix that was saved with the model.
    tok = AutoTokenizer.from_pretrained(str(model_dir))
    state = None
    state_path = p / "training_state.pt"
    if state_path.exists():
        state = torch.load(state_path, map_location=device)
    return m, tok, state


loaded_model, loaded_tokenizer, training_state = None, None, None

# # 1. Try local checkpoint
# if checkpoint_path.exists():
#     print(f"Trying local checkpoint: {checkpoint_path}")
#     loaded_model, loaded_tokenizer, training_state = _load_model_from_dir(checkpoint_path)

# # 2. Fall back to HF Hub
# if loaded_model is None:
#     try:
#         from huggingface_hub import snapshot_download
#         print(f"Trying HF Hub checkpoint: {HF_CHECKPOINT_REPO}")
#         hf_dir = snapshot_download(HF_CHECKPOINT_REPO)
#         loaded_model, loaded_tokenizer, training_state = _load_model_from_dir(hf_dir)
#     except Exception as e:
#         print(f"No HF checkpoint available ({type(e).__name__}), starting fresh")

if loaded_model is not None:
    base_model = loaded_model
    tokenizer = loaded_tokenizer
    if training_state is not None:
        start_epoch = training_state.get("epoch", 0)
        training_losses = training_state.get("losses", [])
        print(f"Resuming from epoch {start_epoch} ({len(training_losses)} epochs already logged)")
    else:
        print("Checkpoint loaded; no training state found, starting from epoch 0")
else:
    base_model.resize_token_embeddings(len(tokenizer))
    print("No checkpoint found — starting fresh training")

# Wrap the model with the hierarchical CoT generate loop
hier_model = CausalLMModelWithHierarchicalCot(base_model, tokenizer)
print(f"[THOUGHT] token id: {hier_model.thought_token_id}")
print(f"[SOLUTION] token id: {hier_model.solution_token_id}")
print(f"[RETURN] token id: {hier_model.return_token_id}")


# --- Training dataset ---
class HierarchicalCoTDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, max_length=MAX_SEQ_LENGTH):
        # Filter out examples with empty hierarchical_cot
        self.data = [
            ex for ex in hf_dataset if len(ex["hierarchical_cot"].strip()) > 0
        ]
        self.tokenizer = tokenizer
        self.max_length = max_length
        print(f"  Training examples after filtering empty CoTs: {len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        example = self.data[idx]
        prompt = PROMPT_TEMPLATE.format(problem=example["problem"])
        target = example["hierarchical_cot"]

        full_text = prompt + "\n" + target + self.tokenizer.eos_token

        encoding = self.tokenizer(
            full_text,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )

        input_ids = encoding["input_ids"].squeeze(0)
        attention_mask = encoding["attention_mask"].squeeze(0)

        # Mask prompt tokens so loss is only on the hierarchical CoT
        prompt_encoding = self.tokenizer(prompt + "\n", return_tensors="pt")
        prompt_length = prompt_encoding["input_ids"].shape[1]

        labels = input_ids.clone()
        labels[:prompt_length] = -100
        labels[attention_mask == 0] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }

Added 3 special tokens: ['[THOUGHT]', '[SOLUTION]', '[RETURN]']
New vocab size: 151668
No checkpoint found — starting fresh training
[THOUGHT] token id: 151665
[SOLUTION] token id: 151666
[RETURN] token id: 151667


In [None]:
# --- Memory tracing helpers ---
def _mem(label=""):
    alloc    = torch.cuda.memory_allocated()  / 1e9
    reserved = torch.cuda.memory_reserved()   / 1e9
    peak     = torch.cuda.max_memory_allocated() / 1e9
    total    = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"  [MEM] {label:<40} alloc={alloc:.2f}GB  reserved={reserved:.2f}GB  peak={peak:.2f}GB / {total:.1f}GB")

def _reset_peak():
    torch.cuda.reset_peak_memory_stats()

_mem("baseline (start of training cell)")

# -----------------------------------------------------------------------
# STEP 1: Training setup
# -----------------------------------------------------------------------

# Gradient checkpointing recomputes activations during backward instead of
# storing them — trades compute for memory.  Critical for long sequences.
base_model.gradient_checkpointing_enable()

train_ds = HierarchicalCoTDataset(train_data, tokenizer)

lengths = [train_ds[i]["attention_mask"].sum().item() for i in range(len(train_ds))]
print(f"\nActual token lengths:  min={min(lengths)}  max={max(lengths)}  mean={sum(lengths)/len(lengths):.0f}  padded_to={MAX_SEQ_LENGTH}")
print(f"Padding overhead without dynamic collation: {MAX_SEQ_LENGTH / (sum(lengths)/len(lengths)):.1f}x\n")

def dynamic_collate_fn(batch):
    """Pad each batch only to its longest real sequence (rounded up to nearest
    128 tokens, the FlexAttention block size).  With batch_size=1 this is
    simply the actual sequence length, eliminating the 4-5× padding overhead
    from static MAX_SEQ_LENGTH padding."""
    max_real = max(b["attention_mask"].sum().item() for b in batch)
    # Round up to multiple of 128 so FlexAttention block grid is aligned
    max_real = ((max_real + 127) // 128) * 128

    return {
        "input_ids":      torch.stack([b["input_ids"][:max_real]      for b in batch]),
        "attention_mask": torch.stack([b["attention_mask"][:max_real] for b in batch]),
        "labels":         torch.stack([b["labels"][:max_real]         for b in batch]),
    }

train_loader = DataLoader(
    train_ds, batch_size=1, shuffle=True, collate_fn=dynamic_collate_fn
)

NUM_EPOCHS = 30
LEARNING_RATE = 2e-5

if "optimizer" in globals():
    del globals()["optimizer"]
    torch.cuda.empty_cache()

optimizer = torch.optim.AdamW(
    hier_model.parameters(), lr=LEARNING_RATE, weight_decay=0.01
)
_mem("after optimizer creation")

# -----------------------------------------------------------------------
# STEP 2: Training loop
# -----------------------------------------------------------------------
hier_model.train()
for epoch in range(start_epoch, NUM_EPOCHS):
    epoch_losses = []
    progress = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS}")

    for step, batch in enumerate(progress):
        trace = (epoch == start_epoch and step < 2)
        if trace:
            _reset_peak()

        input_ids      = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels         = batch["labels"].to(device)
        real_len       = int(attention_mask.sum())

        if trace: _mem(f"[ep{epoch+1} s{step}] after .to(device)  real_len={real_len}")

        optimizer.zero_grad(set_to_none=True)

        outputs = hier_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        if trace: _mem(f"[ep{epoch+1} s{step}] after forward")

        loss = outputs.loss
        del outputs

        loss.backward()
        if trace: _mem(f"[ep{epoch+1} s{step}] after backward")

        torch.nn.utils.clip_grad_norm_(hier_model.parameters(), 1.0)
        optimizer.step()

        epoch_losses.append(loss.item())
        progress.set_postfix(loss=f"{loss.item():.4f}")
        del loss

    avg_loss = np.mean(epoch_losses)
    training_losses.append(avg_loss)
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS} — Average Loss: {avg_loss:.4f}")
    torch.cuda.empty_cache()

print("\nTraining complete!")

checkpoint_path.mkdir(parents=True, exist_ok=True)
base_model.save_pretrained(checkpoint_path)
tokenizer.save_pretrained(checkpoint_path)
torch.save(
    {"epoch": NUM_EPOCHS, "losses": training_losses},
    checkpoint_path / "training_state.pt",
)
print(f"Checkpoint saved to {checkpoint_path}")

In [None]:
# Push checkpoint to HuggingFace Hub
# Requires: `huggingface-cli login`  OR  HF_TOKEN env variable set
from huggingface_hub import HfApi

print(f"Pushing to https://huggingface.co/{HF_CHECKPOINT_REPO} ...")
base_model.push_to_hub(HF_CHECKPOINT_REPO, commit_message="Add hierarchical CoT fine-tuned model")
tokenizer.push_to_hub(HF_CHECKPOINT_REPO, commit_message="Add tokenizer with special tokens")

HfApi().upload_file(
    path_or_fileobj=str(checkpoint_path / "training_state.pt"),
    path_in_repo="training_state.pt",
    repo_id=HF_CHECKPOINT_REPO,
    commit_message="Add training state",
)

print(f"Done! Model available at: https://huggingface.co/{HF_CHECKPOINT_REPO}")

In [None]:
# Execute the trained hierarchical model on the same test set
hier_results = []
hier_model.eval()

for i, example in enumerate(tqdm(train_data, desc="Hierarchical inference")):
    problem = example["problem"]
    prompt = PROMPT_TEMPLATE.format(problem=problem)

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_length = inputs["input_ids"].shape[1]

    output_ids, tokens_generated, peak_context = hier_model.generate(
        input_ids=inputs["input_ids"],
        max_new_tokens=MAX_NEW_TOKENS,
    )

    # Decode with and without special tokens
    output_with_special = tokenizer.decode(
        output_ids[0][input_length:], skip_special_tokens=False
    )
    output_clean = tokenizer.decode(
        output_ids[0][input_length:], skip_special_tokens=True
    )
    predicted_answer = extract_boxed_answer(output_clean)

    hier_results.append(
        {
            "problem": problem,
            "output": output_with_special,
            "output_clean": output_clean,
            "predicted_answer": predicted_answer,
            "expected_answer": example["expected_answer"],
            "tokens_generated": tokens_generated,
            "peak_context_length": peak_context,
            "final_context_length": output_ids.shape[1],
        }
    )

    print(f"\n[Example {i}]")
    print(f"  Problem:   {problem[:100]}...")
    print(f"  Predicted: {predicted_answer}")
    print(f"  Expected:  {example['expected_answer'][:80]}")
    print(f"  Tokens generated: {tokens_generated}")
    print(f"  Peak context:     {peak_context}")
    print(f"  Final context:    {output_ids.shape[1]}")

# --- Side-by-side comparison ---
print(f"\n{'=' * 60}")
print(f"{'Metric':<30} {'Vanilla':>12} {'Hierarchical':>15}")
print(f"{'=' * 60}")
v_gen = np.mean([r["tokens_generated"] for r in vanilla_results])
h_gen = np.mean([r["tokens_generated"] for r in hier_results])
v_ctx = np.mean([r["total_context_length"] for r in vanilla_results])
h_ctx = np.mean([r["final_context_length"] for r in hier_results])
h_peak = np.mean([r["peak_context_length"] for r in hier_results])
print(f"{'Avg tokens generated':<30} {v_gen:>12.1f} {h_gen:>15.1f}")
print(f"{'Avg context length':<30} {v_ctx:>12.1f} {h_ctx:>15.1f}")
print(f"{'Avg peak context':<30} {'—':>12} {h_peak:>15.1f}")
print(f"{'=' * 60}")

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle("Vanilla vs Hierarchical CoT Comparison", fontsize=14, fontweight="bold")

n = len(train_data)
x = np.arange(n)
width = 0.35

# 1. Tokens generated per example
ax = axes[0, 0]
v_tokens = [r["tokens_generated"] for r in vanilla_results]
h_tokens = [r["tokens_generated"] for r in hier_results]
ax.bar(x - width / 2, v_tokens, width, label="Vanilla", color="steelblue", alpha=0.8)
ax.bar(
    x + width / 2, h_tokens, width, label="Hierarchical", color="coral", alpha=0.8
)
ax.set_xlabel("Test Example")
ax.set_ylabel("Tokens Generated")
ax.set_title("Tokens Generated per Example")
ax.set_xticks(x)
ax.legend()

# 2. Context window usage
ax = axes[0, 1]
v_context = [r["total_context_length"] for r in vanilla_results]
h_final = [r["final_context_length"] for r in hier_results]
h_peak = [r["peak_context_length"] for r in hier_results]
ax.bar(
    x - width / 2,
    v_context,
    width,
    label="Vanilla (total)",
    color="steelblue",
    alpha=0.8,
)
ax.bar(
    x + width / 2,
    h_final,
    width,
    label="Hierarchical (final)",
    color="coral",
    alpha=0.8,
)
ax.scatter(
    x + width / 2,
    h_peak,
    color="red",
    zorder=5,
    label="Hierarchical (peak)",
    marker="^",
    s=80,
)
ax.set_xlabel("Test Example")
ax.set_ylabel("Context Length (tokens)")
ax.set_title("Context Window Usage")
ax.set_xticks(x)
ax.legend()

# 3. Context reduction percentage
ax = axes[1, 0]
savings = [
    (v - h) / v * 100 if v > 0 else 0 for v, h in zip(v_context, h_final)
]
colors = ["green" if s > 0 else "red" for s in savings]
ax.bar(x, savings, color=colors, alpha=0.7)
ax.axhline(y=0, color="black", linewidth=0.5)
if savings:
    ax.axhline(
        y=np.mean(savings),
        color="blue",
        linestyle="--",
        label=f"Mean: {np.mean(savings):.1f}%",
    )
ax.set_xlabel("Test Example")
ax.set_ylabel("Context Reduction (%)")
ax.set_title("Context Savings from Hierarchical Pruning")
ax.set_xticks(x)
ax.legend()

# 4. Training loss curve
ax = axes[1, 1]
ax.plot(
    range(1, len(training_losses) + 1),
    training_losses,
    "o-",
    color="purple",
    linewidth=2,
)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("Training Loss")
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("verification_results.png", dpi=150, bbox_inches="tight")
plt.show()

# Final summary
print("\n" + "=" * 60)
print("VERIFICATION SUMMARY")
print("=" * 60)
print(f"{'Vanilla avg tokens generated:':<42} {np.mean(v_tokens):.1f}")
print(f"{'Hierarchical avg tokens generated:':<42} {np.mean(h_tokens):.1f}")
print(f"{'Vanilla avg context length:':<42} {np.mean(v_context):.1f}")
print(f"{'Hierarchical avg final context:':<42} {np.mean(h_final):.1f}")
print(f"{'Hierarchical avg peak context:':<42} {np.mean(h_peak):.1f}")
print(f"{'Average context reduction:':<42} {np.mean(savings):.1f}%")