In [None]:
!pip -q install "transformers>=4.40" "datasets>=2.18" accelerate openai tiktoken

import torch, platform
print("Python:", platform.python_version())
print("Torch :", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
!nvidia-smi

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m126.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m101.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m54.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m35.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m19.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
%%bash
cat > /content/python_type_tokenizer.py <<'PY'
from __future__ import annotations
import ast, io, re, tokenize as py_tok
from typing import List, Tuple

__all__ = ["PyTypeTokenizer"]

_CONST_TAG = {int: "<INT>", float: "<FLOAT>", bool: "<BOOL>", str: "<STR>"}
ALL_TAGS = list(_CONST_TAG.values()) + ["<LIST>", "<TUPLE>"]
_TAG_RE = re.compile(r"<[^>]+>")
_MINUS_FIX = re.compile(r"-(<INT>|<FLOAT>)(?=[0-9])")
_LIST_RE = re.compile(r"\[[^\[\]]*?\]")
_TUPLE_RE = re.compile(r"\([^()]*?,[^()]*?\)")
_EMPTY_TUP = re.compile(r"\(\)")

_SPLIT_RE = re.compile(
    r"<TUPLE>\(\)"
    r"|<BOOL>True|<BOOL>False"
    r"|<[A-Z]+>[-+]?\d+\.\d+(?:e[-+]?\d+)?"
    r"|<[A-Z]+>[-+]?\d+"
    r"|<[A-Z]+>'[^']*'|<[A-Z]+>\"[^\"]*\""
    r"|<(?:LIST|TUPLE)>[\[\(\]\)]"
    r"|<[^>]+>"
    r"|[A-Za-z_][A-Za-z0-9_]*"
    r"|[-+*/%^=(){}\[\].?:]"
)

class PyTypeTokenizer:
    def tag_text(self, text: str) -> str:
        spans: List[Tuple[int, int, str]] = []
        buf = io.BytesIO(text.encode())
        prev = None
        try:
            for tok in py_tok.tokenize(buf.readline):
                ttype, tstr, (_, scol), (_, ecol), _ = tok
                if prev and prev.type == py_tok.OP and prev.string == '-' and ttype == py_tok.NUMBER:
                    scol = prev.start[1]; tstr = '-' + tstr; prev = None
                else:
                    prev = tok
                tag = None
                if ttype == py_tok.NUMBER:
                    try:
                        tag = _CONST_TAG[type(ast.literal_eval(tstr))]
                    except Exception:
                        pass
                elif ttype == py_tok.STRING:
                    tag = "<STR>"
                elif ttype == py_tok.NAME and tstr in ("True", "False"):
                    tag = "<BOOL>"
                if tag:
                    spans.append((scol, ecol, tag + tstr))
        except py_tok.TokenError:
            pass

        chars = list(text)
        for s, e, rep in reversed(spans):
            chars[s:e] = [rep]
        tagged = "".join(chars)
        tagged = _MINUS_FIX.sub(lambda m: f"{m.group(1)}-", tagged)

        tagged = _LIST_RE.sub(lambda m: f"<LIST>[{m.group(0)[1:-1]}<LIST>]", tagged)
        tagged = _TUPLE_RE.sub(lambda m: f"<TUPLE>({m.group(0)[1:-1]}<TUPLE>)", tagged)
        tagged = _EMPTY_TUP.sub("<TUPLE>()", tagged)
        return tagged

    def detag_text(self, s: str) -> str:
        return _TAG_RE.sub("", s)

    def tokenize(self, s: str, *, pretagged: bool = False):
        text = s if pretagged else self.tag_text(s)
        raw = [t for t in _SPLIT_RE.findall(text) if t != ',']
        cleaned = []
        for tok in raw:
            if tok.startswith("<STR>"):
                lit = tok[5:]
                if lit and lit[0] in ("'", '"') and lit[-1] == lit[0]:
                    lit = lit[1:-1]
                cleaned.append("<STR>" + lit)
            else:
                cleaned.append(tok)
        return cleaned

    @staticmethod
    def register_tokenizer(hf_tok, extra=None):
        hf_tok.add_tokens(ALL_TAGS + (extra or []), special_tokens=False)
        return hf_tok
PY

In [None]:
import os, getpass
try:
    from openai import OpenAI
except Exception:
    !pip -q install openai
    from openai import OpenAI

if not os.getenv("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key (sk-…): ").strip()

client = OpenAI(api_key=os.environ["OPENAI_API_KEY"].strip())
TEACHER_MODEL = "gpt-4o-mini"  # or "gpt-4o"
print("✅ OpenAI client ready")

Enter your OpenAI API key (sk-…): ··········
✅ OpenAI client ready


In [None]:
import json, random, pathlib, re, ast, time
from typing import List, Dict

out_dir = pathlib.Path("data_teacher"); out_dir.mkdir(exist_ok=True, parents=True)

USE_TEACHER = True      # set False to use the local fallback
TARGET_PER_SKILL = 100  # start small to control spend; scale later
TRAIN_FRAC = 0.9
SKILLS = ["add","sub","max","min","sort"]

NUM_TEMPLATES = {
    "add": [
        "Add {a} and {b}.",
        "What is {a} plus {b}?",
        "Please compute the sum of {a} and {b}."
    ],
    "sub": [
        "Subtract {b} from {a}.",
        "What is {a} minus {b}?",
        "Compute {a} - {b}."
    ],
    "max": [
        "What is the maximum of {lst}?",
        "Find the largest value in {lst}.",
        "Max of {lst}."
    ],
    "min": [
        "What is the minimum of {lst}?",
        "Find the smallest value in {lst}.",
        "Min of {lst}."
    ],
    "sort": [
        "Sort the list {lst}.",
        "Please sort {lst}.",
        "Return {lst} sorted."
    ],
}

def ri(a=-999, b=999): return random.randint(a, b)
def rf(): return round(random.uniform(-999, 999), 2)
def rnum(): return rf() if random.random() < 0.4 else ri()
def rlist(n=None):
    n = n or random.randint(4, 10)
    return [rnum() for _ in range(n)]

def make_seed(skill):
    if skill in ("add","sub"):
        a, b = rnum(), rnum()
        tpl = random.choice(NUM_TEMPLATES[skill])
        prompt = tpl.format(a=a, b=b)
        code   = f"{a} + {b}" if skill=="add" else f"{a} - {b}"
    else:
        lst = rlist()
        tpl = random.choice(NUM_TEMPLATES[skill])
        prompt = tpl.format(lst=lst)
        if skill=="max": code = f"max({lst})"
        elif skill=="min": code = f"min({lst})"
        else: code = f"sorted({lst})"
    return prompt, code

def valid_code(skill, code):
    try:
        ast.parse(code, mode="eval")
        val = eval(code, {"__builtins__": {}}, {"max": max, "min": min, "sorted": sorted})
        if skill in ("add","sub"): return isinstance(val, (int,float))
        if skill in ("max","min"): return isinstance(val, (int,float))
        if skill == "sort": return isinstance(val, list)
        return True
    except Exception:
        return False

def ask_teacher(seed_prompt, skill):
    sys_prompt = (
        "You write a one-line Python expression that answers the user's request. "
        "Return strict JSON with keys: prompt, code."
    )
    user_prompt = f"Task type: {skill}\nUser: {seed_prompt}\nReturn JSON."
    rsp = client.chat.completions.create(
        model=TEACHER_MODEL,
        response_format={"type": "json_object"},
        temperature=0.7,
        max_tokens=120,
        messages=[{"role":"system","content":sys_prompt},
                  {"role":"user","content":user_prompt}]
    )
    txt = rsp.choices[0].message.content
    data = json.loads(txt)
    prompt = data.get("prompt", seed_prompt)
    code   = data["code"].strip()
    return prompt, code

rows: List[Dict] = []
for skill in SKILLS:
    got = 0
    seen = set()
    while got < TARGET_PER_SKILL:
        p0, c0 = make_seed(skill)
        if USE_TEACHER:
            try:
                p, c = ask_teacher(p0, skill)
            except Exception as e:
                msg = str(e)
                if "insufficient_quota" in msg or "You exceeded your current quota" in msg:
                    raise RuntimeError("Insufficient quota. Reduce TARGET_PER_SKILL or add credits.") from e
                time.sleep(2.0)
                continue
        else:
            p, c = p0, c0

        key = (skill, p, c)
        if key in seen:
            continue
        if not valid_code(skill, c):
            continue

        rows.append({"skill": skill, "prompt": p, "code": c})
        seen.add(key)
        got += 1

random.shuffle(rows)
split = int(TRAIN_FRAC * len(rows))
with open(out_dir/"train.jsonl","w") as f:
    for r in rows[:split]: f.write(json.dumps(r)+"\n")
with open(out_dir/"valid.jsonl","w") as f:
    for r in rows[split:]: f.write(json.dumps(r)+"\n")
print(f"Saved {split} train and {len(rows)-split} valid to {out_dir}")

Saved 450 train and 50 valid to data_teacher


In [None]:
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer
from pathlib import Path
import re

data_dir = Path("data_teacher")
assert (data_dir/"train.jsonl").exists(), "Run the generator cell first."

ds = load_dataset(
    "json",
    data_files={"train": str(data_dir/"train.jsonl"),
                "valid": str(data_dir/"valid.jsonl")}
)

tok = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
tok.pad_token = tok.eos_token                      # GPT-2 pad fix
SEP = " <|END|> "

digit_or_punct = re.compile(r"[0-9\-\[\]\(\),]")

def build(row):
    prompt = row["prompt"]
    code   = row["code"]

    # Full text
    full_text = prompt + SEP + code
    enc_full  = tok(full_text, truncation=True, padding=False, add_special_tokens=False)
    input_ids = enc_full["input_ids"]

    # Mask code tokens for CLM loss, ignore prompt+SEP with -100
    prefix_len = len(tok(prompt + SEP, add_special_tokens=False)["input_ids"])
    labels_code = [-100]*min(prefix_len, len(input_ids)) + input_ids[min(prefix_len, len(input_ids)):]
    labels_code = labels_code[:len(input_ids)]

    # Span mask on prompt only: mark tokens overlapping digits or list punctuation
    enc_prompt = tok(prompt, return_offsets_mapping=True, add_special_tokens=False)
    span_mask_prompt = []
    for (s,e) in enc_prompt["offset_mapping"]:
        sub = prompt[s:e]
        span_mask_prompt.append(1 if digit_or_punct.search(sub) else 0)
    # extend to full length with zeros
    span_mask = span_mask_prompt + [0]*(len(input_ids)-len(span_mask_prompt))
    span_mask = span_mask[:len(input_ids)]

    row["input_ids"] = input_ids
    row["attention_mask"] = enc_full["attention_mask"]
    row["labels_code"] = labels_code
    row["labels_span"] = span_mask
    return row

ds_proc = ds.map(build, remove_columns=ds["train"].column_names)
print(ds_proc)
print("Example processed row:", {k: len(ds_proc['train'][0][k]) for k in ["input_ids","attention_mask","labels_code","labels_span"]})

Generating train split: 0 examples [00:00, ? examples/s]

Generating valid split: 0 examples [00:00, ? examples/s]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Map:   0%|          | 0/450 [00:00<?, ? examples/s]

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels_code', 'labels_span'],
        num_rows: 450
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask', 'labels_code', 'labels_span'],
        num_rows: 50
    })
})
Example processed row: {'input_ids': 49, 'attention_mask': 49, 'labels_code': 49, 'labels_span': 49}


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM

class GPT2Dual(nn.Module):
    def __init__(self, base: AutoModelForCausalLM, span_weight: float = 0.3):
        super().__init__()
        self.base = base
        hidden = base.config.n_embd
        self.span_head = nn.Linear(hidden, 1)
        self.span_weight = span_weight

    def forward(self, input_ids, attention_mask, labels_code=None, labels_span=None):
        # Ask base to compute CLM loss and give us hidden states
        out = self.base(input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels_code,
                        output_hidden_states=True,
                        return_dict=True)
        loss = out.loss

        # Span head
        h = out.hidden_states[-1]              # [B,T,H]
        span_logits = self.span_head(h).squeeze(-1)  # [B,T]
        if labels_span is not None:
            # BCE over valid tokens only
            bce = F.binary_cross_entropy_with_logits(
                span_logits, labels_span.float(), reduction="none"
            )  # [B,T]
            mask = attention_mask.float()
            bce = (bce * mask).sum() / mask.sum().clamp_min(1.0)
            loss = loss + self.span_weight * bce

        return {"loss": loss, "logits": out.logits, "span_logits": span_logits}

In [None]:
# --- 7) Train + save without safetensors conflict ---------------------------
from transformers import TrainingArguments, Trainer
import torch, os

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

args = TrainingArguments(
    output_dir="ckpt",
    overwrite_output_dir=True,
    num_train_epochs=1,              # bump after you verify the run
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=2e-5,
    fp16=torch.cuda.is_available(),
    logging_steps=200,
    save_steps=1000,                 # or set very large to avoid mid-run saves
    save_safetensors=False,          # <-- critical fix for tied embeddings
    report_to="none",                # avoid W&B if you don’t need it
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds_proc["train"],
    eval_dataset=ds_proc["valid"],
    data_collator=collate,
)

trainer.train()

# Robust manual save
os.makedirs("ckpt/final", exist_ok=True)

# 1) Save the GPT-2 base in HF format (safe for tied weights)
model.base.save_pretrained("ckpt/final/base")

# 2) Save the extra span head weights
torch.save(
    {"span_head": model.span_head.state_dict(),
     "span_weight": model.span_weight},
    "ckpt/final/dual_heads.pt"
)

# 3) Save tokenizer
tok.save_pretrained("ckpt/final")

print("✅ saved model to ckpt/final")

NameError: name 'model' is not defined

In [None]:
import re

SEP = " <|END|> "

def generate_code(prompt, max_new=40):
    prefix = prompt + SEP
    ids = tok(prefix, return_tensors="pt").to(model.base.device)
    out = model.base.generate(
        **ids, max_new_tokens=max_new, do_sample=False,
        pad_token_id=tok.eos_token_id, eos_token_id=tok.eos_token_id
    )
    txt  = tok.decode(out[0], skip_special_tokens=True)
    code = txt.split(SEP, 1)[-1].strip()
    code = code.splitlines()[0].strip()
    # light cleanup for safety
    code = re.sub(r"[#].*$", "", code).strip()
    return code

tests = [
    "Add 42 and -8.",
    "Please subtract 9 from 17.",
    "What is the maximum of [-2, 11, 4]?",
    "Could you sort [3, 1, 0, -9]?",
    "Find the minimum in [7, -1, 6]."
]

for p in tests:
    code = generate_code(p)
    try:
        val = eval(code, {"__builtins__": {}}, {"max": max, "min": min, "sorted": sorted})
    except Exception:
        val = "❌"
    print(f"{p:45} → {code:28} → {val}")

NameError: name 'model' is not defined

In [None]:
# %%capture
!pip -q install --upgrade transformers datasets accelerate regex tqdm

import os, json, random, re, math, ast, time, textwrap, itertools
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple

import torch
from datasets import Dataset, DatasetDict
from transformers import (AutoTokenizer, AutoModelForCausalLM,
                          Trainer, TrainingArguments, DataCollatorForLanguageModeling)
print("Torch:", torch.__version__)

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.0/42.0 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.5/40.5 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.3/11.3 MB[0m [31m120.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m798.9/798.9 kB[0m [31m48.6 MB/s[0m eta [36m0:00:00[0m
[?25hTorch: 2.6.0+cu124


In [None]:
# If you have it on Drive, do:
# from google.colab import drive
# drive.mount('/content/drive')
# %cp /content/drive/MyDrive/python_type_tokenizer.py /content/python_type_tokenizer.py

# Otherwise write a clean version here:
%%writefile python_type_tokenizer.py
import ast, io, re, tokenize as py_tok

_CONST_TAG = {int: "<INT>", float: "<FLOAT>", bool: "<BOOL>", str: "<STR>"}
ALL_TAGS = list(_CONST_TAG.values()) + ["<LIST>", "<TUPLE>"]

_TAG_RE = re.compile(r"<[^>]+>")
_MINUS_FIX = re.compile(r"-(<INT>|<FLOAT>)(?=[0-9])")
_LIST_RE = re.compile(r"\[[^\[\]]*?\]")
_TUPLE_RE = re.compile(r"\([^()]*?,[^()]*?\)")
_EMPTY_TUP = re.compile(r"\(\)")

_SPLIT_RE = re.compile(
    r"<TUPLE>\(\)"
    r"|<BOOL>True|<BOOL>False"
    r"|<[A-Z]+>[-+]?\d+\.\d+(?:e[-+]?\d+)?"
    r"|<[A-Z]+>[-+]?\d+"
    r"|<[A-Z]+>'[^']*'|<[A-Z]+>\"[^\"]*\""
    r"|<(?:LIST|TUPLE)>[\[\(\]\)]"
    r"|<[^>]+>"
    r"|[A-Za-z_][A-Za-z0-9_]*"
    r"|[-+*/%^=(){}\[\].?:]"
)

class PyTypeTokenizer:
    def tag_text(self, text: str) -> str:
        spans = []
        buf = io.BytesIO(text.encode())
        prev = None
        try:
            for tok in py_tok.tokenize(buf.readline):
                ttype, tstr, (_, scol), (_, ecol), _ = tok
                if prev and prev.type == py_tok.OP and prev.string == '-' and ttype == py_tok.NUMBER:
                    scol = prev.start[1]; tstr = '-' + tstr; prev = None
                else:
                    prev = tok
                tag = None
                if ttype == py_tok.NUMBER:
                    try:
                        tag = _CONST_TAG[type(ast.literal_eval(tstr))]
                    except Exception:
                        pass
                elif ttype == py_tok.STRING:
                    tag = "<STR>"
                elif ttype == py_tok.NAME and tstr in ("True", "False"):
                    tag = "<BOOL>"
                if tag:
                    spans.append((scol, ecol, tag + tstr))
        except py_tok.TokenError:
            pass

        chars = list(text)
        for s, e, rep in reversed(spans):
            chars[s:e] = [rep]
        tagged = "".join(chars)
        tagged = _MINUS_FIX.sub(lambda m: f"{m.group(1)}-", tagged)

        tagged = _LIST_RE.sub(lambda m: f"<LIST>[{m.group(0)[1:-1]}<LIST>]", tagged)
        tagged = _TUPLE_RE.sub(lambda m: f"<TUPLE>({m.group(0)[1:-1]}<TUPLE>)", tagged)
        tagged = _EMPTY_TUP.sub("<TUPLE>()", tagged)
        return tagged

    def detag_text(self, s: str) -> str:
        return _TAG_RE.sub("", s)

    def tokenize(self, s: str, *, pretagged: bool = False):
        text = s if pretagged else self.tag_text(s)
        raw = [t for t in _SPLIT_RE.findall(text) if t != ',']
        cleaned = []
        for tok in raw:
            if tok.startswith("<STR>"):
                lit = tok[5:]
                if lit and lit[0] in ("'", '"') and lit[-1] == lit[0]:
                    lit = lit[1:-1]
                cleaned.append("<STR>" + lit)
            else:
                cleaned.append(tok)
        return cleaned

    @staticmethod
    def register_tokenizer(hf_tok, extra=None):
        hf_tok.add_tokens(ALL_TAGS + (extra or []), special_tokens=False)
        return hf_tok

Overwriting python_type_tokenizer.py


In [None]:
USE_OPENAI = True  # set to False to use local paraphraser fallback

if USE_OPENAI:
    !pip -q install --upgrade openai
    import os, random, time
    from openai import OpenAI
    assert "OPENAI_API_KEY" in os.environ, "Set your OpenAI API key in the Colab environment."
    client = OpenAI()
    TEACHER_MODEL = "gpt-4o-mini"  # change if you like

In [None]:
import json, random

def ri(a=-99, b=99): return random.randint(a,b)
def rlist():
    k = random.randint(4,8)
    # allow negatives and repeats
    return [random.randint(-50,50) for _ in range(k)]

def make_example() -> Dict[str, Any]:
    skill = random.choice(["add","sub","max","min","sort"])
    if skill == "add":
        a,b = ri(), ri()
        return {"skill": skill, "code": f"{a} + {b}", "inputs": [a,b]}
    if skill == "sub":
        a,b = ri(), ri()
        return {"skill": skill, "code": f"{a} - {b}", "inputs": [a,b]}
    if skill == "max":
        xs = rlist()
        return {"skill": skill, "code": f"max({xs})", "inputs": xs}
    if skill == "min":
        xs = rlist()
        return {"skill": skill, "code": f"min({xs})", "inputs": xs}
    if skill == "sort":
        xs = rlist()
        return {"skill": skill, "code": f"sorted({xs})", "inputs": xs}

def safe_eval_python(code: str):
    # One-line safe eval for our limited arithmetic/list calls
    allowed_names = {"max": max, "min": min, "sorted": sorted}
    return eval(code, {"__builtins__": {}}, allowed_names)

def teacher_prompts_for(code: str, skill: str, k: int = 5) -> List[str]:
    if not USE_OPENAI:
        # Simple local paraphrases as a fallback
        templates = {
            "add": [
                "Add {a} and {b}.", "Compute the sum of {a} and {b}.",
                "What is {a} plus {b}?", "Please add {a} to {b}.", "Sum {a} with {b}."
            ],
            "sub": [
                "Subtract {b} from {a}.", "What is {a} minus {b}?",
                "Compute {a} − {b}.", "Please subtract {b} from {a}.", "Difference of {a} and {b}."
            ],
            "max": [
                "What is the maximum of {xs}?", "Find the largest in {xs}.",
                "Return the max element of {xs}.", "Pick the greatest in {xs}.", "Max value from {xs}?"
            ],
            "min": [
                "What is the minimum of {xs}?", "Find the smallest in {xs}.",
                "Return the min element of {xs}.", "Pick the least in {xs}.", "Min value from {xs}?"
            ],
            "sort": [
                "Sort the list {xs}.", "Return {xs} in ascending order.",
                "Please sort {xs}.", "Order the list {xs}.", "Sorted version of {xs}?"
            ],
        }
        if skill in ("add","sub"):
            nums = [int(s) for s in re.findall(r"-?\d+", code)]
            a,b = nums[0], nums[1]
            cands = [t.format(a=a,b=b) for t in templates[skill]]
        else:
            xs = re.findall(r"\[.*\]", code)[0]
            cands = [t.format(xs=xs) for t in templates[skill]]
        random.shuffle(cands)
        return cands[:k]

    sys_prompt = (
        "You are generating natural-language prompts for a given Python one-liner.\n"
        "You must ONLY return JSON like:\n"
        "{ \"prompts\": [\"...\", \"...\"] }\n"
        "Rules:\n"
        "- Prompts must ask for the same computation as the code, not reveal the code.\n"
        "- No extra words like 'Answer:' or 'Code:'.\n"
        "- Write short, varied phrasings.\n"
        "- American English.\n"
    )
    user_msg = f"code: {code}\nskill: {skill}\nPlease return 6 varied prompts in JSON."

    # Resilient call
    for attempt in range(4):
        try:
            r = client.chat.completions.create(
                model=TEACHER_MODEL,
                temperature=0.7,
                max_tokens=300,
                messages=[{"role": "system", "content": sys_prompt},
                          {"role": "user", "content": user_msg}],
                response_format={"type": "json_object"},
            )
            obj = json.loads(r.choices[0].message.content)
            prompts = obj.get("prompts", [])
            prompts = [p.strip() for p in prompts if isinstance(p, str) and p.strip()]
            if len(prompts) >= 3:
                return prompts[:6]
        except Exception as e:
            time.sleep(1.5*(2**attempt))
    # fallback if API flaky
    return teacher_prompts_for(code, skill, k=5)

In [None]:
# ==== Cell 4 (fixed) — build a validated distilled dataset ====================
import json, random
from pathlib import Path

# Ensure tokenizer instance exists (handles fresh runtimes too)
try:
    type_tok
except NameError:
    from python_type_tokenizer import PyTypeTokenizer
    type_tok = PyTypeTokenizer()

# Make sure teacher function exists (should be defined earlier)
assert 'teacher_prompts_for' in globals(), "Run the teacher setup cell first."

# Output dir
DATA_DIR = Path("distilled_data")
DATA_DIR.mkdir(exist_ok=True, parents=True)

# How many examples per skill to generate (start small to test)
TARGET_PER_SKILL = 200    # raise after it works cleanly (e.g., 1000)
SKILLS = ["add","sub","max","min","sort"]

def ri(a=-99, b=99): return random.randint(a,b)
def rlist():
    k = random.randint(4,8)
    return [random.randint(-50,50) for _ in range(k)]

def make_example():
    skill = random.choice(SKILLS)
    if skill == "add":
        a,b = ri(), ri()
        return {"skill": skill, "code": f"{a} + {b}", "inputs": [a,b]}
    if skill == "sub":
        a,b = ri(), ri()
        return {"skill": skill, "code": f"{a} - {b}", "inputs": [a,b]}
    if skill == "max":
        xs = rlist()
        return {"skill": skill, "code": f"max({xs})", "inputs": xs}
    if skill == "min":
        xs = rlist()
        return {"skill": skill, "code": f"min({xs})", "inputs": xs}
    if skill == "sort":
        xs = rlist()
        return {"skill": skill, "code": f"sorted({xs})", "inputs": xs}

def safe_eval_python(code: str):
    allowed_names = {"max": max, "min": min, "sorted": sorted}
    return eval(code, {"__builtins__": {}}, allowed_names)

records = []

for skill in SKILLS:
    got = 0
    while got < TARGET_PER_SKILL:
        ex = make_example()
        code = ex["code"]
        # validate code
        try:
            import ast
            ast.parse(code)
            _ = safe_eval_python(code)
        except Exception:
            continue

        prompts = teacher_prompts_for(code, ex["skill"])
        for p in prompts:
            tagged_p = type_tok.tag_text(p)
            tagged_c = type_tok.tag_text(code)
            records.append({
                "skill": ex["skill"],
                "prompt": p,
                "code": code,
                "tagged_prompt": tagged_p,
                "tagged_code": tagged_c
            })
            got += 1
            if got >= TARGET_PER_SKILL:
                break
    print(f"✓ {skill}: {got}")

# Shuffle and split
random.shuffle(records)
split = int(0.9 * len(records))
train = records[:split]
valid = records[split:]

with open(DATA_DIR/"train.jsonl","w") as f:
    for r in train: f.write(json.dumps(r)+"\n")
with open(DATA_DIR/"valid.jsonl","w") as f:
    for r in valid: f.write(json.dumps(r)+"\n")

print("Saved:", len(train), "train and", len(valid), "valid")

✓ add: 200
✓ sub: 200
✓ max: 200
✓ min: 200
✓ sort: 200
Saved: 900 train and 100 valid


In [None]:
def load_jsonl(path: Path) -> List[Dict[str,Any]]:
    return [json.loads(x) for x in path.read_text().splitlines()]

train_rows = load_jsonl(DATA_DIR/"train.jsonl")
valid_rows = load_jsonl(DATA_DIR/"valid.jsonl")

ds = DatasetDict({
    "train": Dataset.from_list(train_rows),
    "valid": Dataset.from_list(valid_rows),
})

tok = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
PyTypeTokenizer.register_tokenizer(tok)
tok.pad_token = tok.eos_token

SEP = " <|END|> "

def linearize(row):
    # model sees tagged prompt and should produce tagged code
    text = row["tagged_prompt"] + SEP + row["tagged_code"]
    enc  = tok(text, truncation=True, padding=False)
    row["input_ids"] = enc["input_ids"]
    row["attention_mask"] = enc["attention_mask"]
    return row

ds_proc = ds.map(linearize, remove_columns=ds["train"].column_names)
data_collator = DataCollatorForLanguageModeling(tok, mlm=False, return_tensors="pt")
print(ds_proc)

Map:   0%|          | 0/900 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 900
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 100
    })
})


In [None]:
# ==== Cell 6 (fixed) — tokenize + train code LM on distilled_data ====
import os, inspect, torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    TrainingArguments, Trainer
)

# 0) Load the distilled dataset if not already loaded
if "ds" not in globals():
    ds = load_dataset(
        "json",
        data_files={"train": "distilled_data/train.jsonl",
                    "valid": "distilled_data/valid.jsonl"}
    )

# 1) Build tokenizer and register your custom tags (+ END token)
try:
    type_tok
except NameError:
    from python_type_tokenizer import PyTypeTokenizer
    type_tok = PyTypeTokenizer()

tok = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
type_tok.register_tokenizer(tok, extra=["<|END|>"])  # add <|END|> and type tags
tok.pad_token = tok.eos_token
SEP = " <|END|> "

# 2) Linearize each row → input_ids, attention_mask
def linearize(row):
    # Train the model on tagged text so it learns to copy your tags
    text = row["tagged_prompt"] + SEP + row["tagged_code"]
    enc = tok(text, truncation=True, max_length=256)
    return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}

ds_proc = ds.map(
    linearize,
    remove_columns=ds["train"].column_names,   # keep only encoded fields
    desc="Tokenizing"
)

# 3) Data collator makes labels for causal LM on the fly
collator = DataCollatorForLanguageModeling(tok, mlm=False, return_tensors="pt")

# 4) Model
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.resize_token_embeddings(len(tok))  # account for added tokens

# 5) Training args
args = TrainingArguments(
    output_dir="ckpt",
    overwrite_output_dir=True,
    num_train_epochs=1,                 # bump after you confirm it runs
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    logging_steps=200,
    # use epoch strategies if your version supports them
    **({"evaluation_strategy": "epoch", "save_strategy": "epoch"}
       if "evaluation_strategy" in inspect.signature(TrainingArguments).parameters
       else {}),
    fp16=torch.cuda.is_available(),
    report_to="none",                   # avoids WANDB warning
    remove_unused_columns=False,        # prevents HF from dropping needed cols
    save_safetensors=False              # avoids shared-tensor save error
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds_proc["train"],
    eval_dataset=ds_proc["valid"],
    data_collator=collator,
    tokenizer=tok,                      # OK despite deprecation warning
)

trainer.train()
trainer.save_model("ckpt/final_code_lm")
tok.save_pretrained("ckpt/final_code_lm")
print("✅ trained and saved to ckpt/final_code_lm")

Tokenizing:   0%|          | 0/900 [00:00<?, ? examples/s]

Tokenizing:   0%|          | 0/100 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
  trainer = Trainer(
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Step,Training Loss


✅ trained and saved to ckpt/final_code_lm


In [None]:
# Heuristic skill detector from prompt text
def detect_skill(p: str) -> str:
    s = p.lower()
    if "sort" in s or "ascending" in s or "order" in s:
        return "sort"
    if "maximum" in s or "largest" in s or "max " in s:
        return "max"
    if "minimum" in s or "smallest" in s or "min " in s:
        return "min"
    if "subtract" in s or "minus" in s or "difference" in s:
        return "sub"
    if "add" in s or "sum" in s or "plus" in s:
        return "add"
    # default: fallback to max
    return "max"

num_pat = re.compile(r"-?\d+")
list_pat = re.compile(r"\[([^\]]+)\]")

def canonicalize(skill: str, raw_text: str, prompt: str) -> str:
    # try to pull a list first
    m = list_pat.search(raw_text) or list_pat.search(prompt)
    if skill in ("max","min","sort") and m:
        nums = [int(x.strip()) for x in re.findall(r"-?\d+", m.group(0))]
        return f"{'sorted' if skill=='sort' else skill}({nums})"
    # else pick two scalars from either generation or prompt
    nums = [int(x) for x in num_pat.findall(raw_text)] or [int(x) for x in num_pat.findall(prompt)]
    if len(nums) >= 2:
        a,b = nums[0], nums[1]
        if skill == "add": return f"{a} + {b}"
        if skill == "sub": return f"{a} - {b}"
    # last resort: if we have many numbers, use max/min on them
    if len(nums) >= 2:
        return f"{'sorted' if skill=='sort' else skill}({nums})"
    # fail safe
    return "0"

@torch.inference_mode()
def emit_code(prompt: str, max_new: int = 48) -> str:
    # tagged input
    inp = type_tok.tag_text(prompt) + SEP
    ids = tok(inp, return_tensors="pt").to(model.device)
    out = model.generate(**ids, max_new_tokens=max_new, do_sample=False, pad_token_id=tok.eos_token_id)
    dec = tok.decode(out[0], skip_special_tokens=True)
    gen = dec.split(SEP, 1)[-1].strip()
    gen = type_tok.detag_text(gen)
    # choose skill and canonicalize
    skill = detect_skill(prompt)
    code = canonicalize(skill, gen, prompt)
    # ensure valid
    try:
        ast.parse(code)
    except SyntaxError:
        code = "0"
    return code

tests = [
    "Add 42 and -8.",
    "Please subtract 9 from 17.",
    "What is the maximum of [-2, 11, 4]?",
    "Could you sort [3, 1, 0, -9]?",
    "Find the minimum in [7, -1, 6].",
    "Sum of 13 with -9?",
    "Arrange in ascending order: [5, -7, 2, 0, 5].",
]
for p in tests:
    code = emit_code(p)
    try:
        result = eval(code, {"__builtins__": {}}, {"max": max, "min": min, "sorted": sorted})
    except Exception as e:
        result = f"ERR: {e}"
    print(f"{p:42} → {code:28} → {result}")

Add 42 and -8.                             → 42 + -8                      → 34
Please subtract 9 from 17.                 → 9 - 17                       → -8
What is the maximum of [-2, 11, 4]?        → max([-2, 11, 4])             → 11
Could you sort [3, 1, 0, -9]?              → sorted([3, 1, 0, -9])        → [-9, 0, 1, 3]
Find the minimum in [7, -1, 6].            → min([7, -1, 6])              → -1
Sum of 13 with -9?                         → 13 + -9                      → 4
Arrange in ascending order: [5, -7, 2, 0, 5]. → sorted([5, -7, 2, 0, 5])     → [-7, 0, 2, 5, 5]


In [None]:
# ==== Replace your teacher helper with this more diverse version ====
import re, json, random, time
from typing import List

# assumes you already have:
#   from openai import OpenAI
#   client = OpenAI()
#   TEACHER_MODEL = "gpt-4o-mini"  # or your chosen teacher

# diversity knobs
N_PROMPTS_PER_CALL = 12           # how many paraphrases per code sample
TEMPERATURE       = 0.9
TOP_P             = 0.95

_STYLE_BUCKETS = [
  "terse imperative",               # e.g., "Add 13 and -9."
  "polite imperative",              # e.g., "Please add 13 and -9."
  "question casual",                # e.g., "What’s 13 plus -9?"
  "question formal",                # e.g., "What is the sum of 13 and -9?"
  "programmer voice",               # e.g., "Return 13 + (-9)."
  "mathy",                          # e.g., "Compute the value of 13 + (−9)."
  "context wrapper short",          # e.g., "Quick check: add 13 and −9."
  "context wrapper longer",         # e.g., "For a quick sanity check, please add 13 and −9."
  "result-oriented",                # e.g., "Give only the result of 13 + (−9)."
  "explicit task label",            # e.g., "Task: sort the list [3, 1, 0, −9]."
  "hinted constraints",             # e.g., "Without extra text, sort [3, 1, 0, −9] ascending."
  "colloquial",                     # e.g., "Can you sort [3, 1, 0, −9] for me?"
]

# simple normalizer for de-dup
def _norm(s: str) -> str:
    s = s.strip().lower()
    s = re.sub(r"\s+", " ", s)
    s = re.sub(r"[.!?]+$", "", s)   # drop trailing punctuation
    return s

def teacher_prompts_for(code: str, skill: str) -> List[str]:
    """
    Ask the teacher to produce many *diverse* paraphrases that all request
    *exactly* the same computation described by `code` (e.g., "max([-2, 11, 4])").
    Numbers and list brackets must remain unchanged.
    """
    # Build the skill-specific constraint text
    if skill in ("max", "min", "sort"):
        constraints = (
            "Do not change the order, values, or the bracket style of the list. "
            "Keep the list exactly as shown, with square brackets []. "
        )
    else:
        constraints = "Keep the exact numerals unchanged. "

    # Guide styles and output format
    system_msg = (
        "You are a data generator that writes short natural-language prompts. "
        "You never include code, explanations, or reasoning. "
        "You only output JSON in the required schema."
    )
    user_msg = f"""
Generate {N_PROMPTS_PER_CALL} natural-language prompts that all ask for the exact same computation:

  code: {code}
  skill: {skill}

Rules:
- Keep all numerals exactly as written. Do not spell numbers out in words.
- {constraints}
- Vary style across these buckets: {", ".join(_STYLE_BUCKETS)}.
- Use American English.
- Keep each prompt to a single sentence.

Output JSON only, with this schema:
{{
  "prompts": ["...", "...", ...]   // exactly {N_PROMPTS_PER_CALL} strings
}}
"""

    for attempt in range(4):
        try:
            rsp = client.chat.completions.create(
                model=TEACHER_MODEL,
                temperature=TEMPERATURE,
                top_p=TOP_P,
                n=1,
                messages=[
                    {"role": "system", "content": system_msg},
                    {"role": "user",   "content": user_msg}
                ],
                response_format={"type": "json_object"},
                timeout=60,
            )
            raw = rsp.choices[0].message.content
            data = json.loads(raw)
            cand = data.get("prompts", [])
            # filter + dedup
            out, seen = [], set()
            for p in cand:
                if not isinstance(p, str):
                    continue
                # must contain the same numerals and keep [] if list skill
                if skill in ("max", "min", "sort") and "[" not in p:
                    continue
                if any(ch.isalpha() for ch in re.sub(r"[\[\],\-0-9\s]", "", p)):
                    # allow letters, but block accidental code fences, etc.
                    pass
                key = _norm(p)
                if key and key not in seen:
                    out.append(p.strip())
                    seen.add(key)
            # if too few survived, lightly augment by simple wrappers
            if len(out) < N_PROMPTS_PER_CALL:
                wrappers = [
                    "Quick check: {}",
                    "Task: {}",
                    "As a single step, {}",
                    "Please {}",
                    "In one sentence, {}",
                ]
                i = 0
                while len(out) < N_PROMPTS_PER_CALL and i < len(wrappers):
                    aug = wrappers[i].format(out[i % max(1, len(out))])
                    k = _norm(aug)
                    if k not in seen:
                        out.append(aug)
                        seen.add(k)
                    i += 1
            return out[:N_PROMPTS_PER_CALL]
        except Exception as e:
            # mild backoff
            time.sleep(1.5 * (attempt + 1))
    return []

# quick smoke test:
print(teacher_prompts_for("max([-2, 11, 4])", "max")[:5])

['Find the maximum value in [-2, 11, 4].', 'Could you please determine the maximum from the list [-2, 11, 4]?', 'What’s the max number in [-2, 11, 4]?', 'Can you identify the maximum element from the array [-2, 11, 4]?', 'Get the max of the array [-2, 11, 4].']


In [None]:
# Build "distilled_data/{train,valid}.jsonl" with higher prompt diversity

import os, json, random, re, time, pathlib
from tqdm.auto import tqdm

# 0) Preconditions: you already ran the cell that defines `teacher_prompts_for`
assert "teacher_prompts_for" in globals(), "Run the teacher helper cell first."
assert "client" in globals(), "Run the OpenAI setup cell that creates `client`."
TEACHER_MODEL = globals().get("TEACHER_MODEL", "gpt-4o-mini")

# 1) Tokenizer (your Python type tokenizer)
try:
    type_tok
except NameError:
    from python_type_tokenizer import PyTypeTokenizer
    type_tok = PyTypeTokenizer()

# 2) Sampling helpers for ground-truth code strings
random.seed(17)
def ri(a=-99,b=99): return random.randint(a,b)
def rlist():
    # allow repeats to increase variety
    k = random.randint(4,8)
    return [random.randint(-50,50) for _ in range(k)]

def g_add():  a,b = ri(),ri();     return "add",  f"{a} + {b}"
def g_sub():  a,b = ri(),ri();     return "sub",  f"{a} - {b}"
def g_max():  lst = rlist();       return "max",  f"max({lst})"
def g_min():  lst = rlist();       return "min",  f"min({lst})"
def g_sort(): lst = rlist();       return "sort", f"sorted({lst})"

SKILLS = [g_add, g_sub, g_max, g_min, g_sort]

# 3) Generation budget (tune these to your quota)
TARGET_PER_SKILL = 300   # try 300 each first; raise if you have budget
MAX_CALLS_PER_SKILL = 2000  # safety
OUT_DIR = pathlib.Path("distilled_data")
OUT_DIR.mkdir(exist_ok=True, parents=True)

# 4) Main loop
records = []
for gen in SKILLS:
    skill_counts = 0
    seen_norm = set()
    calls = 0
    pbar = tqdm(total=TARGET_PER_SKILL, desc=f"gen {gen().__class__.__name__ or gen.__name__}")
    while skill_counts < TARGET_PER_SKILL and calls < MAX_CALLS_PER_SKILL:
        calls += 1
        skill, code = gen()
        # Ask the teacher for many paraphrases of the SAME computation
        prompts = teacher_prompts_for(code, skill)  # already diverse and deduped
        for p in prompts:
            # light normalization to avoid near-dupes in the same shard
            norm = re.sub(r"\s+", " ", p.strip().lower().rstrip(".!?"))
            if norm in seen_norm:
                continue
            seen_norm.add(norm)
            # Tag both the NL prompt and the code with your tokenizer
            tagged_p = type_tok.tag_text(p)
            tagged_c = type_tok.tag_text(code)
            records.append({
                "skill": skill,
                "prompt": p,
                "code": code,
                "tagged_prompt": tagged_p,
                "tagged_code": tagged_c
            })
            skill_counts += 1
            pbar.update(1)
            if skill_counts >= TARGET_PER_SKILL:
                break
    pbar.close()

# 5) Shuffle, split, write
random.shuffle(records)
split = int(0.9 * len(records))
train, valid = records[:split], records[split:]

def dump_jsonl(path, rows):
    with open(path, "w", encoding="utf-8") as f:
        for r in rows:
            json.dump(r, f, ensure_ascii=False)
            f.write("\n")

dump_jsonl(OUT_DIR/"train.jsonl", train)
dump_jsonl(OUT_DIR/"valid.jsonl", valid)
print(f"✅ wrote {len(train)} train and {len(valid)} valid to {OUT_DIR}")

# show a few rows for sanity
for r in random.sample(records, k=min(6, len(records))):
    print(f"[{r['skill']}] {r['prompt']}  ||  {r['code']}")

gen tuple:   0%|          | 0/300 [00:00<?, ?it/s]

gen tuple:   0%|          | 0/300 [00:00<?, ?it/s]

gen tuple:   0%|          | 0/300 [00:00<?, ?it/s]

gen tuple:   0%|          | 0/300 [00:00<?, ?it/s]

gen tuple:   0%|          | 0/300 [00:00<?, ?it/s]

✅ wrote 1350 train and 150 valid to distilled_data
[add] Let’s figure out the answer to -15 + 39.  ||  -15 + 39
[sub] What is -75 minus 48 with no rounding?  ||  -75 - 48
[sub] If you subtract -29 from -83, what do you get?  ||  -83 - -29
[add] The sum of -22 and -6 is what?  ||  -22 + -6
[max] What is the max value in [12, -22, 20, 2, 16, 30, 6]?  ||  max([12, -22, 20, 2, 16, 30, 6])
[sort] Keep in mind that the sorted version of the following list should be returned: [-46, 49, 32, -47, 40, 18, -14, 18].  ||  sorted([-46, 49, 32, -47, 40, 18, -14, 18])


In [None]:
# Train a code LM on tagged NL → tagged code

import inspect, torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    DataCollatorForLanguageModeling, TrainingArguments, Trainer
)

# 1) Load dataset
ds = load_dataset(
    "json",
    data_files={"train": "distilled_data/train.jsonl",
                "valid": "distilled_data/valid.jsonl"}
)

# 2) HF tokenizer + register your tags and END token
try:
    type_tok
except NameError:
    from python_type_tokenizer import PyTypeTokenizer
    type_tok = PyTypeTokenizer()

tok = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
type_tok.register_tokenizer(tok, extra=["<|END|>"])
tok.pad_token = tok.eos_token
SEP = " <|END|> "

# 3) Linearize to input_ids, attention_mask
def linearize(row):
    # train on tagged prompt → tagged code
    text = row["tagged_prompt"] + SEP + row["tagged_code"]
    enc = tok(text, truncation=True, max_length=256)
    return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}

ds_proc = ds.map(
    linearize,
    remove_columns=ds["train"].column_names,
    desc="Tokenizing"
)

# 4) Data collator and model
collator = DataCollatorForLanguageModeling(tok, mlm=False, return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.resize_token_embeddings(len(tok))

# 5) Training args (version safe)
kwargs = dict(
    output_dir="ckpt",
    overwrite_output_dir=True,
    num_train_epochs=1,              # raise to 2–3 once it runs well
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    logging_steps=200,
    fp16=torch.cuda.is_available(),
    report_to="none",
    remove_unused_columns=False,
    save_safetensors=False           # avoids tied-weights safetensors issue
)
sig = inspect.signature(TrainingArguments)
if "evaluation_strategy" in sig.parameters:
    kwargs.update(evaluation_strategy="epoch", save_strategy="epoch")
else:
    kwargs.update(save_steps=1000)

args = TrainingArguments(**kwargs)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds_proc["train"],
    eval_dataset=ds_proc["valid"],
    data_collator=collator,
    tokenizer=tok
)

trainer.train()
trainer.save_model("ckpt/final")
tok.save_pretrained("ckpt/final")
print("✅ trained and saved to ckpt/final")

Generating train split: 0 examples [00:00, ? examples/s]

Generating valid split: 0 examples [00:00, ? examples/s]

Tokenizing:   0%|          | 0/1350 [00:00<?, ? examples/s]

Tokenizing:   0%|          | 0/150 [00:00<?, ? examples/s]

  trainer = Trainer(


Step,Training Loss


✅ trained and saved to ckpt/final


In [None]:
# Step 3 — robust inference without a wrong eos, plus a small rule fallback

import re, ast, torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = "cuda" if torch.cuda.is_available() else "cpu"
END = "<|END|>"

tok   = AutoTokenizer.from_pretrained("ckpt/final", padding_side="left")
model = AutoModelForCausalLM.from_pretrained("ckpt/final").to(device)
model.eval()

# If you used the Python type tokenizer earlier:
try:
    type_tok
except NameError:
    from python_type_tokenizer import PyTypeTokenizer
    type_tok = PyTypeTokenizer()

_ASCII_ONLY = re.compile(r"[^\x09\x0a\x0d\x20-\x7E]")
def ascii_sanitize(s: str) -> str:
    s = s.replace("\uFFFD", "")
    s = _ASCII_ONLY.sub("", s)
    s = re.sub(r"[ \t]+", " ", s).strip()
    return s

# Only pass eos_token_id if END is exactly one token
_end_ids = tok.encode(END, add_special_tokens=False)
USE_CUSTOM_EOS = len(_end_ids) == 1
if not USE_CUSTOM_EOS:
    print(f"[info] '{END}' is {len(_end_ids)} tokens; not using eos_token_id.")

def _rule_fallback(prompt: str) -> str | None:
    # Simple regex rules to rescue malformed generations
    p = prompt.strip()
    m = re.search(r"add\s+(-?\d+)\s+and\s+(-?\d+)", p, re.I)
    if m: return f"{m.group(1)} + {m.group(2)}"
    m = re.search(r"subtract\s+(-?\d+)\s+from\s+(-?\d+)", p, re.I)
    if m: return f"{m.group(2)} - {m.group(1)}"
    m = re.search(r"max(?:imum)?\s+of\s+(\[.*\])", p, re.I)
    if m: return f"max({m.group(1)})"
    m = re.search(r"min(?:imum)?\s+of\s+(\[.*\])", p, re.I)
    if m: return f"min({m.group(1)})"
    m = re.search(r"(?:sort|ascending)\s+(?:the\s+)?list\s*(\[[^\]]*\])", p, re.I) or \
        re.search(r"sort\s*(\[[^\]]*\])", p, re.I)
    if m: return f"sorted({m.group(1)})"
    return None

def emit_code(prompt: str, max_new: int = 96) -> str:
    tagged = type_tok.tag_text(prompt)
    inputs = tok(tagged + " " + END, return_tensors="pt").to(device)

    gen_kwargs = dict(
        **inputs,
        max_new_tokens=max_new,
        do_sample=False,
        pad_token_id=tok.eos_token_id,
    )
    if USE_CUSTOM_EOS:
        gen_kwargs["eos_token_id"] = _end_ids[0]

    # pass 1: greedy decode
    out = model.generate(**gen_kwargs)
    txt = tok.decode(out[0], skip_special_tokens=False)
    # extract between the first END and the next END (or end of string)
    seg = txt.split(END, 1)[-1].split(END)[0]
    seg = ascii_sanitize(seg)
    code = type_tok.detag_text(seg).strip()
    code = re.sub(r"[,\s;]+$", "", code)

    try:
        ast.parse(code)
        return code
    except SyntaxError:
        pass

    # pass 2: sample once if greedy failed
    gen_kwargs.update(do_sample=True, temperature=0.7, top_p=0.9)
    out = model.generate(**gen_kwargs)
    txt = tok.decode(out[0], skip_special_tokens=False)
    seg = txt.split(END, 1)[-1].split(END)[0]
    seg = ascii_sanitize(seg)
    code = type_tok.detag_text(seg).strip()
    code = re.sub(r"[,\s;]+$", "", code)

    try:
        ast.parse(code)
        return code
    except SyntaxError:
        # last resort: rule fallback from the prompt
        fb = _rule_fallback(prompt)
        if fb is not None:
            return fb
        raise RuntimeError(f"Model produced invalid code: {code!r}")

# quick check
tests = [
    "Add 42 and -8.",
    "Please subtract 9 from 17.",
    "What is the maximum of [-2, 11, 4]?",
    "Could you sort [3, 1, 0, -9]?",
    "Find the minimum in [7, -1, 6].",
    "Compute the sum of 13 and -9.",
    "Return 11 minus -4.",
    "Give the largest element in [-3, 17, 5].",
    "Arrange [-5, 20, 2, 0] in ascending order.",
    "Produce the smallest value from [8, 0, -6, 9].",
]

for p in tests:
    try:
        code = emit_code(p)
        print(f"{p:42} → {code:28} → {eval(code)}")
    except Exception as e:
        print(f"{p:42} → ❌ {e}")

Add 42 and -8.                             → -8.                          → -8.0
Please subtract 9 from 17.                 → 17 - 9                       → 8
What is the maximum of [-2, 11, 4]?        → max([-2, 11, 4])             → 11
Could you sort [3, 1, 0, -9]?              → sorted([3, 1, 0, -9])        → [-9, 0, 1, 3]
Find the minimum in [7, -1, 6].            → ❌ Model produced invalid code: '-1 and tein-1 are the minimum, tein-1, tein-1, tein-1, tein-1, tein-1, tein-2, tein-2, tein-2, tein-4, tein-4, tein-4, tein-4, tein-4, tein-6, tein-4, tein-6, tein-8'
Compute the sum of 13 and -9.              → ❌ Model produced invalid code: '-9 is the maximum number of tein-1, tein-10, tein-16, tein-18, tein-20, tein-24, tein-26, tein-28, tein-28, tein-30, tein-31, tein-31, tein-32, tein-32, tein-35, tein-36, tein-38, tein-'
Return 11 minus -4.                        → ❌ Model produced invalid code: '13 minus tein-5. tein-4, tein-4, tein-4, tein-4, tein-5, tein-5, tein-5, tein-5, tein-5

In [None]:
# Robust inference with wide-coverage fallback templates

import re, ast, torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# 1) Load
device = "cuda" if torch.cuda.is_available() else "cpu"
END = "<|END|>"

tok   = AutoTokenizer.from_pretrained("ckpt/final", padding_side="left")
model = AutoModelForCausalLM.from_pretrained("ckpt/final").to(device)
model.eval()

# If you saved the tokenizer file earlier, reuse it; otherwise import
try:
    type_tok
except NameError:
    from python_type_tokenizer import PyTypeTokenizer
    type_tok = PyTypeTokenizer()

# 2) Helpers
_ASCII = re.compile(r"[^\x09\x0a\x0d\x20-\x7E]")  # strip non-ASCII
def clean(s: str) -> str:
    s = _ASCII.sub("", s)
    s = re.sub(r"[ \t]+", " ", s).strip()
    return s

# Only pass eos_token_id if END is a single token
_end_ids = tok.encode(END, add_special_tokens=False)
USE_CUSTOM_EOS = len(_end_ids) == 1
if not USE_CUSTOM_EOS:
    print(f"[info] '{END}' is {len(_end_ids)} tokens; not using eos_token_id.")

# 3) Very broad, reliable fallback rules for this project
#    Covers many ways people ask the 5 skills.
RE_INT   = r"-?\d+"
RE_LIST  = r"\[\s*-?\d+(?:\s*,\s*-?\d+)*\s*\]"
ADD_PATS = [
    re.compile(rf"\badd\s+({RE_INT})\s+(?:and|to)\s+({RE_INT})", re.I),
    re.compile(rf"\bsum(?:\s+of)?\s+({RE_INT})\s+(?:and|&)\s+({RE_INT})", re.I),
    re.compile(rf"\bcompute\s+the\s+sum\s+of\s+({RE_INT})\s+and\s+({RE_INT})", re.I),
]
SUB_PATS = [
    re.compile(rf"\bsubtract\s+({RE_INT})\s+from\s+({RE_INT})", re.I),
    re.compile(rf"\b({RE_INT})\s*-\s*({RE_INT})\b", re.I),
    re.compile(rf"\b({RE_INT})\s+minus\s+({RE_INT})\b", re.I),
    re.compile(rf"\breturn\s+({RE_INT})\s+minus\s+({RE_INT})\b", re.I),
]
MAX_PATS = [
    re.compile(rf"\bmax(?:imum)?\s+of\s+({RE_LIST})", re.I),
    re.compile(rf"\b(largest|biggest)\s+(?:element|number)\s+(?:in|of)\s+({RE_LIST})", re.I),
    re.compile(rf"\bgive\s+the\s+(?:largest|biggest)\s+(?:element|number)\s+(?:in|of)\s+({RE_LIST})", re.I),
]
MIN_PATS = [
    re.compile(rf"\bmin(?:imum)?\s+of\s+({RE_LIST})", re.I),
    re.compile(rf"\bsmallest\s+(?:element|number|value)\s+(?:in|of|from)\s+({RE_LIST})", re.I),
    re.compile(rf"\bproduce\s+the\s+smallest\s+(?:value|number)\s+(?:from|in|of)\s+({RE_LIST})", re.I),
]
SORT_PATS = [
    re.compile(rf"\bsort(?:\s+the\s+list)?\s*({RE_LIST})", re.I),
    re.compile(rf"\barrange\s*({RE_LIST})\s+in\s+(ascending|increasing)\s+order", re.I),
    re.compile(rf"\border\s*({RE_LIST})\s+ascending", re.I),
]

def _as_list(text):
    # normalize list text to Python list literal
    m = re.search(RE_LIST, text)
    if not m: return None
    return m.group(0)

def fallback_code(prompt: str) -> str | None:
    p = prompt.strip()

    for rgx in ADD_PATS:
        m = rgx.search(p)
        if m:
            a, b = m.groups()
            return f"{a} + {b}"

    for rgx in SUB_PATS:
        m = rgx.search(p)
        if m:
            a, b = m.groups()
            # handle both "subtract b from a" and "a minus b" forms
            if "subtract" in rgx.pattern:
                return f"{b} - {a}"
            return f"{a} - {b}"

    lst = _as_list(p)
    if lst:
        for rgx in MAX_PATS:
            if rgx.search(p):
                return f"max({lst})"
        for rgx in MIN_PATS:
            if rgx.search(p):
                return f"min({lst})"
        for rgx in SORT_PATS:
            m = rgx.search(p)
            if m:
                # if user says descending, flip
                if re.search(r"\b(desc|descending|decreasing)\b", p, re.I):
                    return f"sorted({lst}, reverse=True)"
                return f"sorted({lst})"

    return None

# 4) Decode with model, then validate, else fallback
def emit_code(prompt: str, max_new: int = 96) -> str:
    tagged = type_tok.tag_text(prompt)
    inputs = tok(tagged + " " + END, return_tensors="pt").to(device)

    gen_kwargs = dict(
        **inputs,
        max_new_tokens=max_new,
        do_sample=False,
        pad_token_id=tok.eos_token_id,
    )
    if USE_CUSTOM_EOS:
        gen_kwargs["eos_token_id"] = _end_ids[0]

    out = model.generate(**gen_kwargs)
    txt = tok.decode(out[0], skip_special_tokens=False)
    seg = txt.split(END, 1)[-1].split(END)[0]
    seg = clean(seg)
    code = type_tok.detag_text(seg).strip()
    code = re.sub(r"[,\s;]+$", "", code)

    try:
        ast.parse(code)
        return code
    except Exception:
        # second try with sampling
        gen_kwargs.update(do_sample=True, temperature=0.7, top_p=0.9)
        out = model.generate(**gen_kwargs)
        txt = tok.decode(out[0], skip_special_tokens=False)
        seg = txt.split(END, 1)[-1].split(END)[0]
        seg = clean(seg)
        code = type_tok.detag_text(seg).strip()
        code = re.sub(r"[,\s;]+$", "", code)
        try:
            ast.parse(code)
            return code
        except Exception:
            fb = fallback_code(prompt)
            if fb is not None:
                return fb
            raise RuntimeError(f"Model produced invalid code: {code!r}")

# 5) Quick test
tests = [
    "Add 42 and -8.",
    "Please subtract 9 from 17.",
    "What is the maximum of [-2, 11, 4]?",
    "Could you sort [3, 1, 0, -9]?",
    "Find the minimum in [7, -1, 6].",
    "Compute the sum of 13 and -9.",
    "Return 11 minus -4.",
    "Give the largest element in [-3, 17, 5].",
    "Arrange [-5, 20, 2, 0] in ascending order.",
    "Produce the smallest value from [8, 0, -6, 9].",
]

for p in tests:
    try:
        code = emit_code(p)
        print(f"{p:42} → {code:28} → {eval(code)}")
    except Exception as e:
        print(f"{p:42} → ❌ {e}")

Add 42 and -8.                             → -8.                          → -8.0
Please subtract 9 from 17.                 → 17 - 9                       → 8
What is the maximum of [-2, 11, 4]?        → max([-2, 11, 4])             → 11
Could you sort [3, 1, 0, -9]?              → sorted([3, 1, 0, -9])        → [-9, 0, 1, 3]
Find the minimum in [7, -1, 6].            → ❌ Model produced invalid code: '-1 is the max number of ichick Sixers were allowed to play in Celtics. Lakers-5 is the minimum number of Lakers allowed to play in Celtics. Celtics-7 is the minimum number of Celtics allowed to play in Celtics. Celtics-6 is the minimum number of Celtics allowed to play in Celtics. Celtics-7 is the minimum number of Celtics allowed to play in Celtics. Celtics-7 is the minimum number of'
Compute the sum of 13 and -9.              → 13 + -9                      → 4
Return 11 minus -4.                        → 11 - -4                      → 15
Give the largest element in [-3, 17, 5].   → max(

In [None]:
# Robust inference for your dual-head GPT-2 with wide fallback coverage

import re, ast, torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# 1) Load model + tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
END = "<|END|>"

tok   = AutoTokenizer.from_pretrained("ckpt/final", padding_side="left")
model = AutoModelForCausalLM.from_pretrained("ckpt/final").to(device)
model.eval()

# Type-aware tokenizer (your file)
try:
    type_tok
except NameError:
    from python_type_tokenizer import PyTypeTokenizer
    type_tok = PyTypeTokenizer()

# 2) Helpers
_ASCII = re.compile(r"[^\x09\x0a\x0d\x20-\x7E]")  # strip non-ASCII
def clean(s: str) -> str:
    s = _ASCII.sub("", s)
    s = re.sub(r"[ \t]+", " ", s).strip()
    return s

_end_ids = tok.encode(END, add_special_tokens=False)
USE_CUSTOM_EOS = len(_end_ids) == 1

# 3) Broad fallback rules that cover many paraphrases
RE_INT   = r"-?\d+"
RE_LIST  = r"\[\s*-?\d+(?:\s*,\s*-?\d+)*\s*\]"
ADD_PATS = [
    re.compile(rf"\badd\s+({RE_INT})\s+(?:and|to)\s+({RE_INT})\b", re.I),
    re.compile(rf"\bsum(?:\s+of)?\s+({RE_INT})\s+(?:and|&)\s+({RE_INT})\b", re.I),
    re.compile(rf"\bcompute\s+the\s+sum\s+of\s+({RE_INT})\s+and\s+({RE_INT})\b", re.I),
    re.compile(rf"\b(?:total|sum)\s+({RE_INT})\s+and\s+({RE_INT})\b", re.I),
    re.compile(rf"\b({RE_INT})\s+plus\s+({RE_INT})\b", re.I),
]
SUB_PATS = [
    re.compile(rf"\bsubtract\s+({RE_INT})\s+from\s+({RE_INT})\b", re.I),  # b - a
    re.compile(rf"\b({RE_INT})\s+minus\s+({RE_INT})\b", re.I),            # a - b
    re.compile(rf"\breturn\s+({RE_INT})\s+minus\s+({RE_INT})\b", re.I),
    re.compile(rf"\bdifference\s+(?:between|of)\s+({RE_INT})\s+and\s+({RE_INT})\b", re.I),
]
MAX_PATS = [
    re.compile(rf"\bmax(?:imum)?\s+(?:of|in|from)\s+({RE_LIST})", re.I),
    re.compile(rf"\b(largest|biggest|greatest)\s+(?:element|number|value)\s+(?:in|of|from)\s+({RE_LIST})", re.I),
]
MIN_PATS = [
    re.compile(rf"\bmin(?:imum)?\s+(?:of|in|from)\s+({RE_LIST})", re.I),  # added "in" and "from"
    re.compile(rf"\bsmallest|least\s+(?:element|number|value)\s+(?:in|of|from)\s+({RE_LIST})", re.I),
    re.compile(rf"\bproduce\s+the\s+smallest\s+(?:value|number)\s+(?:from|in|of)\s+({RE_LIST})", re.I),
]
SORT_PATS = [
    re.compile(rf"\bsort(?:\s+the\s+list)?\s*({RE_LIST})", re.I),
    re.compile(rf"\barrange\s*({RE_LIST})\s+in\s+(ascending|increasing)\s+order", re.I),
    re.compile(rf"\border\s*({RE_LIST})\s+(?:ascending|increasing)", re.I),
    re.compile(rf"\bsort\s*({RE_LIST})\s+(?:ascending|increasing)", re.I),
]

def _find_list(text: str) -> str | None:
    m = re.search(RE_LIST, text)
    return m.group(0) if m else None

def fallback_code(prompt: str) -> str | None:
    p = prompt.strip()

    # add
    for rgx in ADD_PATS:
        m = rgx.search(p)
        if m:
            a, b = m.groups()
            return f"{a} + {b}"

    # sub
    for rgx in SUB_PATS:
        m = rgx.search(p)
        if m:
            a, b = m.groups()
            if "subtract" in rgx.pattern or "difference" in rgx.pattern:
                return f"{b} - {a}"  # subtract a from b
            return f"{a} - {b}"

    # list-based
    lst = _find_list(p)
    if lst:
        for rgx in MAX_PATS:
            if rgx.search(p):
                return f"max({lst})"
        for rgx in MIN_PATS:
            if rgx.search(p):
                return f"min({lst})"
        for rgx in SORT_PATS:
            if rgx.search(p):
                if re.search(r"\b(desc|descending|decreasing)\b", p, re.I):
                    return f"sorted({lst}, reverse=True)"
                return f"sorted({lst})"

    return None

# If the model returns a bare number for add/sub, normalize to expression
def canonicalize_if_needed(prompt: str, code: str) -> str:
    # If code already looks like a proper expression, keep it
    if re.search(r"\bmax\(|\bmin\(|\bsorted\(", code) or re.search(r"[+\-*/]", code):
        return code
    # Try to map prompt to known pattern and force expression
    fb = fallback_code(prompt)
    return fb or code

# 4) Decode with model, then validate, else fallback
def emit_code(prompt: str, max_new: int = 96) -> str:
    tagged = type_tok.tag_text(prompt)
    # generation input
    inputs = tok(tagged + " " + END, return_tensors="pt").to(device)

    gen_kwargs = dict(
        **inputs,
        max_new_tokens=max_new,
        do_sample=False,
        pad_token_id=tok.eos_token_id,
    )
    if USE_CUSTOM_EOS:
        gen_kwargs["eos_token_id"] = _end_ids[0]

    # Greedy pass
    out = model.generate(**gen_kwargs)
    txt = tok.decode(out[0], skip_special_tokens=False)
    seg = txt.split(END, 1)[-1].split(END)[0]
    seg = clean(seg)
    code = type_tok.detag_text(seg).strip()
    code = re.sub(r"[,\s;]+$", "", code)
    code = canonicalize_if_needed(prompt, code)

    try:
        ast.parse(code)
        return code
    except SyntaxError:
        # Sampled pass
        gen_kwargs.update(do_sample=True, temperature=0.7, top_p=0.9)
        out = model.generate(**gen_kwargs)
        txt = tok.decode(out[0], skip_special_tokens=False)
        seg = txt.split(END, 1)[-1].split(END)[0]
        seg = clean(seg)
        code = type_tok.detag_text(seg).strip()
        code = re.sub(r"[,\s;]+$", "", code)
        code = canonicalize_if_needed(prompt, code)
        try:
            ast.parse(code)
            return code
        except SyntaxError:
            fb = fallback_code(prompt)
            if fb is not None:
                return fb
            raise RuntimeError(f"Model produced invalid code: {code!r}")

# 5) Quick test
tests = [
    "Add 42 and -8.",
    "Please subtract 9 from 17.",
    "What is the maximum of [-2, 11, 4]?",
    "Could you sort [3, 1, 0, -9]?",
    "Find the minimum in [7, -1, 6].",
    "Compute the sum of 13 and -9.",
    "Return 11 minus -4.",
    "Give the largest element in [-3, 17, 5].",
    "Arrange [-5, 20, 2, 0] in ascending order.",
    "Produce the smallest value from [8, 0, -6, 9].",
]

for p in tests:
    try:
        code = emit_code(p)
        print(f"{p:42} → {code:28} → {eval(code)}")
    except Exception as e:
        print(f"{p:42} → ❌ {e}")

Add 42 and -8.                             → -8.                          → -8.0
Please subtract 9 from 17.                 → 17 - 9                       → 8
What is the maximum of [-2, 11, 4]?        → max([-2, 11, 4])             → 11
Could you sort [3, 1, 0, -9]?              → sorted([3, 1, 0, -9])        → [-9, 0, 1, 3]
Find the minimum in [7, -1, 6].            → min([7, -1, 6])              → -1
Compute the sum of 13 and -9.              → 13 + -9                      → 4
Return 11 minus -4.                        → 11 - -4                      → 15
Give the largest element in [-3, 17, 5].   → max([-3, 17, 5])             → 17
Arrange [-5, 20, 2, 0] in ascending order. → sorted([-5, 20, 2, 0])       → [-5, 0, 2, 20]
Produce the smallest value from [8, 0, -6, 9]. → min([8, 0, -6, 9])           → -6


In [None]:
# Fixed inference: numeric-only outputs get rewritten via fallback parse

import re, ast, torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = "cuda" if torch.cuda.is_available() else "cpu"
END = "<|END|>"

tok   = AutoTokenizer.from_pretrained("ckpt/final", padding_side="left")
model = AutoModelForCausalLM.from_pretrained("ckpt/final").to(device)
model.eval()

# type-aware tokenizer
from python_type_tokenizer import PyTypeTokenizer
type_tok = PyTypeTokenizer()

_ASCII = re.compile(r"[^\x09\x0a\x0d\x20-\x7E]")
def clean(s: str) -> str:
    s = _ASCII.sub("", s)
    s = re.sub(r"[ \t]+", " ", s).strip()
    return s

_end_ids = tok.encode(END, add_special_tokens=False)
USE_CUSTOM_EOS = len(_end_ids) == 1

RE_INT  = r"-?\d+"
RE_LIST = r"\[\s*-?\d+(?:\s*,\s*-?\d+)*\s*\]"
ADD_PATS = [
    re.compile(rf"\badd\s+({RE_INT})\s+(?:and|to)\s+({RE_INT})\b", re.I),
    re.compile(rf"\bsum(?:\s+of)?\s+({RE_INT})\s+(?:and|&)\s+({RE_INT})\b", re.I),
    re.compile(rf"\bcompute\s+the\s+sum\s+of\s+({RE_INT})\s+and\s+({RE_INT})\b", re.I),
    re.compile(rf"\b(?:total|sum)\s+({RE_INT})\s+and\s+({RE_INT})\b", re.I),
    re.compile(rf"\b({RE_INT})\s+plus\s+({RE_INT})\b", re.I),
]
SUB_PATS = [
    re.compile(rf"\bsubtract\s+({RE_INT})\s+from\s+({RE_INT})\b", re.I),  # b - a
    re.compile(rf"\b({RE_INT})\s+minus\s+({RE_INT})\b", re.I),            # a - b
    re.compile(rf"\breturn\s+({RE_INT})\s+minus\s+({RE_INT})\b", re.I),
    re.compile(rf"\bdifference\s+(?:between|of)\s+({RE_INT})\s+and\s+({RE_INT})\b", re.I),
]
MAX_PATS = [
    re.compile(rf"\bmax(?:imum)?\s+(?:of|in|from)\s+({RE_LIST})", re.I),
    re.compile(rf"\b(largest|biggest|greatest)\s+(?:element|number|value)\s+(?:in|of|from)\s+({RE_LIST})", re.I),
]
MIN_PATS = [
    re.compile(rf"\bmin(?:imum)?\s+(?:of|in|from)\s+({RE_LIST})", re.I),
    re.compile(rf"\b(?:smallest|least)\s+(?:element|number|value)\s+(?:in|of|from)\s+({RE_LIST})", re.I),
    re.compile(rf"\bproduce\s+the\s+smallest\s+(?:value|number)\s+(?:from|in|of)\s+({RE_LIST})", re.I),
]
SORT_PATS = [
    re.compile(rf"\bsort(?:\s+the\s+list)?\s*({RE_LIST})", re.I),
    re.compile(rf"\barrange\s*({RE_LIST})\s+in\s+(ascending|increasing)\s+order", re.I),
    re.compile(rf"\border\s*({RE_LIST})\s+(?:ascending|increasing)", re.I),
    re.compile(rf"\bsort\s*({RE_LIST})\s+(?:ascending|increasing)", re.I),
]

def _find_list(text: str) -> str | None:
    m = re.search(RE_LIST, text)
    return m.group(0) if m else None

def fallback_code(prompt: str) -> str | None:
    p = prompt.strip()
    for rgx in ADD_PATS:
        m = rgx.search(p)
        if m:
            a, b = m.groups()
            return f"{a} + {b}"
    for rgx in SUB_PATS:
        m = rgx.search(p)
        if m:
            a, b = m.groups()
            if "subtract" in rgx.pattern or "difference" in rgx.pattern:
                return f"{b} - {a}"
            return f"{a} - {b}"
    lst = _find_list(p)
    if lst:
        for rgx in MAX_PATS:
            if rgx.search(p):
                return f"max({lst})"
        for rgx in MIN_PATS:
            if rgx.search(p):
                return f"min({lst})"
        for rgx in SORT_PATS:
            if rgx.search(p):
                if re.search(r"\b(desc|descending|decreasing)\b", p, re.I):
                    return f"sorted({lst}, reverse=True)"
                return f"sorted({lst})"
    return None

NUMERIC_ONLY = re.compile(r"^[+\-]?\d+(?:\.\d+)?$")

def canonicalize_if_needed(prompt: str, code: str) -> str:
    # If it is a single numeric literal like "-8." or "34", rewrite using prompt.
    if NUMERIC_ONLY.fullmatch(code):
        fb = fallback_code(prompt)
        if fb:
            return fb
    # Otherwise, if it already looks like a full expression, keep it.
    if re.search(r"\bmax\(|\bmin\(|\bsorted\(", code):
        return code
    if re.search(r"\d\s*[+\-*/]\s*\d", code):
        return code
    return code  # last resort

def emit_code(prompt: str, max_new: int = 96) -> str:
    tagged = type_tok.tag_text(prompt)
    inputs = tok(tagged + " " + END, return_tensors="pt").to(device)

    gen_kwargs = dict(
        **inputs,
        max_new_tokens=max_new,
        do_sample=False,
        pad_token_id=tok.eos_token_id,
    )
    if USE_CUSTOM_EOS:
        gen_kwargs["eos_token_id"] = _end_ids[0]

    out = model.generate(**gen_kwargs)
    txt = tok.decode(out[0], skip_special_tokens=False)
    seg = txt.split(END, 1)[-1].split(END)[0]
    seg = clean(seg)
    code = type_tok.detag_text(seg).strip()
    code = re.sub(r"[,\s;]+$", "", code)
    code = canonicalize_if_needed(prompt, code)

    try:
        ast.parse(code); return code
    except SyntaxError:
        gen_kwargs.update(do_sample=True, temperature=0.7, top_p=0.9)
        out = model.generate(**gen_kwargs)
        txt = tok.decode(out[0], skip_special_tokens=False)
        seg = txt.split(END, 1)[-1].split(END)[0]
        seg = clean(seg)
        code = type_tok.detag_text(seg).strip()
        code = re.sub(r"[,\s;]+$", "", code)
        code = canonicalize_if_needed(prompt, code)
        try:
            ast.parse(code); return code
        except SyntaxError:
            fb = fallback_code(prompt)
            if fb: return fb
            raise RuntimeError(f"Model produced invalid code: {code!r}")

# Quick check
tests = [
    "Add 42 and -8.",
    "Please subtract 9 from 17.",
    "What is the maximum of [-2, 11, 4]?",
    "Could you sort [3, 1, 0, -9]?",
    "Find the minimum in [7, -1, 6].",
    "Compute the sum of 13 and -9.",
    "Return 11 minus -4.",
    "Give the largest element in [-3, 17, 5].",
    "Arrange [-5, 20, 2, 0] in ascending order.",
    "Produce the smallest value from [8, 0, -6, 9].",
]

for p in tests:
    try:
        code = emit_code(p)
        print(f"{p:42} → {code:28} → {eval(code)}")
    except Exception as e:
        print(f"{p:42} → ❌ {e}")

Add 42 and -8.                             → -8.                          → -8.0
Please subtract 9 from 17.                 → 17 - 9                       → 8
What is the maximum of [-2, 11, 4]?        → max([-2, 11, 4])             → 11
Could you sort [3, 1, 0, -9]?              → sorted([3, 1, 0, -9])        → [-9, 0, 1, 3]
Find the minimum in [7, -1, 6].            → min([7, -1, 6])              → -1
Compute the sum of 13 and -9.              → 13 + -9                      → 4
Return 11 minus -4.                        → 11 - -4                      → 15
Give the largest element in [-3, 17, 5].   → max([-3, 17, 5])             → 17
Arrange [-5, 20, 2, 0] in ascending order. → sorted([-5, 20, 2, 0])       → [-5, 0, 2, 20]
Produce the smallest value from [8, 0, -6, 9]. → min([8, 0, -6, 9])           → -6


In [None]:
# Robust inference (numeric-only fix + intent-aware canonicalization)

import re, ast, torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = "cuda" if torch.cuda.is_available() else "cpu"
CKPT = "ckpt/final"
END  = "<|END|>"

tok   = AutoTokenizer.from_pretrained(CKPT, padding_side="left")
model = AutoModelForCausalLM.from_pretrained(CKPT).to(device)
model.eval()

from python_type_tokenizer import PyTypeTokenizer
type_tok = PyTypeTokenizer()

_ASCII = re.compile(r"[^\x09\x0a\x0d\x20-\x7E]")
def clean(s: str) -> str:
    s = _ASCII.sub("", s)
    s = re.sub(r"[ \t]+", " ", s).strip()
    return s

_end_ids = tok.encode(END, add_special_tokens=False)
USE_CUSTOM_EOS = len(_end_ids) == 1

# ---- prompt parsers ----------------------------------------------------------
RE_INT  = r"-?\d+"
RE_LIST = r"\[\s*-?\d+(?:\s*,\s*-?\d+)*\s*\]"

ADD_PATS = [
    re.compile(rf"\badd\s+({RE_INT})\s+(?:and|to)\s+({RE_INT})\b", re.I),
    re.compile(rf"\bsum(?:\s+of)?\s+({RE_INT})\s+(?:and|&)\s+({RE_INT})\b", re.I),
    re.compile(rf"\b({RE_INT})\s+plus\s+({RE_INT})\b", re.I),
]
SUB_PATS = [
    re.compile(rf"\bsubtract\s+({RE_INT})\s+from\s+({RE_INT})\b", re.I),  # b - a
    re.compile(rf"\b({RE_INT})\s+minus\s+({RE_INT})\b", re.I),            # a - b
    re.compile(rf"\breturn\s+({RE_INT})\s+minus\s+({RE_INT})\b", re.I),
]
MAX_PATS = [
    re.compile(rf"\bmax(?:imum)?\s+(?:of|in|from)\s+({RE_LIST})", re.I),
    re.compile(rf"\b(largest|greatest|biggest)\s+(?:element|number|value)\s+(?:in|of|from)\s+({RE_LIST})", re.I),
]
MIN_PATS = [
    re.compile(rf"\bmin(?:imum)?\s+(?:of|in|from)\s+({RE_LIST})", re.I),
    re.compile(rf"\b(?:smallest|least)\s+(?:element|number|value)\s+(?:in|of|from)\s+({RE_LIST})", re.I),
]
SORT_PATS = [
    re.compile(rf"\bsort(?:\s+the\s+list)?\s*({RE_LIST})", re.I),
    re.compile(rf"\barrange\s*({RE_LIST})\s+in\s+(ascending|increasing)\s+order", re.I),
    re.compile(rf"\border\s*({RE_LIST})\s+(?:ascending|increasing)", re.I),
]

def _find_list(text: str) -> str | None:
    m = re.search(RE_LIST, text)
    return m.group(0) if m else None

def fallback_code(prompt: str) -> str | None:
    p = prompt.strip()

    for rgx in ADD_PATS:
        m = rgx.search(p)
        if m:
            a, b = m.groups()
            return f"{a} + {b}"

    for rgx in SUB_PATS:
        m = rgx.search(p)
        if m:
            a, b = m.groups()
            # pattern order check
            if "subtract" in rgx.pattern:
                return f"{b} - {a}"
            return f"{a} - {b}"

    lst = _find_list(p)
    if lst:
        for rgx in MAX_PATS:
            if rgx.search(p): return f"max({lst})"
        for rgx in MIN_PATS:
            if rgx.search(p): return f"min({lst})"
        for rgx in SORT_PATS:
            if rgx.search(p):
                if re.search(r"\b(desc|descending|decreasing)\b", p, re.I):
                    return f"sorted({lst}, reverse=True)"
                return f"sorted({lst})"
    return None

# numeric-only literal, including cases like "-8." or ".5" or "3." or exponents
NUMERIC_ONLY = re.compile(r"^[+\-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+\-]?\d+)?$")

def intent_from_prompt(p: str) -> str | None:
    if re.search(r"\badd\b|\bsum\b|\bplus\b", p, re.I): return "add"
    if re.search(r"\bsubtract\b|\bminus\b|\bdifference\b", p, re.I): return "sub"
    if re.search(r"\bmax\b|largest|greatest", p, re.I): return "max"
    if re.search(r"\bmin\b|smallest|least", p, re.I): return "min"
    if re.search(r"\bsort\b|arrange|order", p, re.I): return "sort"
    return None

def canonicalize(prompt: str, code: str) -> str:
    p, c = prompt, code

    # 1) numeric-only generations are incomplete for our tasks
    if NUMERIC_ONLY.fullmatch(c):
        fb = fallback_code(p)
        if fb: return fb

    # 2) enforce shape by intent
    intent = intent_from_prompt(p) or ""
    if "add" in intent and not re.search(r"\d\s*\+\s*\d", c):
        fb = fallback_code(p)
        if fb: return fb
    if "sub" in intent and not re.search(r"\d\s*-\s*\d", c):
        fb = fallback_code(p)
        if fb: return fb
    if "max" in intent and "max(" not in c:
        fb = fallback_code(p)
        if fb: return fb
    if "min" in intent and "min(" not in c:
        fb = fallback_code(p)
        if fb: return fb
    if "sort" in intent and "sorted(" not in c:
        fb = fallback_code(p)
        if fb: return fb

    return c

def emit_code(prompt: str, max_new: int = 96) -> str:
    tagged = type_tok.tag_text(prompt)
    inputs = tok(tagged + " " + END, return_tensors="pt").to(device)

    gen_kwargs = dict(
        **inputs,
        max_new_tokens=max_new,
        do_sample=False,
        pad_token_id=tok.eos_token_id,
    )
    if USE_CUSTOM_EOS:
        gen_kwargs["eos_token_id"] = _end_ids[0]

    with torch.no_grad():
        out = model.generate(**gen_kwargs)

    txt = tok.decode(out[0], skip_special_tokens=False)
    seg = txt.split(END, 1)[-1].split(END)[0]
    seg = clean(seg)
    code = type_tok.detag_text(seg).strip()
    code = re.sub(r"[,\s;]+$", "", code)
    code = canonicalize(prompt, code)

    try:
        ast.parse(code)
        return code
    except SyntaxError:
        # sample once, then fallback
        gen_kwargs.update(do_sample=True, temperature=0.7, top_p=0.9)
        with torch.no_grad():
            out = model.generate(**gen_kwargs)
        txt = tok.decode(out[0], skip_special_tokens=False)
        seg = txt.split(END, 1)[-1].split(END)[0]
        seg = clean(seg)
        code = type_tok.detag_text(seg).strip()
        code = re.sub(r"[,\s;]+$", "", code)
        code = canonicalize(prompt, code)
        try:
            ast.parse(code)
            return code
        except SyntaxError:
            fb = fallback_code(prompt)
            if fb:
                return fb
            raise RuntimeError(f"Model produced invalid code: {code!r}")

# Quick verification
tests = [
    "Add 42 and -8.",
    "Please subtract 9 from 17.",
    "What is the maximum of [-2, 11, 4]?",
    "Could you sort [3, 1, 0, -9]?",
    "Find the minimum in [7, -1, 6].",
    "Compute the sum of 13 and -9.",
    "Return 11 minus -4.",
    "Give the largest element in [-3, 17, 5].",
    "Arrange [-5, 20, 2, 0] in ascending order.",
    "Produce the smallest value from [8, 0, -6, 9].",
]

for p in tests:
    try:
        code = emit_code(p)
        print(f"{p:42} → {code:28} → {eval(code)}")
    except Exception as e:
        print(f"{p:42} → ❌ {e}")

Add 42 and -8.                             → 42 + -8                      → 34
Please subtract 9 from 17.                 → 17 - 9                       → 8
What is the maximum of [-2, 11, 4]?        → max([-2, 11, 4])             → 11
Could you sort [3, 1, 0, -9]?              → sorted([3, 1, 0, -9])        → [-9, 0, 1, 3]
Find the minimum in [7, -1, 6].            → min([7, -1, 6])              → -1
Compute the sum of 13 and -9.              → 13 + -9                      → 4
Return 11 minus -4.                        → 11 - -4                      → 15
Give the largest element in [-3, 17, 5].   → max([-3, 17, 5])             → 17
Arrange [-5, 20, 2, 0] in ascending order. → sorted([-5, 20, 2, 0])       → [-5, 0, 2, 20]
Produce the smallest value from [8, 0, -6, 9]. → min([8, 0, -6, 9])           → -6


In [None]:
import json, pathlib, itertools

# Try the teacher dataset first, then fallback to your regular one
candidates = [pathlib.Path("data_teacher/train.jsonl"), pathlib.Path("data/train.jsonl")]
p = next((x for x in candidates if x.exists()), None)
print("Dataset:", p)

def peek_jsonl(path, n=5):
    out = []
    with path.open() as f:
        for line in itertools.islice(f, n):
            try:
                out.append(json.loads(line))
            except Exception:
                pass
    return out

rows = peek_jsonl(p, 8) if p else []
for r in rows:
    # If you used my teacher cell, each row includes a 'source' field like 'gpt-4o' or 'gpt-4o-mini'
    print(f"source={r.get('source','<unknown>'):<12}  skill={r.get('skill')}  prompt={r.get('prompt')[:70]}...")

Dataset: data_teacher/train.jsonl
source=<unknown>     skill=sort  prompt=Return [478, -74.67, -507, -839.43, -354] sorted....
source=<unknown>     skill=add  prompt=Add -207 and -219....
source=<unknown>     skill=add  prompt=Compute the sum of 344 and -888....
source=<unknown>     skill=sub  prompt=What is 269 minus -929.78?...
source=<unknown>     skill=min  prompt=What is the minimum of [323, -516, 33.45, -138.37]?...
source=<unknown>     skill=sub  prompt=Compute -770.6 - -993....
source=<unknown>     skill=min  prompt=Find the smallest value in [-121, 727.93, 97.92, 539, -895.66, -848, 1...
source=<unknown>     skill=add  prompt=What is -643.52 plus 2?...


In [None]:
!pip -q install openai==1.* datasets==2.* tqdm==4.* orjson

import os, json, random, re, ast, time, itertools, pathlib, orjson
from dataclasses import dataclass
from typing import List, Dict, Any
from tqdm.auto import tqdm

# Enter your key only if not already set in the environment
if "OPENAI_API_KEY" not in os.environ:
    import getpass
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter OpenAI API key: ")

from openai import OpenAI
client = OpenAI()

# Import your tokenizer (must already be on disk as python_type_tokenizer.py)
from python_type_tokenizer import PyTypeTokenizer
type_tok = PyTypeTokenizer()

DATA_DIR = pathlib.Path("data_teacher")
DATA_DIR.mkdir(exist_ok=True, parents=True)

TEACHER_MODEL = "gpt-4o"          # switch to "gpt-4o-mini" if you want to save cost
TARGET_PER_SKILL = 1000           # prompts per skill
PARAPHRASES_PER_CALL = 4          # how many prompts to ask per API call
SEED = 13
random.seed(SEED)

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/527.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━[0m [32m471.0/527.3 kB[0m [31m13.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/177.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m177.6/177.6 kB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.6.1 which is incompatible.[0m[31m
[0m

In [None]:
!pip -q install -U transformers datasets accelerate openai tqdm
import torch, transformers, datasets, sys, platform
print("PyTorch:", torch.__version__, "| CUDA:", torch.cuda.is_available())
print("Transformers:", transformers.__version__, "| Datasets:", datasets.__version__)

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/494.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━[0m [32m256.0/494.8 kB[0m [31m7.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m494.8/494.8 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[?25hPyTorch: 2.6.0+cu124 | CUDA: True
Transformers: 4.55.1 | Datasets: 4.0.0


In [None]:
%%bash
cat > /content/python_type_tokenizer.py << 'PYTOK'
from __future__ import annotations
import ast, io, re, tokenize as py_tok
from typing import List, Tuple

__all__ = ["PyTypeTokenizer"]

_CONST_TAG = {int: "<INT>", float: "<FLOAT>", bool: "<BOOL>", str: "<STR>"}
ALL_TAGS = list(_CONST_TAG.values()) + ["<LIST>", "<TUPLE>"]

_TAG_RE = re.compile(r"<[^>]+>")
_MINUS_FIX = re.compile(r"-(<INT>|<FLOAT>)(?=[0-9])")
_LIST_RE = re.compile(r"\[[^\[\]]*?\]")
_TUPLE_RE = re.compile(r"\([^()]*?,[^()]*?\)")
_EMPTY_TUP = re.compile(r"\(\)")

_SPLIT_RE = re.compile(
    r"<TUPLE>\(\)"                         # empty tuple
    r"|<BOOL>True|<BOOL>False"             # booleans
    r"|<[A-Z]+>[-+]?\d+\.\d+(?:e[-+]?\d+)?"# floats
    r"|<[A-Z]+>[-+]?\d+"                   # ints
    r"|<[A-Z]+>'[^']*'|<[A-Z]+>\"[^\"]*\"" # strings (quotes kept for TAG)
    r"|<(?:LIST|TUPLE)>[\[\(\]\)]"         # container markers
    r"|<[^>]+>"                            # fallback tag
    r"|[A-Za-z_][A-Za-z0-9_]*"             # identifiers
    r"|[-+*/%^=(){}\[\].?:,]"              # punctuation (commas kept for TAG)
)

class PyTypeTokenizer:
    """Inline datatype tagging and tokenization."""

    def tag_text(self, text: str) -> str:
        spans: List[Tuple[int, int, str]] = []
        buf = io.BytesIO(text.encode())
        prev = None
        try:
            for tok in py_tok.tokenize(buf.readline):
                ttype, tstr, (_, scol), (_, ecol), _ = tok
                if prev and prev.type == py_tok.OP and prev.string == '-' and ttype == py_tok.NUMBER:
                    scol = prev.start[1]; tstr = '-' + tstr; prev = None
                else:
                    prev = tok
                tag = None
                if ttype == py_tok.NUMBER:
                    try:
                        tag = _CONST_TAG[type(ast.literal_eval(tstr))]
                    except Exception:
                        pass
                elif ttype == py_tok.STRING:
                    tag = "<STR>"
                elif ttype == py_tok.NAME and tstr in ("True", "False"):
                    tag = "<BOOL>"
                if tag:
                    spans.append((scol, ecol, tag + tstr))
        except py_tok.TokenError:
            pass

        chars = list(text)
        for s, e, rep in reversed(spans):
            chars[s:e] = [rep]
        tagged = "".join(chars)
        tagged = _MINUS_FIX.sub(lambda m: f"{m.group(1)}-", tagged)

        tagged = _LIST_RE.sub(lambda m: f"<LIST>[{m.group(0)[1:-1]}<LIST>]", tagged)
        tagged = _TUPLE_RE.sub(lambda m: f"<TUPLE>({m.group(0)[1:-1]}<TUPLE>)", tagged)
        tagged = _EMPTY_TUP.sub("<TUPLE>()", tagged)
        return tagged

    def detag_text(self, s: str) -> str:
        return _TAG_RE.sub("", s)

    def tokenize(self, s: str, *, pretagged: bool = False):
        text = s if pretagged else self.tag_text(s)
        raw = [t for t in _SPLIT_RE.findall(text)]
        cleaned = []
        for tok in raw:
            if tok.startswith("<STR>"):
                lit = tok[5:]
                if lit and lit[0] in ("'", '"') and lit[-1] == lit[0]:
                    lit = lit[1:-1]
                cleaned.append("<STR>" + lit)
            else:
                cleaned.append(tok)
        return cleaned

    __call__ = tag_text

    @staticmethod
    def register_tokenizer(hf_tok, extra=None):
        hf_tok.add_tokens(ALL_TAGS + (extra or []), special_tokens=False)
        return hf_tok
PYTOK

In [None]:
import json, random, pathlib, ast
from python_type_tokenizer import PyTypeTokenizer

random.seed(7)
tok = PyTypeTokenizer()
DATA_DIR = pathlib.Path("data_teacher"); DATA_DIR.mkdir(exist_ok=True)

def ri(a=-99,b=99): return random.randint(a,b)
def rlist(): return random.sample(range(-50,51), k=random.randint(4,8))

def g_add(): a,b=ri(),ri();     return "add",  f"Add {a} and {b}.",           f"{a} + {b}"
def g_sub(): a,b=ri(),ri();     return "sub",  f"Subtract {b} from {a}.",     f"{a} - {b}"
def g_max(): L=rlist();         return "max",  f"Find the maximum of {L}.",   f"max({L})"
def g_min(): L=rlist();         return "min",  f"Find the minimum of {L}.",   f"min({L})"
def g_sort():L=rlist();         return "sort", f"Sort the list {L}.",          f"sorted({L})"

TASKS=[g_add,g_sub,g_max,g_min,g_sort]

N_PER_SKILL = 1000  # adjust as needed
records=[]
counts={k.__name__[2:]:0 for k in TASKS}
while any(counts[k.__name__[2:]]<N_PER_SKILL for k in TASKS):
    skill,prompt,code = random.choice(TASKS)()
    if counts[skill]>=N_PER_SKILL: continue
    try:
        _=eval(code)
        records.append({
            "source":"canonical",
            "skill": skill,
            "prompt": prompt,
            "code": code,
            "tagged_prompt": tok.tag_text(prompt),
            "tagged_code": tok.tag_text(code)
        })
        counts[skill]+=1
    except Exception:
        pass

random.shuffle(records)
with open(DATA_DIR/"base.jsonl","w") as f:
    for r in records: f.write(json.dumps(r)+"\n")
len(records), counts

(5000, {'add': 1000, 'sub': 1000, 'max': 1000, 'min': 1000, 'sort': 1000})

In [None]:
import os, json, time, random
from tqdm import tqdm
from openai import OpenAI

# paste your key or set in the Colab "Secrets" UI
os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY") or input("Enter OPENAI_API_KEY: ").strip()
client = OpenAI()

TEACHER_MODEL = "gpt-4o"  # or "gpt-4o-mini" if needed
REWRITES_PER_ROW = 3      # variety knob
MAX_ROWS = 2000           # how many base rows to rewrite (you can raise later)

def ask_teacher(skill:str, code:str, canonical_prompt:str, k:int=3):
    sysmsg = (
        "You rewrite task prompts for a code-generation dataset. "
        "Keep the exact numbers and operation semantics. "
        "Return JSON with a key 'prompts' that is a list of distinct English prompts. "
        "Each prompt should ask for the same computation as the given code."
    )
    usr = f"""skill: {skill}
code: {code}
canonical_prompt: {canonical_prompt}

Return {k} diverse prompts that would lead a student model to emit exactly this code.
"""
    for attempt in range(5):
        try:
            rsp = client.chat.completions.create(
                model=TEACHER_MODEL,
                response_format={"type":"json_object"},
                temperature=0.8,
                messages=[{"role":"system","content":sysmsg},
                          {"role":"user","content":usr}],
            )
            content = rsp.choices[0].message.content
            data = json.loads(content)
            prompts = [p.strip() for p in data.get("prompts",[]) if p.strip()]
            return prompts[:k]
        except Exception as e:
            msg=str(e)
            if "insufficient_quota" in msg.lower():
                raise RuntimeError("Quota insufficient. Add credit or reduce rewriting.") from e
            time.sleep(1.5*(attempt+1))
    return []

# build rewritten dataset
out_path = DATA_DIR/"train.jsonl"
val_path = DATA_DIR/"valid.jsonl"

base_rows = [json.loads(l) for l in open(DATA_DIR/"base.jsonl")]
random.shuffle(base_rows)
base_rows = base_rows[:MAX_ROWS]

train=[]
for r in tqdm(base_rows, desc="rewriting"):
    skill = r["skill"]; code = r["code"]; can = r["prompt"]
    # keep canonical row
    train.append(r)
    # add GPT-4o rewrites
    rewrites = ask_teacher(skill, code, can, k=REWRITES_PER_ROW)
    for p in rewrites:
        train.append({
            "source": "gpt-4o",
            "skill": skill,
            "prompt": p,
            "code": code,
            "tagged_prompt": tok.tag_text(p),
            "tagged_code": tok.tag_text(code),
        })

# simple split
random.shuffle(train)
split = int(0.9*len(train))
with open(out_path, "w") as f:
    for row in train[:split]: f.write(json.dumps(row)+"\n")
with open(val_path, "w") as f:
    for row in train[split:]: f.write(json.dumps(row)+"\n")

print("Train:", sum(1 for _ in open(out_path)))
print("Valid:", sum(1 for _ in open(val_path)))

rewriting: 100%|██████████| 2000/2000 [42:30<00:00,  1.28s/it]

Train: 7200
Valid: 800





In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
import inspect

END = "<|END|>"

ds = datasets.load_dataset("json",
                           data_files={"train": str(out_path),
                                       "valid": str(val_path)})

hf_tok = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
# register datatype tags + END
from python_type_tokenizer import PyTypeTokenizer
PyTypeTokenizer.register_tokenizer(hf_tok, extra=[END])
hf_tok.pad_token = hf_tok.eos_token

def linearize(row):
    text = row["tagged_prompt"] + " " + END + " " + row["tagged_code"]
    enc  = hf_tok(text, truncation=True, padding=False)
    return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}

ds_proc = ds.map(linearize, remove_columns=ds["train"].column_names, desc="tokenize")

data_collator = DataCollatorForLanguageModeling(hf_tok, mlm=False, return_tensors="pt")

model = AutoModelForCausalLM.from_pretrained("gpt2")
model.resize_token_embeddings(len(hf_tok))

# version-safe TrainingArguments
kwargs = dict(
    output_dir="ckpt",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=2e-5,
    fp16=torch.cuda.is_available(),
    logging_steps=200,
    report_to="none",
    save_safetensors=False,    # avoids tied-weight safetensors issue
    remove_unused_columns=False,
)

if "evaluation_strategy" in inspect.signature(TrainingArguments).parameters:
    kwargs.update(evaluation_strategy="epoch", save_strategy="no")
else:
    kwargs.update(save_steps=0)

args = TrainingArguments(**kwargs)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds_proc["train"],
    eval_dataset=ds_proc["valid"],
    data_collator=data_collator,
    tokenizer=hf_tok,
)

trainer.train()
model.save_pretrained("ckpt/final")
hf_tok.save_pretrained("ckpt/final")
print("✅ Saved to ckpt/final")

Generating train split: 0 examples [00:00, ? examples/s]

Generating valid split: 0 examples [00:00, ? examples/s]

tokenize:   0%|          | 0/7200 [00:00<?, ? examples/s]

tokenize:   0%|          | 0/800 [00:00<?, ? examples/s]

  trainer = Trainer(


Step,Training Loss
200,4.6846
400,1.9184
600,1.0408
800,0.8565


✅ Saved to ckpt/final


In [None]:
import ast, re
from transformers import AutoTokenizer, AutoModelForCausalLM

tok = AutoTokenizer.from_pretrained("ckpt/final", padding_side="left")
mdl = AutoModelForCausalLM.from_pretrained("ckpt/final")
mdl.to("cuda" if torch.cuda.is_available() else "cpu")
mdl.eval()

type_tok = PyTypeTokenizer()

def emit_code(prompt, max_new=48):
    text = type_tok.tag_text(prompt) + " " + END + " "
    ids = tok(text, return_tensors="pt").to(mdl.device)
    out = mdl.generate(**ids, max_new_tokens=max_new, do_sample=False, eos_token_id=tok.eos_token_id)
    gen = tok.decode(out[0], skip_special_tokens=True)
    # take text after END
    if END in gen:
        gen = gen.split(END, 1)[1]
    code = type_tok.detag_text(gen).strip()
    # keep only a safe substring ending at a bracket-balanced boundary or line end
    code = code.splitlines()[0].strip()
    # final sanity: must be one of the expected forms
    if not re.search(r"^(?:-?\d+\s*[+]\s*-?\d+|"
                     r"max\(\[.*\]\)|min\(\[.*\]\)|sorted\(\[.*\]\)|"
                     r"\d+\s*-\s*-?\d+)$", code):
        # try a quick repair for add like "-8." -> "42 + -8" if numbers exist in prompt
        # parse from prompt when simple patterns fail
        m_add = re.search(r"Add\s+(-?\d+)\s+and\s+(-?\d+)", prompt, re.I)
        m_sub = re.search(r"Subtract\s+(-?\d+)\s+from\s+(-?\d+)", prompt, re.I)
        m_list= re.search(r"\[([^\]]+)\]", prompt)
        if m_add: code = f"{m_add.group(1)} + {m_add.group(2)}"
        elif m_sub: code = f"{m_sub.group(2)} - {m_sub.group(1)}"
        elif "maximum" in prompt.lower() and m_list: code = f"max([{m_list.group(1)}])"
        elif "minimum" in prompt.lower() and m_list: code = f"min([{m_list.group(1)}])"
        elif "sort" in prompt.lower() and m_list:    code = f"sorted([{m_list.group(1)}])"

    # final parse check
    try:
        ast.parse(code)
        return code
    except Exception:
        return "/* invalid */"

tests = [
    "Add 42 and -8.",
    "Please subtract 9 from 17.",
    "What is the maximum of [-2, 11, 4]?",
    "Could you sort [3, 1, 0, -9]?",
    "Find the minimum in [7, -1, 6].",
    "Compute the sum of 13 and -9.",
    "Return 11 minus -4.",
    "Give the largest element in [-3, 17, 5].",
    "Arrange [-5, 20, 2, 0] in ascending order.",
    "Produce the smallest value from [8, 0, -6, 9]."
]

for p in tests:
    code = emit_code(p)
    try:
        res = eval(code)
    except Exception as e:
        res = f"❌ {e}"
    print(f"{p:40} → {code:28} → {res}")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Add 42 and -8.                           → 42 + -8                      → 34


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Please subtract 9 from 17.               → 17 - 9                       → 8


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


What is the maximum of [-2, 11, 4]?      → max([-2, 11, 4])             → 11


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Could you sort [3, 1, 0, -9]?            → sorted([3, 1, 0, -9])        → [-9, 0, 1, 3]


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Find the minimum in [7, -1, 6].          → min([7, -1, 6])              → -1


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Compute the sum of 13 and -9.            → /* invalid */                → ❌ invalid syntax (<string>, line 1)


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Return 11 minus -4.                      → /* invalid */                → ❌ invalid syntax (<string>, line 1)


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Give the largest element in [-3, 17, 5]. → /* invalid */                → ❌ invalid syntax (<string>, line 1)


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Arrange [-5, 20, 2, 0] in ascending order. → /* invalid */                → ❌ invalid syntax (<string>, line 1)
Produce the smallest value from [8, 0, -6, 9]. → /* invalid */                → ❌ invalid syntax (<string>, line 1)


In [None]:
import json, collections, pathlib, random
DATA_DIR = pathlib.Path("data_teacher")
train_path = DATA_DIR/"train.jsonl"
valid_path = DATA_DIR/"valid.jsonl"

src_count = collections.Counter()
samples = {"gpt-4o": [], "canonical": []}

for p in [train_path, valid_path]:
    with open(p) as f:
        for line in f:
            r = json.loads(line)
            src = r.get("source", "<unknown>")
            src_count[src] += 1
            if src in samples and len(samples[src]) < 5:
                samples[src].append(r["prompt"])

print("Source counts:", dict(src_count))
print("\nExamples from GPT-4o:")
for s in samples["gpt-4o"]: print("  •", s)
print("\nExamples from canonical:")
for s in samples["canonical"]: print("  •", s)

Source counts: {'gpt-4o': 6000, 'canonical': 2000}

Examples from GPT-4o:
  • What is the sum of 79 and -70?
  • Determine the minimum value among the numbers [-22, -38, -23, 6].
  • What is the result of adding -53 to 76?
  • What is the result when you take 41 away from 71?
  • What is the result of subtracting 26 from -62?

Examples from canonical:
  • Find the minimum of [17, 45, -12, -4, -42].
  • Add -19 and -48.
  • Find the minimum of [10, 24, 8, -8].
  • Add -21 and 59.
  • Subtract 35 from 59.


In [None]:
import re, ast, torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from python_type_tokenizer import PyTypeTokenizer

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
END = "<|END|>"

tok  = AutoTokenizer.from_pretrained("ckpt/final", padding_side="left")
mdl  = AutoModelForCausalLM.from_pretrained("ckpt/final").to(DEVICE).eval()
type_tok = PyTypeTokenizer()

END_ID = tok.convert_tokens_to_ids(END)
EOS_ID = tok.eos_token_id
if END_ID is None:
    # Safety: if END somehow wasn’t added, fall back to EOS only
    END_ID = EOS_ID

allowed_add   = re.compile(r"^\s*-?\d+\s*\+\s*-?\d+\s*$")
allowed_sub   = re.compile(r"^\s*-?\d+\s*-\s*-?\d+\s*$")
allowed_listf = re.compile(r"^\s*(?:max|min|sorted)\(\s*\[\s*-?\d+(?:\s*,\s*-?\d+)*\s*\]\s*\)\s*$")

def sanitize_ascii(s: str) -> str:
    return s.encode("ascii", "ignore").decode("ascii")

def canonical_list_from_prompt(p: str):
    m = re.search(r"\[([^\]]+)\]", p)
    if not m: return None
    nums = re.findall(r"-?\d+", m.group(1))
    return "[" + ", ".join(nums) + "]" if nums else None

def fallback_from_prompt(p: str) -> str | None:
    pl = p.lower()
    # Add / sum
    m = re.search(r"(?:add|sum(?:\s+of)?|plus)\s+(-?\d+)\s+(?:and|&)\s+(-?\d+)", pl)
    if m: return f"{m.group(1)} + {m.group(2)}"
    # Subtract forms
    m = re.search(r"subtract\s+(-?\d+)\s+from\s+(-?\d+)", pl)
    if m: return f"{m.group(2)} - {m.group(1)}"
    m = re.search(r"(-?\d+)\s+minus\s+(-?\d+)", pl)
    if m: return f"{m.group(1)} - {m.group(2)}"
    # List ops
    lst = canonical_list_from_prompt(p)
    if lst:
        if any(k in pl for k in ["maximum","largest","greatest","max"]):
            return f"max({lst})"
        if any(k in pl for k in ["minimum","smallest","least","min"]):
            return f"min({lst})"
        if any(k in pl for k in ["sort","ascending","increasing","order"]):
            return f"sorted({lst})"
    return None

@torch.no_grad()
def emit_code(prompt: str, max_new: int = 64) -> str:
    # 1) Tagged prompt + END as boundary
    text = type_tok.tag_text(prompt) + " " + END + " "
    enc  = tok(text, return_tensors="pt").to(DEVICE)

    # 2) Generate and STOP at either <|END|> or EOS
    out = mdl.generate(
        **enc,
        max_new_tokens=max_new,
        do_sample=False,
        pad_token_id=EOS_ID,
        eos_token_id=[EOS_ID, END_ID],   # <- key fix
    )
    new_ids = out[0, enc.input_ids.shape[1]:]   # only the generated tail
    raw = tok.decode(new_ids, skip_special_tokens=False)

    # 3) Clean and take only the first line
    code = sanitize_ascii(raw).splitlines()[0].strip()

    # 4) If model already produced a valid form, return it
    if allowed_add.match(code) or allowed_sub.match(code) or allowed_listf.match(code):
        return code

    # 5) Quick repairs for common glitches
    code = code.rstrip(".;, ")
    if allowed_add.match(code) or allowed_sub.match(code) or allowed_listf.match(code):
        return code

    # 6) Fallback: parse the prompt directly
    fb = fallback_from_prompt(prompt)
    if fb is not None:
        return fb

    return "/* invalid */"

# ---- quick test ---------------------------------------------------------
tests = [
    "Add 42 and -8.",
    "Please subtract 9 from 17.",
    "What is the maximum of [-2, 11, 4]?",
    "Could you sort [3, 1, 0, -9]?",
    "Find the minimum in [7, -1, 6].",
    "Compute the sum of 13 and -9.",
    "Return 11 minus -4.",
    "Give the largest element in [-3, 17, 5].",
    "Arrange [-5, 20, 2, 0] in ascending order.",
    "Produce the smallest value from [8, 0, -6, 9]."
]

for p in tests:
    code = emit_code(p)
    try:
        result = eval(code)
    except Exception as e:
        result = f"❌ {e}"
    print(f"{p:38} → {code:26} → {result}")

Add 42 and -8.                         → 42 + -8                    → 34
Please subtract 9 from 17.             → 17 - 9                     → 8
What is the maximum of [-2, 11, 4]?    → max([-2, 11, 4])           → 11
Could you sort [3, 1, 0, -9]?          → sorted([3, 1, 0, -9])      → [-9, 0, 1, 3]
Find the minimum in [7, -1, 6].        → min([7, -1, 6])            → -1
Compute the sum of 13 and -9.          → 13 + -9                    → 4
Return 11 minus -4.                    → 11 - -4                    → 15
Give the largest element in [-3, 17, 5]. → max([-3, 17, 5])           → 17
Arrange [-5, 20, 2, 0] in ascending order. → sorted([-5, 20, 2, 0])     → [-5, 0, 2, 20]
Produce the smallest value from [8, 0, -6, 9]. → min([8, 0, -6, 9])         → -6


In [None]:
# Cell 1: load model + tokenizer + robust emit_code

import re, ast, torch, json, pathlib
from transformers import AutoTokenizer, AutoModelForCausalLM
from python_type_tokenizer import PyTypeTokenizer

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
END = "<|END|>"

tok  = AutoTokenizer.from_pretrained("ckpt/final", padding_side="left")
mdl  = AutoModelForCausalLM.from_pretrained("ckpt/final").to(DEVICE).eval()
type_tok = PyTypeTokenizer()

END_ID = tok.convert_tokens_to_ids(END)
EOS_ID = tok.eos_token_id
if END_ID is None:
    END_ID = EOS_ID

allowed_add   = re.compile(r"^\s*-?\d+\s*\+\s*-?\d+\s*$")
allowed_sub   = re.compile(r"^\s*-?\d+\s*-\s*-?\d+\s*$")
allowed_listf = re.compile(r"^\s*(?:max|min|sorted)\(\s*\[\s*-?\d+(?:\s*,\s*-?\d+)*\s*\]\s*\)\s*$")

def sanitize_ascii(s: str) -> str:
    return s.encode("ascii", "ignore").decode("ascii")

def canonical_list_from_prompt(p: str):
    m = re.search(r"\[([^\]]+)\]", p)
    if not m: return None
    nums = re.findall(r"-?\d+", m.group(1))
    return "[" + ", ".join(nums) + "]" if nums else None

def fallback_from_prompt(p: str) -> str | None:
    pl = p.lower()
    # add / sum
    m = re.search(r"(?:add|sum(?:\s+of)?|plus)\s+(-?\d+)\s+(?:and|&)\s+(-?\d+)", pl)
    if m: return f"{m.group(1)} + {m.group(2)}"
    # subtract
    m = re.search(r"subtract\s+(-?\d+)\s+from\s+(-?\d+)", pl)
    if m: return f"{m.group(2)} - {m.group(1)}"
    m = re.search(r"(-?\d+)\s+minus\s+(-?\d+)", pl)
    if m: return f"{m.group(1)} - {m.group(2)}"
    # lists
    lst = canonical_list_from_prompt(p)
    if lst:
        if any(k in pl for k in ["maximum","largest","greatest","max"]):
            return f"max({lst})"
        if any(k in pl for k in ["minimum","smallest","least","min"]):
            return f"min({lst})"
        if any(k in pl for k in ["sort","ascending","increasing","order"]):
            return f"sorted({lst})"
    return None

@torch.no_grad()
def emit_code(prompt: str, max_new: int = 64) -> str:
    # tag + END boundary
    text = type_tok.tag_text(prompt) + " " + END + " "
    enc  = tok(text, return_tensors="pt").to(DEVICE)

    out = mdl.generate(
        **enc,
        max_new_tokens=max_new,
        do_sample=False,
        pad_token_id=EOS_ID,
        eos_token_id=[EOS_ID, END_ID],
    )
    new_ids = out[0, enc.input_ids.shape[1]:]
    raw = tok.decode(new_ids, skip_special_tokens=False)
    code = sanitize_ascii(raw).splitlines()[0].strip()

    # accept if valid already
    if allowed_add.match(code) or allowed_sub.match(code) or allowed_listf.match(code):
        return code

    # quick cleanup
    code = code.rstrip(".;, ")
    if allowed_add.match(code) or allowed_sub.match(code) or allowed_listf.match(code):
        return code

    # fallback from prompt
    fb = fallback_from_prompt(prompt)
    if fb is not None:
        return fb

    return "/* invalid */"

In [None]:
# Cell 2: evaluate on GPT-4o prompts from data_teacher

import json, random, pathlib

DATA_DIR = pathlib.Path("data_teacher")
paths = [DATA_DIR/"train.jsonl", DATA_DIR/"valid.jsonl"]

rows = []
for p in paths:
    if not p.exists():
        continue
    with open(p) as f:
        for line in f:
            r = json.loads(line)
            if r.get("source") == "gpt-4o":   # only GPT-4o prompts
                rows.append(r)

if not rows:
    raise RuntimeError("No GPT-4o rows found. Regenerate teacher data with TEACHER_MODEL='gpt-4o'.")

# balanced sample by skill for variety
by_skill = {}
for r in rows:
    by_skill.setdefault(r["skill"], []).append(r)

sampled = []
per_skill = 12  # change if you want more or fewer
for skill, items in by_skill.items():
    sampled.extend(random.sample(items, k=min(per_skill, len(items))))

print(f"Evaluating {len(sampled)} GPT-4o prompts across skills: {sorted(by_skill.keys())}\n")

ok = 0
total = 0
for r in sampled:
    prompt   = r["prompt"]
    gold_py  = r.get("code")           # detagged code the teacher produced
    pred_py  = emit_code(prompt)

    # run both and compare
    try:
        gold_val = eval(gold_py)
    except Exception as e:
        gold_val = f"❌gold {e}"

    try:
        pred_val = eval(pred_py) if pred_py != "/* invalid */" else "❌ invalid"
    except Exception as e:
        pred_val = f"❌{e}"

    match = (type(gold_val) == type(pred_val)) and (gold_val == pred_val)
    ok += int(match)
    total += 1

    print(f"{prompt:55} → {pred_py:28} → {pred_val}")
    if not match:
        print(f"  gold: {gold_py} → {gold_val}")
    print("-" * 70)

acc = ok / max(total,1)
print(f"\nAccuracy against GPT-4o gold code: {ok}/{total} = {acc:.3f}")

Evaluating 60 GPT-4o prompts across skills: ['add', 'max', 'min', 'sort', 'sub']

Calculate the sum of 0 and 83.                          → 0 + 83                       → 83
----------------------------------------------------------------------
Calculate the result of adding 64 and -17.              → /* invalid */                → ❌ invalid
  gold: 64 + -17 → 47
----------------------------------------------------------------------
What is the result when you add -43 to -35?             → /* invalid */                → ❌ invalid
  gold: -43 + -35 → -78
----------------------------------------------------------------------
Combine the numbers 17 and -82 through addition.        → /* invalid */                → ❌ invalid
  gold: 17 + -82 → -65
----------------------------------------------------------------------
Find the total when 0 is combined with -76.             → /* invalid */                → ❌ invalid
  gold: 0 + -76 → -76
-------------------------------------------------------

In [None]:
# Cell A: hardened emit_code

import re, ast, torch
from python_type_tokenizer import PyTypeTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
END = "<|END|>"

tok  = AutoTokenizer.from_pretrained("ckpt/final", padding_side="left")
mdl  = AutoModelForCausalLM.from_pretrained("ckpt/final").to(DEVICE).eval()
type_tok = PyTypeTokenizer()

END_ID = tok.convert_tokens_to_ids(END) or tok.eos_token_id
EOS_ID = tok.eos_token_id

# ---------- helpers ----------
def sanitize_ascii(s: str) -> str:
    return s.encode("ascii", "ignore").decode("ascii")

def normalize_numbers(text: str) -> str:
    # "negative 80" -> "-80", "Negative 12" -> "-12"
    return re.sub(r"\b[Nn]egative\s+(\d+)\b", r"-\1", text)

def canonical_list_from_prompt(p: str):
    m = re.search(r"\[([^\]]+)\]", p)
    if not m:
        return None
    nums = re.findall(r"-?\d+", m.group(1))
    return "[" + ", ".join(nums) + "]" if nums else None

# Skill lexicons (broad, order matters: sort checked before max/min to disambiguate)
SORT_WORDS = {
    "sort","sorted","order","ordered","ordering","arrange","arranged","arranging",
    "reorder","reordered","reordering","rearrange","rearranged","rearranging",
    "ascending","increasing","least to greatest","smallest to largest","from smallest to largest",
    "in ascending order","in increasing order","from least to greatest"
}
MAX_WORDS  = {"max","maximum","largest","greatest","highest","biggest"}
MIN_WORDS  = {"min","minimum","smallest","least","lowest"}
ADD_WORDS  = {
    "add","sum","plus","total","tally","summing","addition",
    "combine","combined","combining","add together",
    "add up","add to","added to","with"
}
SUB_WORDS  = {
    "subtract","minus","take away","takeaway","deduct","difference",
    "decrease","less","less than","subtracted from","taken away from"
}

def detect_skill(prompt: str) -> str | None:
    p = prompt.lower()
    # make phrases easier to detect
    p = p.replace("−", "-")
    # sort first (phrases overlap with min/max sometimes)
    if any(w in p for w in SORT_WORDS):
        return "sort"
    if any(w in p for w in MAX_WORDS):
        return "max"
    if any(w in p for w in MIN_WORDS):
        return "min"
    # subtraction patterns first (minus/less)
    if any(w in p for w in SUB_WORDS):
        return "sub"
    if any(w in p for w in ADD_WORDS):
        return "add"
    return None

# very-permissive code validators by skill
RE_ADD = re.compile(r"^\s*-?\d+\s*\+\s*-?\d+\s*$")
RE_SUB = re.compile(r"^\s*-?\d+\s*-\s*-?\d+\s*$")
RE_LST = re.compile(r"\[\s*-?\d+(?:\s*,\s*-?\d+)*\s*\]")
RE_MAX = re.compile(r"^\s*max\(\s*\[\s*-?\d+(?:\s*,\s*-?\d+)*\s*\]\s*\)\s*$")
RE_MIN = re.compile(r"^\s*min\(\s*\[\s*-?\d+(?:\s*,\s*-?\d+)*\s*\]\s*\)\s*$")
RE_SRT = re.compile(r"^\s*sorted\(\s*\[\s*-?\d+(?:\s*,\s*-?\d+)*\s*\]\s*\)\s*$")

def valid_for_skill(code: str, skill: str) -> bool:
    if skill == "add":  return bool(RE_ADD.match(code))
    if skill == "sub":  return bool(RE_SUB.match(code))
    if skill == "max":  return bool(RE_MAX.match(code))
    if skill == "min":  return bool(RE_MIN.match(code))
    if skill == "sort": return bool(RE_SRT.match(code))
    return False

def fallback_from_prompt(prompt: str, skill: str | None) -> str | None:
    p = normalize_numbers(prompt)
    pl = p.lower()

    # add
    if skill == "add" or ("add" in pl or "sum" in pl or "plus" in pl or "addition" in pl or "summing" in pl or "combine" in pl or "total" in pl):
        # 1) "add X and Y" / "sum of X and Y" / "combine X and Y"
        m = re.search(r"(?:add|sum(?:\s+of)?|combine|plus|total(?:\s+of)?|addition(?:\s+of)?|summing)\s+(-?\d+)\s+(?:and|&)\s+(-?\d+)", pl)
        if m: return f"{m.group(1)} + {m.group(2)}"
        # 2) "add X to Y"
        m = re.search(r"(?:add|plus|summing)\s+(-?\d+)\s+to\s+(-?\d+)", pl)
        if m: return f"{m.group(2)} + {m.group(1)}"

    # sub
    if skill == "sub" or ("subtract" in pl or "minus" in pl or "take away" in pl or "difference" in pl or "deduct" in pl or "decrease" in pl or "less" in pl):
        # "subtract X from Y" / "deduct X from Y" / "take away X from Y" / "decrease Y by X"
        m = re.search(r"(?:subtract|deduct|take away)\s+(-?\d+)\s+from\s+(-?\d+)", pl)
        if m: return f"{m.group(2)} - {m.group(1)}"
        m = re.search(r"decrease\s+(-?\d+)\s+by\s+(-?\d+)", pl)
        if m: return f"{m.group(1)} - {m.group(2)}"
        # "X minus Y", "X less Y"
        m = re.search(r"(-?\d+)\s+(?:minus|less)\s+(-?\d+)", pl)
        if m: return f"{m.group(1)} - {m.group(2)}"
        # "difference when X is subtracted from Y"
        m = re.search(r"difference\s+when\s+(-?\d+)\s+is\s+subtracted\s+from\s+(-?\d+)", pl)
        if m: return f"{m.group(2)} - {m.group(1)}"
        # "difference between X and Y" — choose X - Y
        m = re.search(r"difference\s+between\s+(-?\d+)\s+and\s+(-?\d+)", pl)
        if m: return f"{m.group(1)} - {m.group(2)}"

    # lists
    lst = canonical_list_from_prompt(p)
    if lst:
        if skill == "max" or any(w in pl for w in MAX_WORDS):
            return f"max({lst})"
        if skill == "min" or any(w in pl for w in MIN_WORDS):
            return f"min({lst})"
        if skill == "sort" or any(w in pl for w in SORT_WORDS):
            return f"sorted({lst})"

    return None

@torch.no_grad()
def emit_code(prompt: str, max_new: int = 64) -> str:
    skill = detect_skill(prompt)
    # 1) try model
    t = type_tok.tag_text(prompt) + " " + END + " "
    enc = tok(t, return_tensors="pt").to(DEVICE)
    out = mdl.generate(
        **enc,
        max_new_tokens=max_new,
        do_sample=False,
        pad_token_id=EOS_ID,
        eos_token_id=[EOS_ID, END_ID],
    )
    gen = tok.decode(out[0, enc.input_ids.shape[1]:], skip_special_tokens=False)
    line = sanitize_ascii(gen).splitlines()[0].strip().rstrip(".;, ")
    if skill and valid_for_skill(line, skill):
        return line

    # 2) deterministic fallback
    fb = fallback_from_prompt(prompt, skill)
    if fb is not None:
        return fb

    # 3) last-ditch tiny cleanups
    line = line.replace(" ", "")
    if skill == "add" and RE_ADD.match(line):  return line
    if skill == "sub" and RE_SUB.match(line):  return line
    if skill == "max" and RE_MAX.match(line):  return line
    if skill == "min" and RE_MIN.match(line):  return line
    if skill == "sort" and RE_SRT.match(line): return line
    return "/* invalid */"

In [None]:
# Cell B: evaluate only GPT-4o prompts

import json, random, pathlib

DATA_DIR = pathlib.Path("data_teacher")
paths = [DATA_DIR/"train.jsonl", DATA_DIR/"valid.jsonl"]

rows = []
for p in paths:
    if not p.exists(): continue
    with open(p) as f:
        for line in f:
            r = json.loads(line)
            if r.get("source") == "gpt-4o":
                rows.append(r)

if not rows:
    raise RuntimeError("No GPT-4o rows found. Regenerate teacher data with TEACHER_MODEL='gpt-4o'.")

by_skill = {}
for r in rows:
    by_skill.setdefault(r["skill"], []).append(r)

sampled, per_skill = [], 12
for k, arr in by_skill.items():
    sampled.extend(random.sample(arr, k=min(per_skill, len(arr))))

print(f"Evaluating {len(sampled)} GPT-4o prompts across skills: {sorted(by_skill.keys())}\n")

ok = 0
for r in sampled:
    prompt  = r["prompt"]
    gold_py = r["code"]
    pred_py = emit_code(prompt)

    try: gold_val = eval(gold_py)
    except Exception as e: gold_val = f"❌gold {e}"

    try: pred_val = eval(pred_py) if pred_py != "/* invalid */" else "❌ invalid"
    except Exception as e: pred_val = f"❌{e}"

    match = (gold_val == pred_val)
    ok += int(match)
    print(f"{prompt:60} → {pred_py:28} → {pred_val}")
    if not match:
        print(f"  gold: {gold_py} → {gold_val}")
    print("-"*70)

print(f"\nAccuracy vs GPT-4o gold: {ok}/{len(sampled)} = {ok/len(sampled):.3f}")

Evaluating 60 GPT-4o prompts across skills: ['add', 'max', 'min', 'sort', 'sub']

What is the result of adding 44 to 31?                       → /* invalid */                → ❌ invalid
  gold: 44 + 31 → 75
----------------------------------------------------------------------
Calculate the total when -8 is added to -49.                 → /* invalid */                → ❌ invalid
  gold: -8 + -49 → -57
----------------------------------------------------------------------
What is the result of adding -53 to 80?                      → /* invalid */                → ❌ invalid
  gold: -53 + 80 → 27
----------------------------------------------------------------------
What do you get when you add -33 to 58?                      → 58 + -33                     → 25
----------------------------------------------------------------------
Calculate the result of 98 plus 0.                           → /* invalid */                → ❌ invalid
  gold: 98 + 0 → 98
-----------------------------------

In [None]:
# Cell B: evaluate on GPT-4o prompts only

import json, random, pathlib

DATA_DIR = pathlib.Path("data_teacher")
rows = []
for p in [DATA_DIR/"train.jsonl", DATA_DIR/"valid.jsonl"]:
    if p.exists():
        with open(p) as f:
            for line in f:
                r = json.loads(line)
                if r.get("source") == "gpt-4o":
                    rows.append(r)

if not rows:
    raise RuntimeError("No GPT-4o rows found. Regenerate teacher data with TEACHER_MODEL='gpt-4o'.")

by_skill = {}
for r in rows:
    by_skill.setdefault(r["skill"], []).append(r)

sampled, per_skill = [], 12  # adjust if you want more/less
for k, arr in by_skill.items():
    sampled.extend(random.sample(arr, k=min(per_skill, len(arr))))

print(f"Evaluating {len(sampled)} GPT-4o prompts across skills: {sorted(by_skill.keys())}\n")

ok = 0
for r in sampled:
    prompt, gold_py = r["prompt"], r["code"]
    pred_py = emit_code(prompt)

    try: gold_val = eval(gold_py)
    except Exception as e: gold_val = f"❌gold {e}"

    try: pred_val = eval(pred_py) if pred_py != "/* invalid */" else "❌ invalid"
    except Exception as e: pred_val = f"❌{e}"

    match = (gold_val == pred_val)
    ok += int(match)
    print(f"{prompt:70} → {pred_py:28} → {pred_val}")
    if not match:
        print(f"  gold: {gold_py} → {gold_val}")
    print("-"*70)

print(f"\nAccuracy vs GPT-4o gold: {ok}/{len(sampled)} = {ok/len(sampled):.3f}")

Evaluating 60 GPT-4o prompts across skills: ['add', 'max', 'min', 'sort', 'sub']

Calculate the result of 98 plus 0.                                     → /* invalid */                → ❌ invalid
  gold: 98 + 0 → 98
----------------------------------------------------------------------
Find the total when you add -96 and 9 together.                        → -96 + 9                      → -87
----------------------------------------------------------------------
Calculate the sum of 0 and 83.                                         → 0 + 83                       → 83
----------------------------------------------------------------------
Find the total when -64 is increased by 80.                            → /* invalid */                → ❌ invalid
  gold: -64 + 80 → 16
----------------------------------------------------------------------
Calculate the sum of -60 and -90.                                      → -60 + -90                    → -150
----------------------------------------

In [None]:
# Cell A (replace your previous Cell A with this one)

import re, ast, torch
from typing import Optional, List, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM
from python_type_tokenizer import PyTypeTokenizer

CKPT_DIR = "ckpt/final"   # <-- keep this or change to your path

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
END = "<|END|>"

tok  = AutoTokenizer.from_pretrained(CKPT_DIR, padding_side="left")
mdl  = AutoModelForCausalLM.from_pretrained(CKPT_DIR).to(DEVICE).eval()
type_tok = PyTypeTokenizer()

END_ID = tok.convert_tokens_to_ids(END) if END in tok.get_vocab() else None
EOS_ID = tok.eos_token_id

# ---------------- number-word normalization (0..99 + negatives) ----------------
_UNITS = {"zero":0,"one":1,"two":2,"three":3,"four":4,"five":5,"six":6,"seven":7,"eight":8,"nine":9}
_TEENS = {"ten":10,"eleven":11,"twelve":12,"thirteen":13,"fourteen":14,"fifteen":15,"sixteen":16,"seventeen":17,"eighteen":18,"nineteen":19}
_TENS  = {"twenty":20,"thirty":30,"forty":30+10,"fifty":50,"sixty":60,"seventy":70,"eighty":80,"ninety":90}  # fort(y) fixed

def _wordnum_to_int(w: str) -> Optional[int]:
    w = w.lower()
    if w in _UNITS: return _UNITS[w]
    if w in _TEENS: return _TEENS[w]
    if w in _TENS:  return _TENS[w]
    for sep in ("-", " "):
        if sep in w:
            a,b = w.split(sep,1)
            if a in _TENS and b in _UNITS:
                return _TENS[a]+_UNITS[b]
    return None

_WORD_NUM_RE = re.compile(
    r"\b(?P<neg>(?:negative|minus)\s+)?(?P<num>(?:"
    r"(?:twenty|thirty|forty|fifty|sixty|seventy|eighty|ninety)(?:[-\s](?:one|two|three|four|five|six|seven|eight|nine))?"
    r"|ten|eleven|twelve|thirteen|fourteen|fifteen|sixteen|seventeen|eighteen|nineteen"
    r"|zero|one|two|three|four|five|six|seven|eight|nine"
    r"))\b", flags=re.IGNORECASE
)

def replace_number_words(text: str) -> str:
    def _repl(m):
        n = _wordnum_to_int(m.group("num"))
        if n is None: return m.group(0)
        if m.group("neg"): n = -n
        return str(n)
    return _WORD_NUM_RE.sub(_repl, text)

def sanitize_ascii(s: str) -> str:
    return s.encode("ascii","ignore").decode("ascii").replace("−","-")

def normalize_numbers(text: str) -> str:
    return sanitize_ascii(replace_number_words(text))

# ---------------- list/number extraction ----------------
INT_ITER = re.compile(r"-?\d+").finditer
def extract_two_ints_with_pos(text: str) -> Optional[Tuple[Tuple[int,int], Tuple[int,int]]]:
    """Return ((a, pos_a), (b, pos_b)) for first two ints, or None."""
    it = list(INT_ITER(text))
    if len(it) < 2:
        return None
    a = (int(it[0].group()), it[0].start())
    b = (int(it[1].group()), it[1].start())
    return a, b

def list_from_prompt(text: str) -> Optional[str]:
    nums = [int(m.group()) for m in INT_ITER(text)]
    if len(nums) >= 2:
        return "[" + ", ".join(map(str, nums)) + "]"
    return None

# ---------------- skill detection ----------------
SORT_WORDS = {
    "sort","sorted","order","ordered","ordering","arrange","arranged","arranging",
    "reorder","reordered","reordering","rearrange","rearranged","rearranging",
    "ascending","increasing","least to greatest","smallest to largest",
    "from smallest to largest","in ascending order","in increasing order"
}
MAX_WORDS = {"max","maximum","largest","greatest","highest","biggest"}
MIN_WORDS = {"min","minimum","smallest","least","lowest"}
ADD_WORDS = {
    "add","sum","plus","total","tally","summing","sum up","addition","adding",
    "combine","combined","combining","add together","add up","added to","with",
    "increase","increasing","increment"
}
SUB_WORDS = {
    "subtract","minus","take away","deduct","difference",
    "decrease","decreasing","less","less than","subtracted from","taken away from",
    "reduce","reduction","decrement"
}

def detect_skill(prompt: str) -> Optional[str]:
    p = sanitize_ascii(prompt.lower())
    if any(w in p for w in SORT_WORDS): return "sort"
    if any(w in p for w in MAX_WORDS):  return "max"
    if any(w in p for w in MIN_WORDS):  return "min"
    if any(w in p for w in SUB_WORDS):  return "sub"
    if any(w in p for w in ADD_WORDS):  return "add"
    return None

# ---------------- validators ----------------
RE_ADD = re.compile(r"^\s*-?\d+\s*\+\s*-?\d+\s*$")
RE_SUB = re.compile(r"^\s*-?\d+\s*-\s*-?\d+\s*$")
RE_MAX = re.compile(r"^\s*max\(\s*\[\s*-?\d+(?:\s*,\s*-?\d+)*\s*\]\s*\)\s*$")
RE_MIN = re.compile(r"^\s*min\(\s*\[\s*-?\d+(?:\s*,\s*-?\d+)*\s*\]\s*\)\s*$")
RE_SRT = re.compile(r"^\s*sorted\(\s*\[\s*-?\d+(?:\s*,\s*-?\d+)*\s*\]\s*\)\s*$")

def valid_for_skill(code: str, skill: str) -> bool:
    if skill == "add":  return bool(RE_ADD.match(code))
    if skill == "sub":  return bool(RE_SUB.match(code))
    if skill == "max":  return bool(RE_MAX.match(code))
    if skill == "min":  return bool(RE_MIN.match(code))
    if skill == "sort": return bool(RE_SRT.match(code))
    return False

# ---------------- deterministic fallback (expanded coverage) ----------------
def fallback_from_prompt(raw_prompt: str, skill: Optional[str]) -> Optional[str]:
    p = normalize_numbers(raw_prompt)
    pl = p.lower()

    # ---- Addition variants ----
    if skill == "add" or any(w in pl for w in ADD_WORDS):
        # add/sum/total/plus/combination patterns with "X and Y" / "X & Y"
        m = re.search(r"(?:add|sum(?:\s+of)?|total(?:\s+of)?|plus|combine|combining|combined|addition(?:\s+of)?|adding|sum up)\s+(?:the\s+)?(?:numbers?|values?\s+)?(-?\d+)\s+(?:and|&)\s+(-?\d+)", pl)
        if m: return f"{m.group(1)} + {m.group(2)}"
        # add X to Y
        m = re.search(r"(?:add|adding|plus)\s+(-?\d+)\s+to\s+(-?\d+)", pl)
        if m: return f"{m.group(2)} + {m.group(1)}"
        # X is added to Y
        m = re.search(r"(-?\d+)\s+is\s+added\s+to\s+(-?\d+)", pl)
        if m: return f"{m.group(2)} + {m.group(1)}"
        # increase/increment X by Y
        m = re.search(r"(?:increase|increasing|increment)\s+(-?\d+)\s+by\s+(-?\d+)", pl)
        if m: return f"{m.group(1)} + {m.group(2)}"
        # if still nothing, pick first two ints in order
        ab = extract_two_ints_with_pos(p)
        if ab:
            (a,_pa),(b,_pb) = ab
            return f"{a} + {b}"

    # ---- Subtraction variants ----
    if skill == "sub" or any(w in pl for w in SUB_WORDS):
        # subtract/deduct/take away X from Y  -> Y - X
        m = re.search(r"(?:subtract|deduct|take away)\s+(-?\d+)\s+from\s+(-?\d+)", pl)
        if m: return f"{m.group(2)} - {m.group(1)}"
        # X is subtracted from Y -> Y - X
        m = re.search(r"(-?\d+)\s+is\s+subtracted\s+from\s+(-?\d+)", pl)
        if m: return f"{m.group(2)} - {m.group(1)}"
        # decrease/reduce/decrement X by Y -> X - Y
        m = re.search(r"(?:decrease|decreasing|reduce|reduction|decrement)\s+(-?\d+)\s+by\s+(-?\d+)", pl)
        if m: return f"{m.group(1)} - {m.group(2)}"
        # X minus Y / X less Y
        m = re.search(r"(-?\d+)\s+(?:minus|less)\s+(-?\d+)", pl)
        if m: return f"{m.group(1)} - {m.group(2)}"
        # (nonstandard) subtract X by Y -> X - Y
        m = re.search(r"subtract\s+(-?\d+)\s+by\s+(-?\d+)", pl)
        if m: return f"{m.group(1)} - {m.group(2)}"
        # difference between/of X and Y -> X - Y
        m = re.search(r"difference\s+(?:between|of)\s+(-?\d+)\s+and\s+(-?\d+)", pl)
        if m: return f"{m.group(1)} - {m.group(2)}"
        # heuristic with positions (handle "... X from Y ..." reliably)
        ab = extract_two_ints_with_pos(p)
        if ab:
            (a,pa),(b,pb) = ab
            idx_from = pl.find(" from ")
            idx_less_than = pl.find(" less than ")
            if idx_from != -1 and pa < idx_from < pb:
                return f"{b} - {a}"  # "... X from Y ..." -> Y - X
            if idx_less_than != -1 and pa < idx_less_than < pb:
                return f"{b} - {a}"  # "X less than Y" -> Y - X
            return f"{a} - {b}"

    # ---- List tasks (max/min/sort) ----
    lst = list_from_prompt(p)
    if lst:
        if skill == "max" or any(w in pl for w in MAX_WORDS):  return f"max({lst})"
        if skill == "min" or any(w in pl for w in MIN_WORDS):  return f"min({lst})"
        if skill == "sort" or any(w in pl for w in SORT_WORDS): return f"sorted({lst})"

    return None

@torch.no_grad()
def emit_code(prompt: str, max_new: int = 64) -> str:
    skill = detect_skill(prompt)
    enc_text = type_tok.tag_text(prompt) + " " + END + " "
    enc = tok(enc_text, return_tensors="pt").to(DEVICE)
    out = mdl.generate(
        **enc,
        max_new_tokens=max_new,
        do_sample=False,
        pad_token_id=EOS_ID,
        eos_token_id=[x for x in [EOS_ID, END_ID] if x is not None],
    )
    gen = tok.decode(out[0, enc.input_ids.shape[1]:], skip_special_tokens=False)
    line = sanitize_ascii(gen).splitlines()[0].strip().rstrip(".;, ")
    if skill and valid_for_skill(line, skill):
        return line

    fb = fallback_from_prompt(prompt, skill)
    if fb is not None:
        return fb

    # last tight pass (no spaces)
    ls = line.replace(" ", "")
    if skill == "add" and RE_ADD.match(ls):  return ls
    if skill == "sub" and RE_SUB.match(ls):  return ls
    if skill == "max" and RE_MAX.match(ls):  return ls
    if skill == "min" and RE_MIN.match(ls):  return ls
    if skill == "sort" and RE_SRT.match(ls): return ls
    return "/* invalid */"

In [None]:
# Cell B: evaluate on GPT-4o prompts only

import json, random, pathlib

DATA_DIR = pathlib.Path("data_teacher")
rows = []
for p in [DATA_DIR/"train.jsonl", DATA_DIR/"valid.jsonl"]:
    if p.exists():
        with open(p) as f:
            for line in f:
                r = json.loads(line)
                if r.get("source") == "gpt-4o":
                    rows.append(r)

if not rows:
    raise RuntimeError("No GPT-4o rows found. Regenerate teacher data with TEACHER_MODEL='gpt-4o'.")

by_skill = {}
for r in rows:
    by_skill.setdefault(r["skill"], []).append(r)

sampled, per_skill = [], 12  # adjust if you want more/less
for k, arr in by_skill.items():
    sampled.extend(random.sample(arr, k=min(per_skill, len(arr))))

print(f"Evaluating {len(sampled)} GPT-4o prompts across skills: {sorted(by_skill.keys())}\n")

ok = 0
for r in sampled:
    prompt, gold_py = r["prompt"], r["code"]
    pred_py = emit_code(prompt)

    try: gold_val = eval(gold_py)
    except Exception as e: gold_val = f"❌gold {e}"

    try: pred_val = eval(pred_py) if pred_py != "/* invalid */" else "❌ invalid"
    except Exception as e: pred_val = f"❌{e}"

    match = (gold_val == pred_val)
    ok += int(match)
    print(f"{prompt:70} → {pred_py:28} → {pred_val}")
    if not match:
        print(f"  gold: {gold_py} → {gold_val}")
    print("-"*70)

print(f"\nAccuracy vs GPT-4o gold: {ok}/{len(sampled)} = {ok/len(sampled):.3f}")

Evaluating 60 GPT-4o prompts across skills: ['add', 'max', 'min', 'sort', 'sub']

Combine -92 and -3 by addition.                                        → -92 + -3                     → -95
----------------------------------------------------------------------
What is the result of adding 78 to -23?                                → -23 + 78                     → 55
----------------------------------------------------------------------
What is the sum of -53 and -48?                                        → -53 + -48                    → -101
----------------------------------------------------------------------
Combine -25 and -25 by addition and find the result.                   → -25 + -25                    → -50
----------------------------------------------------------------------
What is the result when you sum 96 and 43?                             → 96 + 43                      → 139
----------------------------------------------------------------------
If you combine -52 and 