# Hierarchical Chain-of-Thought Training

Fine-tune Qwen3-0.6B on the OpenMathReasoning Hierarchical CoT dataset using `HCotTrainer`.

## Setup

Clone the repo (Colab) or configure `sys.path` so that `model` and `training` packages are importable.

In [1]:
import sys, os

# When running in Colab, clone the repo and add lib/ to the path
if "google.colab" in sys.modules:
    if not os.path.exists("cs224n-final-project"):
        !git clone https://github.com/anujjamwal/cs224n-final-project.git
    else:
        !cd cs224n-final-project && git pull
    sys.path.insert(0, "cs224n-final-project/lib")
else:
    # Local: notebook lives inside lib/ already
    sys.path.insert(0, os.path.dirname(os.path.abspath("__file__")))

Cloning into 'cs224n-final-project'...
remote: Enumerating objects: 97, done.[K
remote: Counting objects: 100% (97/97), done.[K
remote: Compressing objects: 100% (65/65), done.[K
remote: Total 97 (delta 45), reused 80 (delta 28), pack-reused 0 (from 0)[K
Receiving objects: 100% (97/97), 200.04 KiB | 1.64 MiB/s, done.
Resolving deltas: 100% (45/45), done.


In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from datasets import load_dataset

from model import generate
from model.model import THOUGHT_TOKEN, SOLUTION_TOKEN, RETURN_TOKEN, SPECIAL_TOKENS


## Load Model and Tokenizer

In [3]:
MODEL_NAME = "nvidia/OpenMath-Nemotron-1.5B"
model_repo_id = "anujjamwal/OpenMath-Nemotron-1.5B-hcot"

In [None]:
import os
os.environ['HF_TOKEN'] = ''

## Load and Tokenize Dataset

In [5]:
eval_ds = load_dataset("davidanugraha/OpenMathReasoning-Sampled", split="train")

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


README.md:   0%|          | 0.00/539 [00:00<?, ?B/s]

data/train-00000-of-00003.parquet:   0%|          | 0.00/141M [00:00<?, ?B/s]

data/train-00001-of-00003.parquet:   0%|          | 0.00/149M [00:00<?, ?B/s]

data/train-00002-of-00003.parquet:   0%|          | 0.00/179M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/92544 [00:00<?, ? examples/s]

In [6]:
# Prepare dataset in completion format
def preprocess(example):
    prompt = "Solve the following math problem. Make sure to put the answer (and only answer) inside \\boxed{}."
    
    assistant_content = f"<think>\n{example['generated_solution']}\n</think>\n\\boxed{{{example['expected_answer']}}}"
    
    return {
        "prompt": [
            {"role": "system", "content": prompt},
            {"role": "user", "content": example["question"]},
        ],
        "completion": [
            {"role": "assistant", "content": assistant_content},
        ],
    }

eval_prep_ds = eval_ds.map(preprocess, remove_columns=eval_ds.column_names)


Map:   0%|          | 0/92544 [00:00<?, ? examples/s]

## Benchmark

In [7]:
import re
import pandas as pd

def extract_boxed(text):
    """Extract content from the last \\boxed{...}, handling nested braces."""
    pattern = r'\\boxed\{'
    matches = list(re.finditer(pattern, text))
    if not matches:
        return None
    start = matches[-1].end()
    depth = 1
    i = start
    while i < len(text) and depth > 0:
        if text[i] == '{':
            depth += 1
        elif text[i] == '}':
            depth -= 1
        i += 1
    return text[start:i-1].strip() if depth == 0 else None

In [8]:
indices_seen = list(range(60, 80))
indices_unseen = list(range(1020, 1040))
indices = indices_seen
results = []

In [9]:
BATCH_SIZE = 10

def batch_tokenize(tok, prompts, device):
    """Tokenize multiple chat prompts with left-padding for batched generation."""
    prev_side = tok.padding_side
    tok.padding_side = "left"
    pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id

    # Tokenize each conversation individually
    encoded = [
        tok.apply_chat_template(p, add_generation_prompt=True, return_tensors="pt", return_dict=True)
        for p in prompts
    ]

    # Left-pad to max length in the batch
    max_len = max(e["input_ids"].shape[1] for e in encoded)
    input_ids = torch.full((len(encoded), max_len), pad_id, dtype=torch.long)
    attention_mask = torch.zeros((len(encoded), max_len), dtype=torch.long)

    for i, e in enumerate(encoded):
        seq_len = e["input_ids"].shape[1]
        input_ids[i, max_len - seq_len:] = e["input_ids"][0]
        attention_mask[i, max_len - seq_len:] = 1

    tok.padding_side = prev_side
    return {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)}


def run_batched_benchmark(cur_model, tok, indices, modes, batch_size=BATCH_SIZE):
    """Run benchmark in batches for the given modes.

    Parameters
    ----------
    modes : list of (mode_name, generate_func, extra_kwargs) tuples
    """
    thought_id = tok.convert_tokens_to_ids(THOUGHT_TOKEN)
    solution_id = tok.convert_tokens_to_ids(SOLUTION_TOKEN)
    return_id = tok.convert_tokens_to_ids(RETURN_TOKEN)
    eos_id = tok.eos_token_id

    all_results = []
    with torch.no_grad():
        for batch_start in range(0, len(indices), batch_size):
            batch_idx = indices[batch_start:batch_start + batch_size]
            prompts = [eval_prep_ds[idx]["prompt"] for idx in batch_idx]
            expected_answers = [eval_ds[idx]["expected_answer"].strip() for idx in batch_idx]

            inp = batch_tokenize(tok, prompts, cur_model.device)

            for mode_name, gen_func, extra_kw in modes:
                out = cur_model.generate(
                    **inp,
                    max_new_tokens=8192,
                    thought_token_id=thought_id,
                    solution_token_id=solution_id,
                    return_token_id=return_id,
                    eos_token_id=eos_id,
                    custom_generate=gen_func,
                    **extra_kw,
                )

                for j, idx in enumerate(batch_idx):
                    decoded = tok.decode(out.sequences[j], skip_special_tokens=False)
                    predicted = extract_boxed(decoded)
                    all_results.append(dict(
                        idx=idx, mode=mode_name,
                        prompt_tokens=out.prompt_tokens,
                        generated_tokens=out.generated_tokens,
                        total_tokens_processed=out.total_tokens_processed[j],
                        output_tokens=out.output_tokens[j],
                        prune_events=out.prune_events[j],
                        tokens_pruned=out.tokens_pruned[j],
                        wall_time=out.wall_time_seconds / len(batch_idx),
                        expected=expected_answers[j],
                        predicted=predicted,
                        correct=(predicted is not None and predicted == expected_answers[j]),
                    ))

            done = min(batch_start + batch_size, len(indices))
            print(f"[{done}/{len(indices)}] batch done")

    return all_results


In [10]:
base_model = AutoModelForCausalLM.from_pretrained(
  MODEL_NAME,
  dtype=torch.bfloat16,
  device_map='auto'
)
base_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
base_model.eval()

config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

Loading weights:   0%|          | 0/338 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/121 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((1536,), eps=1e-06)
    (rotar

In [11]:
hcot_modes = [
    ("Baseline", generate.generate_standard, {"use_cache": True}),
]

baseline_batched_results = run_batched_benchmark(base_model, base_tokenizer, indices, hcot_modes, batch_size=BATCH_SIZE)
df_batched_baseline = pd.DataFrame(baseline_batched_results)
print(f"\nBatched benchmark complete: {len(indices)} examples x {len(hcot_modes)} modes = {len(df_batched_baseline)} runs")

Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Passing `generation_config` together with generation-related arguments=({'use_cache'}) is deprecated and will be removed in future versions. Please pass either a `generation_config` object OR all generation parameters explicitly, but not both.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Both `max_new_tokens` (=8192) and `max_length`(=8324) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
A custom stopping criteria of type <class 'transformers.generation.stopping_criteria.MaxLengthCriteria'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.stopping_criteria.MaxLengthCriteria'> will take precedence. Please check the docstring of <class 'transfor

[10/20] batch done
[20/20] batch done

Batched benchmark complete: 20 examples x 1 modes = 20 runs


In [12]:
df_batched_baseline

Unnamed: 0,idx,mode,prompt_tokens,generated_tokens,total_tokens_processed,output_tokens,prune_events,tokens_pruned,wall_time,expected,predicted,correct
0,60,Baseline,132,8192,8324,8324,0,0,18.130596,\( A = \begin{pmatrix} -1 & 1 \\ 1 & -1 \end{p...,A = \begin{pmatrix} 1 & 0 \\ 0 & 0 \end{pmatri...,False
1,61,Baseline,132,8192,8324,8324,0,0,18.130596,\(\frac{n\varphi(n)}{2}\),\frac{n \phi(n)}{2},False
2,62,Baseline,132,8192,8324,8324,0,0,18.130596,3,3,True
3,63,Baseline,132,8192,8324,8324,0,0,18.130596,\(\frac{2}{3}\),\frac{2}{3},False
4,64,Baseline,132,8192,8324,8324,0,0,18.130596,\( 2(1 - \sqrt{3}) \le a < 0 \),,False
5,65,Baseline,132,8192,8324,8324,0,0,18.130596,1964,1964,True
6,66,Baseline,132,8192,8324,8324,0,0,18.130596,\(\frac{15309}{256}\),\frac{15309}{256},False
7,67,Baseline,132,8192,8324,8324,0,0,18.130596,"\( (p, q) = (p, 2) \) for any prime \( p \) an...","(p, q) \text{ where } q = 2 \text{ or } p = 3",False
8,68,Baseline,132,8192,8324,8324,0,0,18.130596,\(-7 \ln|x-1| + \frac{1}{2(x-1)^2} + \frac{4}{...,-7 \ln|x-1| + \frac{4}{x-1} + \frac{1}{2(x-1)^...,False
9,69,Baseline,132,8192,8324,8324,0,0,18.130596,005,5,False


In [13]:
del(base_model)
del(base_tokenizer)
torch.cuda.empty_cache()

In [15]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [16]:
save_path = f'/content/drive/MyDrive/df_batched_baseline.csv'
df_batched_baseline.to_csv(save_path, index=False)

In [17]:
model = AutoModelForCausalLM.from_pretrained(
  model_repo_id,
  dtype=torch.bfloat16,
  device_map='auto'
)
tokenizer = AutoTokenizer.from_pretrained(model_repo_id)
model.eval()

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

Loading weights:   0%|          | 0/338 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/130 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/858 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/628 [00:00<?, ?B/s]

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151670, 1536, padding_idx=151643)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((1536,), e

In [18]:
hcot_modes = [
    ("HCoT Prune", generate.generate, {"use_cache": False}),
    ("HCoT Mask",  generate.generate_with_mask, {"use_cache": False, "min_token_length": 2048}),
]

batched_results = run_batched_benchmark(model, tokenizer, indices, hcot_modes, batch_size=BATCH_SIZE)
df_batched = pd.DataFrame(batched_results)
print(f"\nBatched benchmark complete: {len(indices)} examples x {len(hcot_modes)} modes = {len(df_batched)} runs")

KeyboardInterrupt: 

In [None]:
save_path = f'/content/drive/MyDrive/df_batched.csv'
df_batched.to_csv(save_path, index=False)

In [None]:
hcot_modes = [
    ("HCoT Prune (CACHED)", generate.generate, {"use_cache": True}),
    ("HCoT Mask (CACHED)",  generate.generate_with_mask, {"use_cache": True, "min_token_length": 2048})
]

batched_results = run_batched_benchmark(model, tokenizer, indices, hcot_modes, batch_size=BATCH_SIZE)
df_batched_cached = pd.DataFrame(batched_results)
print(f"\nBatched benchmark complete: {len(indices)} examples x {len(hcot_modes)} modes = {len(df_batched_cached)} runs")

In [None]:
save_path = f'/content/drive/MyDrive/df_batched_cached.csv'
df_batched_cached.to_csv(save_path, index=False)

In [None]:
del(model)
del(tokenizer)
torch.cuda.empty_cache()

In [None]:
df = pd.concat([df_batched, df_batched_cached, df_batched_baseline], ignore_index=True)

### Summary Statistics

In [None]:
# --- Summary table grouped by mode ---
summary = df.groupby("mode").agg(
    num_examples=("idx", "count"),
    avg_prompt_tokens=("prompt_tokens", "mean"),
    avg_generated_tokens=("generated_tokens", "mean"),
    avg_total_tokens_processed=("total_tokens_processed", "mean"),
    avg_output_tokens=("output_tokens", "mean"),
    total_prune_events=("prune_events", "sum"),
    total_tokens_pruned=("tokens_pruned", "sum"),
    avg_wall_time=("wall_time", "mean"),
    accuracy=("correct", "mean"),
).reindex(["Standard", "HCoT Prune", "HCoT Prune (CACHED)", "HCoT Mask", "HCoT Mask (CACHED)"])

# Compute token reduction % relative to Standard
std_avg = summary.loc["Standard", "avg_total_tokens_processed"]
summary["token_reduction_pct"] = ((std_avg - summary["avg_total_tokens_processed"]) / std_avg * 100).round(1)

print(summary.to_string())
print(f"\n--- Token Processing Reduction vs Standard ---")
for mode in ["HCoT Prune", "HCoT Prune (CACHED)", "HCoT Mask", "HCoT Mask (CACHED)"]:
    pct = summary.loc[mode, "token_reduction_pct"]
    print(f"  {mode}: {pct:+.1f}%")

### Per-Example Correctness

In [None]:
# Pivot correctness by example to compare modes side by side
correctness = df.pivot(index="idx", columns="mode", values="correct")[["Standard", "HCoT Prune", "HCoT Prune (CACHED)", "HCoT Mask", "HCoT Mask (CACHED)"]]
correctness.columns = [f"{c} correct" for c in correctness.columns]

# Add expected answer for context
expected_map = df.drop_duplicates("idx").set_index("idx")["expected"]
correctness.insert(0, "expected", expected_map)

# Add predicted answers per mode
for mode in ["Standard", "HCoT Prune", "HCoT Prune (CACHED)", "HCoT Mask", "HCoT Mask (CACHED)"]:
    pred = df[df["mode"] == mode].set_index("idx")["predicted"]
    correctness[f"{mode} predicted"] = pred

pd.set_option("display.max_colwidth", 40)
pd.set_option("display.max_rows", 50)
display(correctness)

### Plots

In [None]:
import matplotlib.pyplot as plt
import numpy as np

modes = ["Standard", "HCoT Prune", "HCoT Prune (CACHED)", "HCoT Mask", "HCoT Mask (CACHED)"]
colors = ["#4C72B0", "#DD8452", "#55A868"]

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# --- 1. Average Total Tokens Processed (bar) ---
ax = axes[0, 0]
vals = [summary.loc[m, "avg_total_tokens_processed"] for m in modes]
bars = ax.bar(modes, vals, color=colors)
ax.set_ylabel("Avg Total Tokens Processed")
ax.set_title("Total Tokens Processed per Example")
for bar, v in zip(bars, vals):
    ax.text(bar.get_x() + bar.get_width()/2, v, f"{v:,.0f}", ha="center", va="bottom", fontsize=9)

# --- 2. Per-example tokens processed (scatter) ---
ax = axes[0, 1]
for mode, color in zip(modes, colors):
    sub = df[df["mode"] == mode].sort_values("idx")
    ax.plot(range(len(sub)), sub["total_tokens_processed"].values, "o-", label=mode, color=color, markersize=3, alpha=0.8)
ax.set_xlabel("Example (sorted by idx)")
ax.set_ylabel("Total Tokens Processed")
ax.set_title("Tokens Processed per Example")
ax.legend(fontsize=8)

# --- 3. Accuracy comparison (bar) ---
ax = axes[1, 0]
accs = [summary.loc[m, "accuracy"] * 100 for m in modes]
bars = ax.bar(modes, accs, color=colors)
ax.set_ylabel("Accuracy (%)")
ax.set_title("Answer Accuracy")
ax.set_ylim(0, 100)
for bar, v in zip(bars, accs):
    ax.text(bar.get_x() + bar.get_width()/2, v + 1, f"{v:.1f}%", ha="center", va="bottom", fontsize=9)

# --- 4. Wall time comparison (bar) ---
ax = axes[1, 1]
times = [summary.loc[m, "avg_wall_time"] for m in modes]
bars = ax.bar(modes, times, color=colors)
ax.set_ylabel("Avg Wall Time (s)")
ax.set_title("Average Generation Time")
for bar, v in zip(bars, times):
    ax.text(bar.get_x() + bar.get_width()/2, v, f"{v:.1f}s", ha="center", va="bottom", fontsize=9)

plt.tight_layout()
plt.savefig("benchmark_results.png", dpi=150, bbox_inches="tight")
plt.show()