In [None]:
!pip -q install -U "transformers>=4.46.0" "datasets>=2.19.0" "accelerate>=0.30.0" \
    "peft>=0.11.0" "bitsandbytes>=0.45.0" "jinja2>=3.1.0" "tqdm"

In [None]:
import os, re, json, time, random
from decimal import Decimal, InvalidOperation
from typing import Optional, List, Dict, Any

import torch
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

In [None]:
# CONFIG
DATASET_NAME   = "openai/gsm8k"
DATASET_CONFIG = "main"
SPLIT          = "train"

TEACHER_NAME   = "Qwen/Qwen2.5-7B-Instruct"

OUTPUT_DIR = "/kaggle/working/gsm8k_cot_kd_ready"
os.makedirs(OUTPUT_DIR, exist_ok=True)

OUT_JSONL = os.path.join(OUTPUT_DIR, "gsm8k_train_cot_kd_ready.jsonl")
OUT_META  = os.path.join(OUTPUT_DIR, "run_meta.json")

# Self-consistency: K candidates per question
K = 6

# Adjust for GPU
BATCH_SIZE = 16

# Accuracy-first generation
MAX_NEW_TOKENS = 256
TEMPERATURE    = 0.20
TOP_P          = 0.90
REPETITION_PENALTY = 1.08

# Keep policy
KEEP_ALL_CORRECT = False     
MAX_COT_CHARS    = 1200       

# Buffered write
WRITE_EVERY_N_BATCH = 10
FLUSH_AT_END = True

# RANGE_END=None for full split.
RANGE_START = 0
RANGE_END   = None

SEED = 3407

# SEED & TORCH FLAGS
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [None]:
# REGEX
# Extract LAST "#### number" (supports commas/decimals/sign)
ANS_RE = re.compile(r"####\s*([-+]?\d+(?:,\d{3})*(?:\.\d+)?)")

ROLE_LINE_RE = re.compile(r"(?im)^\s*(assistant|user|system)\s*:?(\s*)$")

# Remove obvious prompt-leak / meta lines (very important)
LEAK_RE = re.compile(
    r"(?im)^\s*(requirements:|rules:|example format:|format:|instructions:|"
    r"end with|no extra text|do not include|start with:|output only|"
    r"steps:|step\s+\d+\s*:\s*\.\.\.)\b.*$"
)

# Minimal LaTeX wrappers & common commands
LATEX_INLINE_RE  = re.compile(r"\\\((.*?)\\\)", re.S)
LATEX_DISPLAY_RE = re.compile(r"\\\[(.*?)\\\]", re.S)
DOLLAR_MATH_RE   = re.compile(r"\$\$(.*?)\$\$|\$(.*?)\$", re.S)

# PROMPT (KD-ready, low leak risk)
def build_messages_kd(question: str) -> List[Dict[str, str]]:
    system = (
        "You are a precise math solver.\n"
        "Output ONLY the solution.\n"
        "Requirements:\n"
        "- Start with: Step 1:\n"
        "- Use short numbered steps with explicit arithmetic.\n"
        "- Do NOT repeat the question.\n"
        "- End with exactly: #### <number>\n"
        "- Do NOT include any other text.\n"
    )
    user = question.strip()
    return [{"role": "system", "content": system},
            {"role": "user", "content": user}]

# HELPERS: number parsing / normalization (Decimal-safe)
def extract_final_answer(text: str) -> Optional[str]:
    matches = ANS_RE.findall(text)
    if not matches:
        return None
    return matches[-1].replace(",", "").strip()

def normalize_decimal_str(s: Optional[str]) -> Optional[str]:
    """
    Normalize numeric string with Decimal to avoid float issues.
    - remove commas/spaces
    - normalize unicode minus
    - if integer => "72" (no .0)
    """
    if s is None:
        return None
    s = str(s).strip().replace(",", "").replace(" ", "")
    s = s.replace("−", "-")
    if s == "":
        return None
    try:
        d = Decimal(s)
    except (InvalidOperation, ValueError):
        return None
    if d == d.to_integral_value():
        return str(int(d))
    # keep a clean decimal form
    return format(d.normalize(), "f").rstrip("0").rstrip(".")

def gold_from_gsm8k_answer(answer_text: str) -> Optional[str]:
    return normalize_decimal_str(extract_final_answer(answer_text))

# CLEANING / FORMAT ENFORCEMENT
def _latex_to_plain(t: str) -> str:
    # Strip wrappers
    t = LATEX_INLINE_RE.sub(lambda m: m.group(1), t)
    t = LATEX_DISPLAY_RE.sub(lambda m: m.group(1), t)
    t = DOLLAR_MATH_RE.sub(lambda m: (m.group(1) or m.group(2) or ""), t)

    # Convert common commands
    t = re.sub(r"\\frac\{([^}]*)\}\{([^}]*)\}", r"(\1)/(\2)", t)
    t = re.sub(r"\\dfrac\{([^}]*)\}\{([^}]*)\}", r"(\1)/(\2)", t)
    t = t.replace("\\times", "×").replace("\\cdot", "·").replace("\\div", "÷")
    t = re.sub(r"\\left|\\right", "", t)
    t = re.sub(r"\\text\{([^}]*)\}", r"\1", t)
    t = re.sub(r"\\mathrm\{([^}]*)\}", r"\1", t)

    # Remove remaining commands
    t = re.sub(r"\\[a-zA-Z]+", "", t)
    t = t.replace("{", "").replace("}", "")
    return t

def _renumber_steps(t: str) -> str:
    lines = t.splitlines()
    out = []
    step_no = 0
    for ln in lines:
        m = re.match(r"^\s*Step\s+(\d+)\s*:\s*(.*)$", ln)
        if m:
            step_no += 1
            out.append(f"Step {step_no}: {m.group(2).strip()}")
        else:
            out.append(ln)
    return "\n".join(out).strip()

def enforce_step_format(t: str) -> str:
    t = t.strip()

    # 1) Drop role lines + leak/meta lines
    kept = []
    for ln in t.splitlines():
        s = ln.strip()
        if ROLE_LINE_RE.match(s):
            continue
        if LEAK_RE.match(s):
            continue
        kept.append(ln)
    t = "\n".join(kept).strip()

    # 2) Hard cut: keep from first "Step 1:" onward (prevents prompt leak + question echo)
    m = re.search(r"(?m)^\s*Step\s+1\s*:", t)
    if m:
        t = t[m.start():].lstrip()

    # 3) Keep only up to last ####
    ans_matches = list(ANS_RE.finditer(t))
    if ans_matches:
        t = t[:ans_matches[-1].end()].strip()

    # 4) Remove LaTeX
    t = _latex_to_plain(t)

    # 5) Normalize whitespace
    t = re.sub(r"\n{3,}", "\n\n", t).strip()

    # 6) Renumber steps sequentially
    t = _renumber_steps(t)

    # 7) Length cap but keep #### line if present
    if MAX_COT_CHARS and len(t) > MAX_COT_CHARS:
        if "####" in t:
            last_hash = t.rfind("####")
            start = max(0, last_hash - (MAX_COT_CHARS - 40))
            t = t[start:].lstrip()
        else:
            t = t[:MAX_COT_CHARS].rstrip()

    return t

def is_valid_cot(t: str) -> bool:
    if "####" not in t:
        return False
    if re.search(r"(?m)^\s*Step\s+1\s*:", t) is None:
        return False
    if not t.lstrip().startswith("Step 1:"):
        return False
    last = t.strip().splitlines()[-1].strip()
    if not re.match(r"^####\s*[-+]?\d+(?:,\d{3})*(?:\.\d+)?\s*$", last):
        return False
    if re.search(r"(?i)(end with|no extra text|example format|requirements:|rules:)", t):
        return False
    return True

# JSONL writer
def write_jsonl_buffer(path: str, buffer: List[Dict[str, Any]]) -> None:
    if not buffer:
        return
    with open(path, "a", encoding="utf-8") as f:
        for obj in buffer:
            f.write(json.dumps(obj, ensure_ascii=False) + "\n")

# LOAD DATA
ds = load_dataset(DATASET_NAME, DATASET_CONFIG)
data = ds[SPLIT]
n_total = len(data)

if RANGE_END is None:
    RANGE_END = n_total - 1

assert 0 <= RANGE_START < n_total
assert 0 <= RANGE_END < n_total
assert RANGE_START <= RANGE_END

n_to_process = RANGE_END - RANGE_START + 1

print(f"Loaded {DATASET_NAME}/{DATASET_CONFIG} split={SPLIT} n={n_total}")
print(f"Processing range: [{RANGE_START}, {RANGE_END}] -> {n_to_process} examples")
print(f"Output JSONL: {OUT_JSONL}")

In [None]:
# LOAD TEACHER (4-bit)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(TEACHER_NAME, use_fast=True)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token

# left padding is better for batched generation with different prompt lengths
tokenizer.padding_side = "left"

try:
    model = AutoModelForCausalLM.from_pretrained(
        TEACHER_NAME,
        quantization_config=bnb_config,
        device_map="auto",
        attn_implementation="sdpa",
        torch_dtype=torch.float16,
    )
    print("[OK] Loaded model with SDPA attention.")
except Exception as e:
    print(f"[WARN] SDPA failed: {str(e)[:200]}")
    model = AutoModelForCausalLM.from_pretrained(
        TEACHER_NAME,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.float16,
    )

model.eval()
print(f"Model loaded. Device: {model.device}")

# MAIN LOOP
start_time = time.time()
processed = 0
kept = 0
buffer: List[Dict[str, Any]] = []

num_batches = (n_to_process + BATCH_SIZE - 1) // BATCH_SIZE

for batch_id, start_idx in enumerate(range(RANGE_START, RANGE_END + 1, BATCH_SIZE), start=1):
    end_idx = min(start_idx + BATCH_SIZE, RANGE_END + 1)
    batch = data.select(range(start_idx, end_idx))

    questions = batch["question"]
    gold_raw  = batch["answer"]
    gold_nums = [gold_from_gsm8k_answer(a) for a in gold_raw]

    # Build prompts
    batch_messages = [build_messages_kd(q) for q in questions]
    prompts = tokenizer.apply_chat_template(
        batch_messages,
        tokenize=False,
        add_generation_prompt=True
    )

    enc = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=2048,
    ).to(model.device)

    input_ids = enc["input_ids"]
    attn_mask = enc["attention_mask"]
    prompt_lens = attn_mask.sum(dim=1).tolist()

    # Repeat for K samples
    input_ids_rep = input_ids.repeat_interleave(K, dim=0)
    attn_mask_rep = attn_mask.repeat_interleave(K, dim=0)
    prompt_lens_rep = [int(L) for L in prompt_lens for _ in range(K)]

    # Generate
    with torch.inference_mode():
        seqs = model.generate(
            input_ids=input_ids_rep,
            attention_mask=attn_mask_rep,
            do_sample=True,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            max_new_tokens=MAX_NEW_TOKENS,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=REPETITION_PENALTY,
        )

    # Decode completions only
    raw_all = []
    for i in range(seqs.size(0)):
        L = prompt_lens_rep[i]
        gen_tokens = seqs[i, L:]
        raw = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
        raw_all.append(raw)

    # Process each example
    for j in range(len(questions)):
        q = questions[j]
        gold = gold_nums[j]

        raws = raw_all[j*K:(j+1)*K]
        comps = [enforce_step_format(r) for r in raws]

        correct = []
        for comp in comps:
            if not is_valid_cot(comp):
                continue
            pred = normalize_decimal_str(extract_final_answer(comp))
            if (pred is not None) and (gold is not None) and (pred == gold):
                correct.append(comp)

        processed += 1

        if correct:
            correct_sorted = sorted(correct, key=len)
            kept_solutions = correct_sorted if KEEP_ALL_CORRECT else [correct_sorted[0]]

            out_obj = {
                "id": f"gsm8k_train_{start_idx + j:05d}",
                "dataset": DATASET_NAME,
                "config": DATASET_CONFIG,
                "split": SPLIT,
                "question": q,
                "gold_answer": gold,
                "cot_solutions": kept_solutions,
                "num_correct_in_K": len(correct),
                "K": K,
                "gen_params": {
                    "temperature": TEMPERATURE,
                    "top_p": TOP_P,
                    "max_new_tokens": MAX_NEW_TOKENS,
                    "repetition_penalty": REPETITION_PENALTY,
                },
            }
            buffer.append(out_obj)
            kept += 1

    # Buffered write
    if (batch_id % WRITE_EVERY_N_BATCH == 0) and buffer:
        write_jsonl_buffer(OUT_JSONL, buffer)
        buffer.clear()

    # Periodic cleanup
    if batch_id % 20 == 0:
        torch.cuda.empty_cache()

    # Logging
    if batch_id == 1 or batch_id % 10 == 0 or batch_id == num_batches:
        elapsed = time.time() - start_time
        ex_per_s = processed / max(elapsed, 1e-9)
        keep_rate = kept / max(processed, 1)
        done = (end_idx - RANGE_START)
        remaining = n_to_process - done
        eta_sec = remaining / max(ex_per_s, 1e-9)
        print(
            f"[{end_idx:5d}/{RANGE_END+1}] "
            f"batch={batch_id}/{num_batches} "
            f"kept={kept}/{processed} ({keep_rate:.1%}) "
            f"speed={ex_per_s:.2f} ex/s ETA={eta_sec/60:.1f}m"
        )

# Final flush
if FLUSH_AT_END and buffer:
    write_jsonl_buffer(OUT_JSONL, buffer)
    buffer.clear()

# Save metadata
meta = {
    "dataset": DATASET_NAME,
    "config": DATASET_CONFIG,
    "split": SPLIT,
    "teacher": TEACHER_NAME,
    "range": [RANGE_START, RANGE_END],
    "params": {
        "K": K,
        "batch_size": BATCH_SIZE,
        "max_new_tokens": MAX_NEW_TOKENS,
        "temperature": TEMPERATURE,
        "top_p": TOP_P,
        "repetition_penalty": REPETITION_PENALTY,
        "max_cot_chars": MAX_COT_CHARS,
        "keep_all_correct": KEEP_ALL_CORRECT,
    },
    "results": {
        "processed": processed,
        "kept": kept,
        "keep_rate": kept / max(processed, 1),
    },
    "output_jsonl": OUT_JSONL,
    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
    "seed": SEED,
}

with open(OUT_META, "w", encoding="utf-8") as f:
    json.dump(meta, f, ensure_ascii=False, indent=2)

print(f"\n✓ Done in {(time.time()-start_time)/60:.1f}m")
print(f"  Kept: {kept}/{processed} ({kept/max(processed,1):.1%})")
print(f"  Output: {OUT_JSONL}")

# Sanity check: print last sample
try:
    with open(OUT_JSONL, "r", encoding="utf-8") as f:
        lines = f.readlines()
    if lines:
        sample = json.loads(lines[-1])
        print(f"\n[Sample] ID: {sample['id']}")
        print(f"Question: {sample['question'][:120]}...")
        print(f"Solution:\n{sample['cot_solutions'][0]}")
except Exception as e:
    print(f"Sanity check error: {e}")