
# StrategyQA — Fine‑Tune CoT (Teacher→Student) **Exactly like the article**

**What this notebook does (faithful to Fine‑Tune‑CoT):**
- Uses a **teacher LLM** to generate *reasoning (CoT) + a final “Yes/No”* for each question.
- **Filters** examples where the teacher’s `Final:` matches the gold label.
- **Fine‑tunes your student model** so that **input = question** and **target/output = rationale + `Final: Yes/No`**.
- At inference, the student gets **only the question** and must **produce reasoning and a final answer**.

> This is the canonical Fine‑tune‑CoT setup from the article/paper (rationale as supervision, not as input).


In [None]:
!pip install -q datasets transformers openai bitsandbytes accelerate python-dotenv huggingface_hub huggingface_hub[hf_xet]

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.3/61.3 MB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m72.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m58.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m51.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# === CONFIG ===
# >>>> Put your student model here (HF Hub id or local path) <<<<
BASE_MODEL = "microsoft/phi-2"  # change to the student model you use

# Teacher settings (OpenAI by default; you can add your provider)
TEACHER_PROVIDER = "openai"   # options: "openai", "dummy"
TEACHER_MODEL = "gpt-4o-mini" # cheaper GPT-4 class; change to your choice
TEMPERATURE = 0.2
MAX_TEACHER_TOKENS = 400

# Data and training
SEED = 42
MAX_TRAIN_SAMPLES = 2000   # sub-sample to control budget
MAX_SEQ_LEN = 192          # token budget for student training (input + output kept short)
EPOCHS = 3
LR = 2e-4
BATCH_SIZE = 8

# Paths
DATA_DIR = "data/strategyqa"
TRAIN_JSON = f"{DATA_DIR}/train.jsonl"
DEV_JSON   = f"{DATA_DIR}/dev.jsonl"
OUT_DIR = "out/cot_student"
OUT_DIR_ASK = "out/cot_student_ask"
TEACHER_DUMP = f"{DATA_DIR}/train_teacher_dump.jsonl"
SFT_JSONL = f"{DATA_DIR}/train_sft_cot.jsonl"   # input=question, output=rationale+Final
PRED_JSON = f"{OUT_DIR}/dev_preds.json"

# Config (reuses your globals when present)
STUDENT_DRAFT_MODEL = TEACHER_MODEL          # you can set a different model if you want
DEV_FRAC_STUDENTQ = 0.1                      # 10% of kept rows go to dev
TEACHER_DUMP_STUDENTQ = f"{DATA_DIR}/teacher_dump_studentq.jsonl"
SFT_STUDENTQ_TRAIN = f"{DATA_DIR}/train_sft_cot_studentq.jsonl"
SFT_STUDENTQ_DEV   = f"{DATA_DIR}/dev_sft_cot_studentq.jsonl"



# Auth (if needed)
# %env HF_TOKEN=
%env OPENAI_API_KEY= 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


env: OPENAI_API_KEY=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


In [None]:

# If running on Colab or a fresh env, uncomment:
# !pip install -q transformers==4.43.3 peft==0.12.0 datasets accelerate bitsandbytes tiktoken openai==1.35.10


In [None]:

import os, json, random, re, time, math, pathlib
from dataclasses import dataclass
from typing import Dict, List, Any

import torch
from transformers import (AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer)
from peft import LoraConfig, get_peft_model, TaskType

random.seed(SEED)

# Robust final-line parser: look for a line 'Final: Yes/No'
FINAL_RE = re.compile(r"^\s*final\s*:\s*(yes|no)\s*$", re.IGNORECASE | re.MULTILINE)

# Accepts: "Final: Yes", "Final: No", "Final answer: Yes", "Final answer: No"
# Allows punctuation/words after Yes/No, and matches anywhere in the text.
FINAL_RE = re.compile(
    r"(?im)\bfinal(?:\s*answer)?\s*[:\-–]\s*(yes|no)\b"
)

def extract_final(text: str) -> str:
    if not text:
        return ""
    # prefer the LAST occurrence in case multiple appear
    last = None
    for m in FINAL_RE.finditer(text):
        last = m
    if last:
        yn = last.group(1).strip().lower()
        return "Yes" if yn.startswith("y") else "No"
    # fallback: if "final" appears near the end, try heuristics
    tail = (text[-200:] or "").lower()
    if "final" in tail:
        if " yes" in tail or "yes" in tail:
            return "Yes"
        if " no" in tail or "no" in tail:
            return "No"
    return ""

PROMPT_TMPL = (
    "You are a careful reasoner. Solve the yes/no question step-by-step.\n"
    "Then output \"Final: Yes\" or \"Final: No\".\n\n"
    "Q: {q}\nReasoning:\n"
)


In [None]:

# This cell expects local StrategyQA jsonl files at data/strategyqa/{train,dev}.jsonl
# Each line: {"question": "...", "answer": "Yes/No"}
# If not present, we create a tiny demo split to let the notebook run end-to-end.

path_train = pathlib.Path(TRAIN_JSON)
path_dev   = pathlib.Path(DEV_JSON)
os.makedirs(DATA_DIR, exist_ok=True)

if not (path_train.exists() and path_dev.exists()):
    print("[WARN] Local StrategyQA not found. Creating a tiny demo set (10 train, 5 dev).")
    toy_train = [
        {"question":"Is water wet?", "answer":"Yes"},
        {"question":"Is 2 greater than 3?", "answer":"No"},
        {"question":"Do birds have wings?", "answer":"Yes"},
        {"question":"Is the Sun a planet?", "answer":"No"},
        {"question":"Can penguins fly?", "answer":"No"},
        {"question":"Is Peru in South America?", "answer":"Yes"},
        {"question":"Does the Amazon River flow through Peru?", "answer":"Yes"},
        {"question":"Is 10 an even number?", "answer":"Yes"},
        {"question":"Do spiders have six legs?", "answer":"No"},
        {"question":"Is Earth flat?", "answer":"No"},
    ]
    toy_dev = [
        {"question":"Is the Nile a river?", "answer":"Yes"},
        {"question":"Is zero a negative number?", "answer":"No"},
        {"question":"Do fish breathe air with lungs?", "answer":"No"},
        {"question":"Is Mount Everest the tallest mountain on Earth?", "answer":"Yes"},
        {"question":"Is chocolate salty by default?", "answer":"No"},
    ]
    with open(TRAIN_JSON, "w", encoding="utf-8") as f:
        for r in toy_train: f.write(json.dumps(r)+"\n")
    with open(DEV_JSON, "w", encoding="utf-8") as f:
        for r in toy_dev: f.write(json.dumps(r)+"\n")

train_rows = [json.loads(l) for l in open(TRAIN_JSON, encoding="utf-8")]
dev_rows   = [json.loads(l) for l in open(DEV_JSON, encoding="utf-8")]

print(f"[OK] Train rows: {len(train_rows)} | Dev rows: {len(dev_rows)}")


[WARN] Local StrategyQA not found. Creating a tiny demo set (10 train, 5 dev).
[OK] Train rows: 10 | Dev rows: 5


In [None]:
def call_teacher_llm(prompt: str) -> str:
    if TEACHER_PROVIDER == "dummy":
        final = "Yes" if random.random() < 0.5 else "No"
        return "Because of common knowledge.\nFinal: " + final  # newline before Final
    elif TEACHER_PROVIDER == "openai":
        from openai import OpenAI
        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key:
            raise RuntimeError("Please set OPENAI_API_KEY env var for teacher generation.")
        client = OpenAI(api_key=api_key)
        resp = client.chat.completions.create(
            model=TEACHER_MODEL,
            temperature=TEMPERATURE,
            messages=[
                {"role":"system","content":
                  "Answer ONE yes/no question. Use 1–3 brief steps, then EXACTLY one line: 'Final: Yes' or 'Final: No'. "
                  "No extra questions, no examples, keep ≤80 tokens total."},
                # {"role": "system", "content":
                #  "Answer ONE yes/no question. Show brief step-by-step reasoning, then on a new line write EXACTLY one of: 'Final: Yes' or 'Final: No'. "
                #  "Do NOT include extra examples, exercises, or additional questions."},
                {"role": "user", "content": prompt},
            ],
            max_tokens=MAX_TEACHER_TOKENS,
            stop=["\n\nQ:", "\nQ:"]   # <- prevent starting another question
        )
        return resp.choices[0].message.content.strip()
    else:
        raise ValueError(f"Unknown TEACHER_PROVIDER: {TEACHER_PROVIDER}")


In [None]:
FINAL_LINE_RE = re.compile(r"^\s*Final:\s*(Yes|No)\s*$", re.IGNORECASE | re.MULTILINE)

def sanitize_teacher_old(text: str) -> str:
    """
    Keep reasoning up to and including the FIRST 'Final:' line.
    Drop anything after (examples, extra Q:, etc).
    Also normalize 'Final: yes/no' casing.
    """
    if not text:
        return ""
    # Find first Final line
    m = FINAL_LINE_RE.search(text)
    if not m:
        return ""
    end = m.end()
    head = text[:end]
    # Normalize the final line to 'Final: Yes' or 'Final: No'
    final_norm = f"Final: {'Yes' if m.group(1).lower().startswith('y') else 'No'}"
    # Replace the matched segment's final line with normalized version
    head = FINAL_LINE_RE.sub(final_norm, head)
    return head.strip()

MAX_COT_TOKENS = 96  # keep short

def sanitize_teacher(text: str) -> str:
    if not text:
        return ""
    m = FINAL_LINE_RE.search(text)
    if not m:
        return ""
    head = text[:m.end()]
    head = FINAL_LINE_RE.sub(f"Final: {'Yes' if m.group(1).lower().startswith('y') else 'No'}", head)
    # Hard trim to avoid rambling before Final:
    head = head.split("\nQ:")[0].split("\n\nQ:")[0]  # kill accidental next prompts
    # Optional: coarse token cap (char cap is fine here since we’re pre-tokenizer)
    return head.strip()[:800]  # ~800 chars ≈ << MAX_COT_TOKENS tokens on average


def normalize_label(s) -> str:
    # Handle boolean values directly
    if isinstance(s, bool):
        return "Yes" if s else "No"

    # Convert non-bool to string for processing
    s = str(s).strip().lower()
    if s.startswith("y"):
        return "Yes"
    if s.startswith("n"):
        return "No"
    return "Yes" if s in {"true", "1"} else "No"

In [None]:

# Build SFT dataset exactly as in Fine-tune-CoT:
# input = question
# output = teacher rationale + 'Final: Yes/No'
# Keep only examples where Final == gold answer.


# Subsample for budget

train_rows = [json.loads(l) for l in open(TRAIN_JSON, encoding="utf-8")]


rows = train_rows[:]
random.shuffle(rows)
if MAX_TRAIN_SAMPLES and len(rows) > MAX_TRAIN_SAMPLES:
    rows = rows[:MAX_TRAIN_SAMPLES]

kept = 0
with open(TEACHER_DUMP, "w", encoding="utf-8") as dump_f,      open(SFT_JSONL, "w", encoding="utf-8") as sft_f:

    for r in rows:
        q = r["question"].strip()
        gold = normalize_label(r["answer"])
        prompt = PROMPT_TMPL.format(q=q)

        # try a couple of times to get a 'Final:' line
        raw = call_teacher_llm(prompt)
        out = sanitize_teacher(raw)
        final = extract_final(out)
        dump_f.write(json.dumps({"question": q, "gold": gold, "teacher_full": raw}, ensure_ascii=False) + "\n")

        if final and final == gold:
            sft_f.write(json.dumps({"input": q, "output": out}, ensure_ascii=False) + "\n")
            kept += 1


print(f"[STAT] Kept training examples after filtering: {kept}")
print(f"[OK] Wrote teacher dump: {TEACHER_DUMP}")
print(f"[OK] Wrote SFT jsonl:   {SFT_JSONL}")

# After writing SFT_JSONL:
print("SFT size:", sum(1 for _ in open(SFT_JSONL, "r", encoding="utf-8")))
for i, line in zip(range(3), open(SFT_JSONL, "r", encoding="utf-8")):
    ex = json.loads(line)
    print("SFT SAMPLE", i, "INPUT:", ex["input"])
    print("SFT SAMPLE", i, "OUTPUT:", ex["output"])
    print("----")



[STAT] Kept training examples after filtering: 801
[OK] Wrote teacher dump: data/strategyqa/train_teacher_dump.jsonl
[OK] Wrote SFT jsonl:   data/strategyqa/train_sft_cot.jsonl
SFT size: 801
SFT SAMPLE 0 INPUT: If a baby was born on Halloween would they be a Scorpio?
SFT SAMPLE 0 OUTPUT: 1. Halloween is on October 31.
2. Scorpio is from October 23 to November 21.
3. A baby born on Halloween falls within the Scorpio date range.
Final: Yes
----
SFT SAMPLE 1 INPUT: Is a slime mold safe from cerebral palsy?
SFT SAMPLE 1 OUTPUT: 1. Slime molds are single-celled organisms and do not have a nervous system or brain.
2. Cerebral palsy is a neurological disorder affecting humans, specifically their motor functions.
3. Since slime molds are not human and do not have a nervous system, they cannot be affected by cerebral palsy.
Final: Yes
----
SFT SAMPLE 2 INPUT: Does the word swastika have meaning in sanskrit?
SFT SAMPLE 2 OUTPUT: 1. The word "swastika" originates from the Sanskrit word "svastika,

In [None]:
# === Build SFT with student-questions -> teacher-CoT ===
# Produces:
#   data/strategyqa/train_sft_cot_studentq.jsonl
#   data/strategyqa/dev_sft_cot_studentq.jsonl

import os, json, random, re, pathlib
from collections import Counter

random.seed(SEED)

def _openai_chat(messages, model, temperature, max_tokens=400, stop=None):
    from openai import OpenAI
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        raise RuntimeError("Set OPENAI_API_KEY to use provider='openai'")
    client = OpenAI(api_key=api_key)
    resp = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=temperature,
        max_tokens=max_tokens,
        stop=stop
    )
    return resp.choices[0].message.content.strip()

def gen_student_draft(q: str) -> str:
    """
    Return a STRICTLY formatted student draft for a yes/no question.

    Output format (exactly two lines):
        Answer: <Yes/No>
        Questions: <Q1>? <Q2>?

    - Tries up to 3 times to get a clean response from the student model.
    - Enforces casing ('Yes'/'No'), exactly two questions ending with '?',
      and removes any extra content.
    - Falls back to a minimal valid draft if formatting fails.

    Depends on:
      - TEACHER_PROVIDER ("openai" or "dummy")
      - STUDENT_DRAFT_MODEL (model string for _openai_chat)
      - _openai_chat(messages, model, temperature, max_tokens, stop)
    """
    import re, random

    # Strict validator: "Answer: Yes|No" + "Questions: ...? ...?"
    # Allows extra whitespace but requires 1–2 questions ending with '?'
    DRAFT_RE = re.compile(
        r"(?is)^\s*answer\s*:\s*(yes|no)\s*\n\s*questions\s*:\s*(.+\?)\s*(.+\?)?\s*$"
    )

    def _normalize_and_fix(raw: str) -> str:
        """Coerce to the exact two-line format; clip to at most two questions."""
        if not raw:
            return ""

        raw = raw.strip()

        # Pull "Answer: ..." (Yes/No) robustly
        m_ans = re.search(r"(?is)^\s*answer\s*:\s*(yes|no)\b", raw)
        ans = m_ans.group(1).strip().lower() if m_ans else ""
        ans = "Yes" if ans.startswith("y") else ("No" if ans.startswith("n") else "")

        # Pull the questions line
        m_qs = re.search(r"(?is)questions\s*:\s*(.*)$", raw)
        qs_raw = m_qs.group(1).strip() if m_qs else ""

        # Split on '?' and keep the first two non-empty fragments, then append '?'
        parts = [p.strip() for p in qs_raw.split("?") if p.strip()]
        parts = parts[:2] if parts else []

        # If fewer than 2 valid questions, synthesize generic but relevant ones
        while len(parts) < 2:
            if len(parts) == 0:
                parts.append(f"What key facts determine the answer to '{q[:80]}'")
            else:
                parts.append("Are there any known exceptions or edge cases")

        # Ensure each ends with '?'
        q1 = (parts[0] + "?").replace("??", "?")
        q2 = (parts[1] + "?").replace("??", "?")

        if not ans:
            return ""

        return f"Answer: {ans}\nQuestions: {q1} {q2}"

    if TEACHER_PROVIDER == "dummy":
        yn = "Yes" if random.random() < 0.5 else "No"
        q1 = f"What facts are required to verify '{q[:80]}'?"
        q2 = "Do timelines or definitions change the outcome?"
        return f"Answer: {yn}\nQuestions: {q1} {q2}"

    elif TEACHER_PROVIDER == "openai":
        sys = (
            "You are a careful student. For the ONE yes/no question, propose a tentative answer "
            "and ask 1–2 short clarifying sub-questions to verify it.\n"
            "Respond in EXACTLY this format (two lines only):\n"
            "Answer: <Yes/No>\n"
            "Questions: <Q1>? <Q2>?\n"
            "Rules: No extra text before/after. Keep questions compact (≤12 words each)."
        )

        # Up to 3 attempts to get clean format; slightly higher temperature for better drafts
        for _ in range(2):
            draft = _openai_chat(
                [{"role": "system", "content": sys},
                 {"role": "user",   "content": f"Q: {q}"}],
                model=STUDENT_DRAFT_MODEL,
                temperature=0.5,
                max_tokens=120,
                stop=None
            )

            fixed = _normalize_and_fix(draft)
            if fixed and DRAFT_RE.match(fixed):
                return fixed

        # Final fallback (guaranteed-valid)
        return fixed

    else:
        raise ValueError(f"Unknown TEACHER_PROVIDER={TEACHER_PROVIDER}")

def gen_teacher_cot(q: str, student_draft: str) -> str:
    """Teacher CoT that explicitly references student's questions; must end with 'Final: Yes/No'."""
    if TEACHER_PROVIDER == "dummy":
        yn = "Yes" if random.random() < 0.5 else "No"
        return f"1) Address sub-questions.\n2) Conclude based on timelines.\nFinal: {yn}"
    elif TEACHER_PROVIDER == "openai":
        sys = (
            "You are the TEACHER. Use the student's sub-questions to structure a brief, factual chain-of-thought. "
            "Address each student sub-question explicitly; then write EXACTLY one line: 'Final: Yes' OR 'Final: No'. "
            "Do NOT add other questions or extra examples."
        )
        user = (
            f"Original question:\n{q}\n\n"
            f"Student draft:\n{student_draft}\n\n"
            "Write step-by-step reasoning that references the student's questions, then a single final line."
        )
        return _openai_chat(
            [{"role":"system","content":sys},{"role":"user","content":user}],
            model=TEACHER_MODEL,
            temperature=0.2,
            max_tokens=min(300, MAX_TEACHER_TOKENS),
            stop=["\n\nQ:", "\nQ:"]   # prevent new QA blocks
        )
    else:
        raise ValueError(f"Unknown TEACHER_PROVIDER={TEACHER_PROVIDER}")

# ---- Build the new SFT ----
in_path = pathlib.Path(TRAIN_JSON)
assert in_path.exists(), f"Missing {TRAIN_JSON}"
src_rows = [json.loads(l) for l in open(TRAIN_JSON, encoding="utf-8")]

rows = src_rows[:]
random.shuffle(rows)
if MAX_TRAIN_SAMPLES and len(rows) > MAX_TRAIN_SAMPLES:
    rows = rows[:MAX_TRAIN_SAMPLES]

kept = []
dropped_no_final = 0
dropped_mismatch = 0
with open(TEACHER_DUMP_STUDENTQ, "w", encoding="utf-8") as dump_f:
    for r in rows:
        q = r["question"].strip()
        gold = normalize_label(r["answer"])

        student_draft = gen_student_draft(q)
        raw_cot = gen_teacher_cot(q, student_draft)
        cot = sanitize_teacher(raw_cot)
        final = extract_final(cot)

        print(f"the q: {q}\n the gold: {gold}\n the student: {student_draft}\n and the teacher cot: {cot} \n")

        dump_f.write(json.dumps({
            "question": q,
            "gold": gold,
            "student_draft": student_draft,
            "teacher_full": raw_cot
        }, ensure_ascii=False) + "\n")

        if not final:
            dropped_no_final += 1
            continue
        if final != gold:
            dropped_mismatch += 1
            continue

        kept.append({
            "input": q,                # train input stays the original question
            "student_draft": student_draft,  # stored for analysis/debug (not used as input)
            "output": cot              # teacher CoT that references student's questions + 'Final: ...'
        })

print(f"[BUILD] kept={len(kept)}  dropped_no_final={dropped_no_final}  dropped_mismatch={dropped_mismatch}")

# ---- Stratified dev split by parsed label ----
yes = [x for x in kept if extract_final(x["output"]) == "Yes"]
no  = [x for x in kept if extract_final(x["output"]) == "No"]
random.shuffle(yes); random.shuffle(no)

total = len(kept)
dev_n = max(1, int(round(total * DEV_FRAC_STUDENTQ)))
dev_yes_n = int(round(dev_n * len(yes) / max(1, total)))
dev_no_n  = dev_n - dev_yes_n

dev = yes[:dev_yes_n] + no[:dev_no_n]
train = yes[dev_yes_n:] + no[dev_no_n:]
random.shuffle(dev); random.shuffle(train)

with open(SFT_STUDENTQ_TRAIN, "w", encoding="utf-8") as f:
    for o in train: f.write(json.dumps(o, ensure_ascii=False) + "\n")
with open(SFT_STUDENTQ_DEV, "w", encoding="utf-8") as f:
    for o in dev: f.write(json.dumps(o, ensure_ascii=False) + "\n")

print(f"[WRITE] train={len(train)}  dev={len(dev)}")
print("[DEV LABEL COUNTS]", Counter(extract_final(o["output"]) for o in dev))

# ---- Peek a few samples ----
for i, o in enumerate(dev[:3]):
    print("==== SAMPLE", i)
    print("Q:", o["input"])
    print("STUDENT_DRAFT:\n", o["student_draft"])
    print("TEACHER_COT:\n", o["output"])
    print("------------------------")


AssertionError: Missing data/strategyqa/train.jsonl

In [None]:
import json, random
from collections import Counter

random.seed(42)

SRC = "data/strategyqa/train_sft_cot.jsonl"  # path to your 790-sample file
DEV_OUT = "data/strategyqa/dev_sft_cot.jsonl"
TRAIN_OUT = "data/strategyqa/train_sft_cot_reduced.jsonl"

DEV_N = 100  # how many you want for dev
DEV_FRAC = 0.1

# ====== Load & parse ======
rows = []
with open(SRC, "r", encoding="utf-8") as f:
    for line in f:
        o = json.loads(line)
        o_out = o.get("output", "")
        label = extract_final(o_out)
        if label in ("Yes", "No"):
            rows.append({"input": o["input"], "output": o_out, "answer": label})
        # If no label found, drop the row silently (can log if you want)

print(f"[LOAD] total rows read: {len(rows)}")
print("[LABEL COUNTS]", Counter(r["answer"] for r in rows))

# ====== Compute dev size ======
total = len(rows)
if DEV_N is None:
    DEV_N = max(1, int(round(total * DEV_FRAC)))
DEV_N = min(DEV_N, total - 1)  # leave at least one for train

# ====== Stratified split by label ======
yes = [r for r in rows if r["answer"] == "Yes"]
no  = [r for r in rows if r["answer"] == "No"]
random.shuffle(yes); random.shuffle(no)

# Proportional allocation
dev_yes_n = int(round(DEV_N * len(yes) / max(1, total)))
dev_no_n  = DEV_N - dev_yes_n

dev = yes[:dev_yes_n] + no[:dev_no_n]
train = yes[dev_yes_n:] + no[dev_no_n:]
random.shuffle(dev); random.shuffle(train)

# ====== Save (drop the helper 'answer' field to keep original schema if you prefer) ======
with open(DEV_OUT, "w", encoding="utf-8") as f:
    for r in dev:
        f.write(json.dumps({"input": r["input"], "output": r["output"]}, ensure_ascii=False) + "\n")

with open(TRAIN_OUT, "w", encoding="utf-8") as f:
    for r in train:
        f.write(json.dumps({"input": r["input"], "output": r["output"]}, ensure_ascii=False) + "\n")

print(f"[SPLIT] train={len(train)}  dev={len(dev)}")
print("[DEV LABEL COUNTS]", Counter(extract_final(r["output"]) for r in dev))
print("[TRAIN LABEL COUNTS]", Counter(extract_final(r["output"]) for r in train))


[LOAD] total rows read: 801
[LABEL COUNTS] Counter({'No': 483, 'Yes': 318})
[SPLIT] train=701  dev=100
[DEV LABEL COUNTS] Counter({'No': 60, 'Yes': 40})
[TRAIN LABEL COUNTS] Counter({'No': 423, 'Yes': 278})


In [None]:
import torch, os
torch.cuda.empty_cache()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:

class CoTDataset(torch.utils.data.Dataset):
    def __init__(self, jsonl_path, tokenizer, max_len=MAX_SEQ_LEN):
        self.rows = [json.loads(l) for l in open(jsonl_path, encoding="utf-8")]
        self.tok = tokenizer
        self.max_len = max_len

    def __len__(self): return len(self.rows)

    def __getitem__(self, i):
        r = self.rows[i]
        q = r["input"]
        y = r["output"]
        # Pack as: [question] -> [rationale + Final]
        # We'll concatenate and mask so that loss applies to the output tokens only.
        enc_inp = self.tok(q, truncation=True, max_length=self.max_len)
        enc_out = self.tok(y, truncation=True, max_length=self.max_len)

        input_ids = enc_inp["input_ids"] + enc_out["input_ids"]
        # Build labels: ignore index for the input portion
        labels = [-100]*len(enc_inp["input_ids"]) + enc_out["input_ids"]
        if len(input_ids) > self.max_len:
            input_ids = input_ids[:self.max_len]
            labels = labels[:self.max_len]

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long)
        }

In [None]:
from transformers import BitsAndBytesConfig, AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
import torch

def _detect_lora_targets(model):
    names = [n for n,_ in model.named_modules()]
    # phi-2 family: fused qkv
    if any(n.endswith("Wqkv") or ".Wqkv" in n for n in names):
        return ["Wqkv", "out_proj"]
    # Llama/Mistral
    if any(".q_proj" in n for n in names):
        return ["q_proj", "v_proj"]
    # GPT-2 style
    if any(".c_attn" in n for n in names):
        return ["c_attn", "c_proj"]
    # fallback: try common names
    return ["Wqkv","out_proj","q_proj","v_proj","k_proj","o_proj","c_attn","c_proj"]

def make_model_and_tokenizer():
    tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    bnb_cfg = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
    )

    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        quantization_config=bnb_cfg,
        device_map="auto",
        low_cpu_mem_usage=True,
        trust_remote_code=False,
    )
    model = prepare_model_for_kbit_training(model)

    targets = _detect_lora_targets(model)
    lora_cfg = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=8, lora_alpha=16, lora_dropout=0.05,
        target_modules=targets
    )
    model = get_peft_model(model, lora_cfg)

    model.gradient_checkpointing_enable()
    model.config.use_cache = False

    # sanity
    trainable = [n for n,p in model.named_parameters() if p.requires_grad]
    print(f"Trainable tensors: {len(trainable)}")
    assert len(trainable) > 0, "No trainable parameters — LoRA didn't attach."
    return model, tok

In [None]:

os.environ["WANDB_DISABLED"] = "true"
os.makedirs(OUT_DIR, exist_ok=True) ##change dir if needed
model, tok = make_model_and_tokenizer()
ds = CoTDataset(TRAIN_OUT, tok, MAX_SEQ_LEN) ##yuval debug - TRAIN_OUT change if needed

args = TrainingArguments(
    output_dir=OUT_DIR,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=LR,
    num_train_epochs=EPOCHS,
    fp16=True,               # or bf16=True if supported
    logging_steps=20,
    save_strategy="epoch",
    report_to="none",        # <- proper way (no WANDB prompt)
    dataloader_num_workers=2
)


trainer = Trainer(model=model, args=args, train_dataset=ds)
trainer.train()
trainer.save_model(OUT_DIR)
tok.save_pretrained(OUT_DIR)

print("[OK] Training complete. Model saved to", OUT_DIR)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/564M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Trainable tensors: 128


Step,Training Loss
20,1.5569
40,0.9588
60,0.8221
80,0.8279
100,0.7299
120,0.7283
140,0.7463
160,0.7438
180,0.6966
200,0.6736


[OK] Training complete. Model saved to out/cot_student


In [None]:
from transformers import StoppingCriteria, StoppingCriteriaList

from transformers import StoppingCriteria, StoppingCriteriaList

class StopOnFinal(StoppingCriteria):
    def __init__(self, tok):
        self.ids_yes = tok.encode("Final: Yes", add_special_tokens=False)
        self.ids_no  = tok.encode("Final: No",  add_special_tokens=False)
    def __call__(self, input_ids, scores, **kwargs):
        seq = input_ids[0].tolist()
        return (len(seq) >= len(self.ids_yes) and seq[-len(self.ids_yes):] == self.ids_yes) or \
               (len(seq) >= len(self.ids_no)  and seq[-len(self.ids_no):]  == self.ids_no)

from transformers import StoppingCriteria, StoppingCriteriaList

def generate_answer(model, tok, q: str) -> str:
    # 1) Prompt exactly like training input (no instruction line to copy)
    prompt = f"Q: {q}\nReasoning:\n"
    enc = tok(prompt, return_tensors="pt", truncation=True, max_length=192)
    enc = {k: v.to(model.device) for k, v in enc.items()}
    prompt_len = enc["input_ids"].shape[1]

    # --- pass 1: generate reasoning; block 'Final:' so it can't jump early ---
    bad_ids_final = [tok.encode("Final:", add_special_tokens=False)]
    bad_qas = [tok.encode(w, add_special_tokens=False) for w in ["\nQ:", "\n\nQ:", "\nA:", "\n\nA:"]]
    bad_words_ids = bad_ids_final + bad_qas

    with torch.inference_mode():
        out1 = model.generate(
            **enc,
            max_new_tokens=96,
            min_new_tokens=24,         # force some reasoning tokens
            do_sample=False,
            num_beams=1,
            eos_token_id=tok.eos_token_id,
            pad_token_id=tok.eos_token_id,
            bad_words_ids=bad_words_ids,
            no_repeat_ngram_size=3,
            repetition_penalty=1.05,
        )

    # Extract only the newly generated reasoning text
    reasoning = tok.decode(out1[0][prompt_len:], skip_special_tokens=True).strip()

    # 2) Append 'Final:' and let the model decide Yes/No only (tiny decode)
    prompt2 = prompt + (reasoning + "\n" if reasoning else "") + "Final:"
    enc2 = tok(prompt2, return_tensors="pt", truncation=True, max_length=192+96)
    enc2 = {k: v.to(model.device) for k, v in enc2.items()}

    with torch.inference_mode():
        out2 = model.generate(
            **enc2,
            max_new_tokens=3,          # just ' Yes' or ' No'
            do_sample=False,
            num_beams=1,
            eos_token_id=tok.eos_token_id,
            pad_token_id=tok.eos_token_id,
        )

    text = tok.decode(out2[0], skip_special_tokens=True)
    final = extract_final(text)
    print(f"first final {final}")
    if not final:
        tl = text.lower()
        final = "Yes" if " final: yes" in tl or tl.strip().endswith("yes") else \
                ("No" if " final: no" in tl or tl.strip().endswith("no") else "No")

    print("---- RAW OUTPUT ----")
    print(text)
    print("--------------------")
    print(f"final answer is: {final}")
    return final


In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Define the target folder inside your Drive
drive_target = '/content/drive/MyDrive/LLM_Models/cot_student_out'
drive_target_ask = '/content/drive/MyDrive/LLM_Models/cot_student_out_ask'


Mounted at /content/drive


In [None]:
# Copy the OUT_DIR to Drive
import shutil, os
if os.path.exists(OUT_DIR):
    shutil.copytree(OUT_DIR, drive_target, dirs_exist_ok=True)
    print(f"Saved model to: {drive_target}")
else:
    print(f"OUT_DIR '{OUT_DIR}' does not exist!")

if os.path.exists(OUT_DIR_ASK):
    shutil.copytree(OUT_DIR_ASK, drive_target_ask, dirs_exist_ok=True)
    print(f"Saved model to: {drive_target_ask}")
else:
    print(f"OUT_DIR_ASK '{OUT_DIR_ASK}' does not exist!")

Saved model to: /content/drive/MyDrive/LLM_Models/cot_student_out
OUT_DIR_ASK 'out/cot_student_ask' does not exist!


In [None]:
# 0) Mount Drive (Colab)
from google.colab import drive
drive.mount('/content/drive')

# 1) Point to the saved model folder on Drive
LOAD_DIR = '/content/drive/MyDrive/LLM_Models/cot_student_out'  # <-- change to your path
PRED_JSON = '/content/drive/MyDrive/LLM_Models/cot_student_out__preds.json'  # optional: save preds to Drive too


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:

# 2) Load tokenizer + PEFT model from Drive
import os, torch, json
from transformers import AutoTokenizer, BitsAndBytesConfig
from peft import AutoPeftModelForCausalLM

# (optional) sanity: show what's in the folder
print("Drive model dir contents:", os.listdir(drive_target_ask)) ## yuval debug - change it to OUT_DIR if you don't want to use the saved model

dev_rows   = [json.loads(l) for l in open(DEV_JSON, encoding="utf-8")]

tok_eval = AutoTokenizer.from_pretrained(drive_target_ask, use_fast=True)
if tok_eval.pad_token_id is None:
    tok_eval.pad_token = tok_eval.eos_token  # ensure padding exists

bnb_cfg_inf = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
)

model_eval = AutoPeftModelForCausalLM.from_pretrained(
    drive_target_ask,
    device_map="auto",
    quantization_config=bnb_cfg_inf,
    torch_dtype="auto"
).eval()

# Fast inference switches (apply AFTER loading)
try:
    model_eval.gradient_checkpointing_disable()
except Exception:
    pass
model_eval.config.use_cache = True

# 3) Run evaluation (fix the denominator if you early-break)
preds, correct, num = [], 0, 0
for r in dev_rows:
    q, gold = r["input"], r["output"]
    gold = extract_final(gold)
    print(f"the q is {q} !! and gold is: {gold}")
    pred = generate_answer(model_eval, tok_eval, q)
    preds.append({"question": q, "pred": pred, "gold": gold})
    correct += int(pred == gold)
    print(f"pred is gold? {pred == gold}")
    num += 1
    if num == 500:   # remove this break if you want full dev
        break

acc = correct / max(1, num)  # <-- use num, not len(dev_rows), when you break
print(f"[DEV ACC] {acc:.3f}  ({correct}/{num})")

with open(PRED_JSON, "w", encoding="utf-8") as f:
    json.dump(preds, f, ensure_ascii=False, indent=2)
print(f"[OK] Wrote predictions to: {PRED_JSON}")


Drive model dir contents: ['checkpoint-76', 'checkpoint-152', 'adapter_model.safetensors', 'tokenizer_config.json', 'checkpoint-228', 'README.md', 'training_args.bin', 'adapter_config.json', 'special_tokens_map.json', 'added_tokens.json', 'tokenizer.json', 'vocab.json', 'merges.txt']


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

the q is Do sun bears stay active during winter? !! and gold is: Yes
first final No
---- RAW OUTPUT ----
Q: Do sun bears stay active during winter?
Reasoning:
1. Sun bears are known to be primarily active during the day, which suggests that they may not be as active during nighttime hours.
2. Winter is a season characterized by colder temperatures and shorter daylight hours, which could further limit their activity.
3. Sun bear behavior is influenced by environmental factors, including temperature and light availability, which would likely affect their activity levels during winter.
4. Other bear species, such as polar bears, are adapted to cold environments and
Final: No, sun
--------------------
final answer is: No
pred is gold? False
the q is Would downloading Mario 64 on an emulator be legal? !! and gold is: No
first final No
---- RAW OUTPUT ----
Q: Would downloading Mario 64 on an emulator be legal?
Reasoning:
1. The original Mario 64 game was released in 1996 and is protected by 


### Notes & Tips
- **Faithfulness to article**: This trains the student to *produce* CoT+Final from only the question (no rationale conditioning at input).
- **Filtering**: Only teacher generations that match the gold label are kept (strongly recommended).
- **Budget**: Lower `MAX_TRAIN_SAMPLES` if teacher API cost is a concern. StrategyQA CoTs are short; `MAX_SEQ_LEN=256` is usually enough.
- **Your student model**: Replace `BASE_MODEL` with the exact model you use.
- **Teacher**: Swap to your preferred provider if not using OpenAI (`TEACHER_PROVIDER="dummy"` lets you dry-run the pipeline).
