In [1]:
import os
import json
import argparse
import numpy as np
import torch
from datasets import load_dataset
from transformers import GenerationConfig, AutoConfig, AutoTokenizer, BitsAndBytesConfig
from vllm import LLM, SamplingParams
import re
import math
from math_verify import parse, verify, LatexExtractionConfig
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


INFO 08-29 03:45:50 [__init__.py:239] Automatically detected platform cuda.


2025-08-29 03:45:52,728	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
DATASET_MAP = {
    "MATH-500": {
        "args": ("HuggingFaceH4/MATH-500", "test"),
        "question_key": "problem",
        "answer_key": "answer"
    },
    "AIME2024": {
        "args": ("HuggingFaceH4/aime_2024", "train"),
        "question_key": "problem",
        "answer_key": "answer"
    },
    "gpqa": {
        "args": ("hendrydong/gpqa_diamond_mc", "test"),
        "question_key": "problem",
        "answer_key": "solution"
    },
    "gsm8k": {
        "args": ("skrishna/gsm8k_only_answer", "test"),
        "question_key": "text",
        "answer_key": "label"
    },
    "openr1-math": {
        "args": ("open-r1/OpenR1-Math-220k", "train"),
        "question_key": "problem",
        "answer_key": "answer"
    },
    "AIME2025": {
        "args": ("yentinglin/aime_2025", "train"),
        "question_key": "problem",
        "answer_key": "answer"
    },
    "MMLU-Pro-math": {
        "args": ("TIGER-Lab/MMLU-Pro", "test"),
        "options_key": "options",
        "question_key": "question",
        "answer_key": "answer"
    }
}

In [3]:
MODEL_MAP   = {
    "deepseek-qwen-1.5b": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    "deepseek-llama3-8b": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    "deepseek-qwen-14b": "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
    "qwq-32b": "Qwen/QwQ-32B",
    "qwen3-8b": "Qwen/Qwen3-8B",
    "deepseek-qwen3-8b": "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
    "phi4-reasoning-plus": "microsoft/Phi-4-reasoning-plus",
    "nemotron-7b": "nvidia/OpenMath-Nemotron-7B",
}

In [5]:
dataset = "gsm8k"
model = "deepseek-llama3-8b"

In [6]:
dataset_name, split = DATASET_MAP[dataset]["args"]
ds = load_dataset(dataset_name, split=split)
question_key = DATASET_MAP[dataset]["question_key"]
answer_key   = DATASET_MAP[dataset]["answer_key"]

Generating train split: 100%|██████████| 7473/7473 [00:00<00:00, 100907.32 examples/s]
Generating test split: 100%|██████████| 1319/1319 [00:00<00:00, 84492.07 examples/s]


In [7]:
model_id = MODEL_MAP[model]
max_pos = AutoConfig.from_pretrained(model_id).max_position_embeddings
cfg = GenerationConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [8]:
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",       
#     bnb_4bit_compute_dtype=torch.float16
# )

llm = LLM(
    model=model_id,
    max_model_len=4096,
    dtype="half",
    # gpu_memory_utilization=0.7,
    # quantization="bitsandbytes"
)

INFO 08-27 07:03:02 [config.py:717] This model supports multiple tasks: {'score', 'embed', 'classify', 'reward', 'generate'}. Defaulting to 'generate'.
INFO 08-27 07:03:02 [llm_engine.py:240] Initializing a V0 LLM engine (v0.8.5) with config: model='deepseek-ai/DeepSeek-R1-Distill-Llama-8B', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-R1-Distill-Llama-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_exe

Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:07<00:07,  7.04s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:14<00:00,  7.01s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:14<00:00,  7.02s/it]



INFO 08-27 07:05:43 [loader.py:458] Loading weights took 14.27 seconds
INFO 08-27 07:05:43 [model_runner.py:1140] Model loading took 14.9889 GiB and 154.309636 seconds
INFO 08-27 07:05:45 [worker.py:287] Memory profiling takes 1.23 seconds
INFO 08-27 07:05:45 [worker.py:287] the current vLLM instance can use total_gpu_memory (31.73GiB) x gpu_memory_utilization (0.90) = 28.56GiB
INFO 08-27 07:05:45 [worker.py:287] model weights take 14.99GiB; non_torch_memory takes 0.09GiB; PyTorch activation peak memory takes 1.20GiB; the rest of the memory reserved for KV Cache is 12.28GiB.
INFO 08-27 07:05:45 [executor_base.py:112] # cuda blocks: 6287, # CPU blocks: 2048
INFO 08-27 07:05:45 [executor_base.py:117] Maximum concurrency for 4096 tokens per request: 24.56x
INFO 08-27 07:05:49 [model_runner.py:1450] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the 

Capturing CUDA graph shapes: 100%|██████████| 35/35 [00:37<00:00,  1.06s/it]

INFO 08-27 07:06:26 [model_runner.py:1592] Graph capturing finished in 37 secs, took 0.24 GiB
INFO 08-27 07:06:26 [llm_engine.py:437] init engine (profile, create kv cache, warmup model) took 42.44 seconds





In [9]:
def make_params(n: int, budget: int, cfg) -> SamplingParams:
    """
    Build SamplingParams from model config and given budget.
    """
    kw = {"n": n, "max_tokens": budget}
    if hasattr(cfg, "temperature") and cfg.temperature is not None:
        kw["temperature"] = cfg.temperature
    if hasattr(cfg, "top_k") and cfg.top_k is not None:
        kw["top_k"] = cfg.top_k
    if hasattr(cfg, "top_p") and cfg.top_p is not None:
        kw["top_p"] = cfg.top_p
    return SamplingParams(**kw)

In [26]:
def apply_chat(prompt: str, tokenizer):
    """
    Wraps a user prompt in the vLLM chat template.
    """
    conversations = [{"role": "user", "content": prompt}]
    return tokenizer.apply_chat_template(
        conversations,
        tokenize=False,
        add_generation_prompt=True
    )

In [27]:
prompts = []
for ex in tqdm(ds):
    q = ex[question_key]
    prompt = (
                f"Problem: {q}\n\n"
                "Please reason step by step, and put your final answer within \\boxed{}."
            )
    prompts.append(apply_chat(prompt, tokenizer))

100%|██████████| 1319/1319 [00:00<00:00, 6828.61it/s]


In [33]:
sampling_params = SamplingParams(n=1, temperature=0.0, max_tokens=3072)

In [31]:
llm.generate(prompts[0], sampling_params=sampling_params)

TypeError: LLM.generate() got an unexpected keyword argument 'do_sample'

In [None]:
results = llm.generate(prompts=prompts, sampling_params=sampling_params)

Processed prompts:   0%|          | 0/6595 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]



Processed prompts:   0%|          | 20/6595 [02:09<7:29:46,  4.10s/it, est. speed input: 11.49 toks/s, output: 138.95 toks/s]



Processed prompts:   1%|▏         | 95/6595 [04:07<2:18:18,  1.28s/it, est. speed input: 28.91 toks/s, output: 469.42 toks/s]



Processed prompts:   3%|▎         | 165/6595 [05:13<1:21:02,  1.32it/s, est. speed input: 41.42 toks/s, output: 785.41 toks/s]



Processed prompts:   3%|▎         | 180/6595 [06:22<3:35:47,  2.02s/it, est. speed input: 37.46 toks/s, output: 684.68 toks/s]



Processed prompts:   3%|▎         | 215/6595 [07:28<2:53:11,  1.63s/it, est. speed input: 37.44 toks/s, output: 674.88 toks/s]



Processed prompts:   4%|▎         | 245/6595 [08:13<2:55:39,  1.66s/it, est. speed input: 38.61 toks/s, output: 718.85 toks/s]



Processed prompts:   4%|▍         | 260/6595 [08:48<3:30:33,  1.99s/it, est. speed input: 38.38 toks/s, output: 719.39 toks/s]



Processed prompts:   4%|▍         | 275/6595 [09:18<2:53:22,  1.65s/it, est. speed input: 39.13 toks/s, output: 734.85 toks/s]



Processed prompts:   5%|▍         | 305/6595 [09:45<1:30:59,  1.15it/s, est. speed input: 41.60 toks/s, output: 776.07 toks/s]



Processed prompts:   5%|▍         | 310/6595 [10:19<4:25:09,  2.53s/it, est. speed input: 39.88 toks/s, output: 746.22 toks/s]



Processed prompts:   5%|▍         | 320/6595 [10:28<3:11:08,  1.83s/it, est. speed input: 40.75 toks/s, output: 767.60 toks/s]



Processed prompts:   5%|▌         | 360/6595 [11:37<3:10:42,  1.84s/it, est. speed input: 40.75 toks/s, output: 773.12 toks/s]



Processed prompts:   6%|▌         | 370/6595 [11:55<2:55:44,  1.69s/it, est. speed input: 41.14 toks/s, output: 787.14 toks/s]



Processed prompts:   6%|▌         | 390/6595 [12:09<1:23:36,  1.24it/s, est. speed input: 42.88 toks/s, output: 819.76 toks/s]



Processed prompts:   6%|▌         | 410/6595 [13:15<4:03:09,  2.36s/it, est. speed input: 40.73 toks/s, output: 801.46 toks/s]



Processed prompts:   7%|▋         | 435/6595 [13:48<2:06:52,  1.24s/it, est. speed input: 41.66 toks/s, output: 802.74 toks/s]



Processed prompts:   7%|▋         | 450/6595 [14:17<2:17:13,  1.34s/it, est. speed input: 41.32 toks/s, output: 812.41 toks/s]



Processed prompts:   7%|▋         | 460/6595 [14:45<3:52:27,  2.27s/it, est. speed input: 40.74 toks/s, output: 801.92 toks/s]



Processed prompts:   7%|▋         | 480/6595 [15:14<2:46:52,  1.64s/it, est. speed input: 41.24 toks/s, output: 798.36 toks/s]



Processed prompts:   7%|▋         | 485/6595 [15:44<5:00:31,  2.95s/it, est. speed input: 40.47 toks/s, output: 789.19 toks/s]



Processed prompts:   8%|▊         | 495/6595 [16:03<3:58:01,  2.34s/it, est. speed input: 40.58 toks/s, output: 790.50 toks/s]



Processed prompts:   8%|▊         | 505/6595 [16:23<3:47:22,  2.24s/it, est. speed input: 40.76 toks/s, output: 791.80 toks/s]



Processed prompts:   8%|▊         | 520/6595 [16:47<2:56:26,  1.74s/it, est. speed input: 41.08 toks/s, output: 804.04 toks/s]



Processed prompts:   8%|▊         | 525/6595 [16:53<2:43:18,  1.61s/it, est. speed input: 41.31 toks/s, output: 813.55 toks/s]



Processed prompts:   8%|▊         | 530/6595 [17:02<2:50:19,  1.68s/it, est. speed input: 41.49 toks/s, output: 818.92 toks/s]



Processed prompts:   8%|▊         | 545/6595 [18:08<5:20:17,  3.18s/it, est. speed input: 40.37 toks/s, output: 790.69 toks/s]



Processed prompts:   9%|▉         | 580/6595 [19:05<2:44:29,  1.64s/it, est. speed input: 40.77 toks/s, output: 787.36 toks/s]



Processed prompts:   9%|▉         | 615/6595 [19:35<1:29:44,  1.11it/s, est. speed input: 42.23 toks/s, output: 810.81 toks/s]



Processed prompts:  10%|▉         | 645/6595 [20:01<1:16:56,  1.29it/s, est. speed input: 42.99 toks/s, output: 831.84 toks/s]



Processed prompts:  10%|▉         | 650/6595 [20:17<2:31:10,  1.53s/it, est. speed input: 42.86 toks/s, output: 829.29 toks/s]



Processed prompts:  10%|█         | 660/6595 [20:47<3:55:25,  2.38s/it, est. speed input: 42.51 toks/s, output: 819.27 toks/s]



Processed prompts:  10%|█         | 680/6595 [21:11<2:16:16,  1.38s/it, est. speed input: 43.02 toks/s, output: 827.87 toks/s]



Processed prompts:  10%|█         | 690/6595 [21:15<1:31:50,  1.07it/s, est. speed input: 43.37 toks/s, output: 832.22 toks/s]



Processed prompts:  11%|█         | 695/6595 [21:34<2:42:43,  1.65s/it, est. speed input: 43.11 toks/s, output: 823.97 toks/s]



Processed prompts:  11%|█         | 730/6595 [22:28<2:46:34,  1.70s/it, est. speed input: 43.77 toks/s, output: 831.90 toks/s]



Processed prompts:  11%|█▏        | 745/6595 [23:10<3:13:33,  1.99s/it, est. speed input: 43.34 toks/s, output: 825.65 toks/s]



Processed prompts:  12%|█▏        | 760/6595 [23:49<3:27:16,  2.13s/it, est. speed input: 42.94 toks/s, output: 820.14 toks/s]



Processed prompts:  12%|█▏        | 780/6595 [24:07<1:55:19,  1.19s/it, est. speed input: 43.76 toks/s, output: 839.27 toks/s]



Processed prompts:  12%|█▏        | 805/6595 [24:53<1:54:35,  1.19s/it, est. speed input: 44.16 toks/s, output: 843.25 toks/s]



Processed prompts:  12%|█▏        | 820/6595 [25:18<2:26:02,  1.52s/it, est. speed input: 44.28 toks/s, output: 848.27 toks/s]



Processed prompts:  13%|█▎        | 840/6595 [26:04<2:23:50,  1.50s/it, est. speed input: 43.98 toks/s, output: 842.15 toks/s]



Processed prompts:  13%|█▎        | 870/6595 [26:44<2:07:35,  1.34s/it, est. speed input: 44.29 toks/s, output: 846.16 toks/s]



Processed prompts:  13%|█▎        | 875/6595 [27:06<3:37:25,  2.28s/it, est. speed input: 43.92 toks/s, output: 843.94 toks/s]



Processed prompts:  14%|█▎        | 905/6595 [28:23<2:00:19,  1.27s/it, est. speed input: 43.41 toks/s, output: 833.31 toks/s]



Processed prompts:  14%|█▍        | 935/6595 [29:12<2:05:37,  1.33s/it, est. speed input: 43.22 toks/s, output: 830.78 toks/s]



Processed prompts:  14%|█▍        | 950/6595 [29:37<1:56:46,  1.24s/it, est. speed input: 43.32 toks/s, output: 832.57 toks/s]



Processed prompts:  14%|█▍        | 955/6595 [29:40<1:43:23,  1.10s/it, est. speed input: 43.39 toks/s, output: 836.34 toks/s]



Processed prompts:  15%|█▍        | 965/6595 [30:01<2:19:17,  1.48s/it, est. speed input: 43.48 toks/s, output: 841.95 toks/s]



Processed prompts:  15%|█▌        | 1000/6595 [31:04<1:57:40,  1.26s/it, est. speed input: 43.94 toks/s, output: 845.87 toks/s]



Processed prompts:  15%|█▌        | 1005/6595 [31:14<2:16:38,  1.47s/it, est. speed input: 43.88 toks/s, output: 845.57 toks/s]



Processed prompts:  16%|█▌        | 1035/6595 [31:47<1:30:59,  1.02it/s, est. speed input: 44.21 toks/s, output: 848.36 toks/s]



Processed prompts:  16%|█▌        | 1045/6595 [31:54<1:16:14,  1.21it/s, est. speed input: 44.49 toks/s, output: 853.96 toks/s]



Processed prompts:  16%|█▌        | 1070/6595 [32:39<2:02:56,  1.34s/it, est. speed input: 44.37 toks/s, output: 847.82 toks/s]



Processed prompts:  16%|█▋        | 1080/6595 [32:58<2:21:16,  1.54s/it, est. speed input: 44.38 toks/s, output: 849.12 toks/s]



Processed prompts:  17%|█▋        | 1090/6595 [33:09<1:58:27,  1.29s/it, est. speed input: 44.52 toks/s, output: 850.76 toks/s]



Processed prompts:  17%|█▋        | 1120/6595 [34:04<2:02:30,  1.34s/it, est. speed input: 44.47 toks/s, output: 846.38 toks/s]



Processed prompts:  17%|█▋        | 1135/6595 [34:35<2:31:29,  1.66s/it, est. speed input: 44.57 toks/s, output: 849.52 toks/s]



In [15]:
def verify_answer(pred: str, ref: str) -> bool:

    # ── patterns & threshold ─────────────────────────────────────────────────
    BASE_N_RE    = re.compile(r"^\(?([0-9A-Za-z]+)\)?_\{(\d+)\}$")
    EXP_RE       = re.compile(r"\^\{(\d+)\}")
    MAX_SAFE_EXP = 10_000

    # ── normalize inputs ─────────────────────────────────────────────────────
    if pred is None or ref is None:
        return False
    p = pred.strip()
    r = ref.strip()

    # ── 1) base-N literal in prediction ─────────────────────────────────────
    m = BASE_N_RE.match(p)
    if m:
        return m.group(1) == r

    # ── 2) base-N literal in reference ──────────────────────────────────────
    m = BASE_N_RE.match(r)
    if m:
        return m.group(1) == p

    # ── 3) huge-exponent guard ───────────────────────────────────────────────
    exps = [int(e) for e in EXP_RE.findall(p)]
    if exps and max(exps) > MAX_SAFE_EXP:
        return p.replace(" ", "") == r.replace(" ", "")

    # ── 4) fallback to math_verify ──────────────────────────────────────────
    wrap = lambda s: f"\\({s}\\)"
    cfg  = LatexExtractionConfig()
    try:
        g_node = parse(wrap(r), extraction_config=[cfg])
        p_node = parse(wrap(p), extraction_config=[cfg])
        return verify(g_node, p_node, float_rounding=2)
    except Exception:
        return False

def extract_answer(text):
    if text is None:
        return None
    # Step 1: Remove everything that is not a number, letter, ".", or "-"
    # text = re.sub(r'[^0-9a-zA-Z{}\\.\-]', '', text)
    # Try extracting from 'boxed' first
    boxed_matches = extract_boxed(text)
    if boxed_matches:
        extracted_answer = boxed_matches[-1][1:-1]
        return extracted_answer

    # Fallback: extract any numbers
    numbers = re.findall(r'-?\d+\.\d+|-?\d+', text)
    if not numbers:
        return None

    try:
        extracted_number = float(numbers[-1])
        # Guard against infinity
        if math.isinf(extracted_number):
            return None
        
        return numbers[-1]
    except (ValueError, OverflowError):
        return None

def extract_boxed(text):
    pattern = re.compile(r'boxed\{')
    matches = []
    stack = []
    
    i = 0
    while i < len(text):
        match = pattern.search(text, i)
        if not match:
            break
        
        start = match.end() - 1  # Position at the first `{`
        stack.append(start)
        i = start + 1
        count = 1  # To track `{}` pairs
        
        while i < len(text) and stack:
            if text[i] == '{':
                count += 1
            elif text[i] == '}':
                count -= 1
                if count == 0:  # Found a matching closing `}`
                    start = stack.pop()
                    matches.append(text[start:i+1])
                    break
            i += 1
    
    return matches

In [17]:
runs = {rid: [] for rid in range(10)}
for idx, gen in tqdm(enumerate(results)):
    gold = ds[idx][answer_key]
    for rid, out in enumerate(gen.outputs):
        text = out.text.strip()
        # prediction extraction
        pred = extract_answer(text)
        # correctness
        correct = False
        try:
            correct = verify_answer(gold, pred)
        except:
            pass
        # reasoning length (entire response) in tokens
        reasoning_length = len(tokenizer.encode(text, add_special_tokens=False))

        runs[rid].append({
            "question":         ds[idx][question_key],
            "full_response":    text,
            "reasoning_length": reasoning_length,
            "prediction":       pred,
            "gold":             gold,
            "correct":          correct
        })

1319it [02:13,  9.87it/s]


In [18]:
os.makedirs("hint_results/gsm8k/deepseek-qwen-1.5b", exist_ok=True)
output_path = (
    "hint_results/gsm8k/deepseek-qwen-1.5b/"
    "10_runs.json"
)
with open(output_path, "w", encoding="utf-8") as f:
    json.dump({"runs":[{"run_id":rid,"records":recs} for rid,recs in runs.items()]}, f, indent=4)