In [None]:
# Install required packages (some may already be pre-installed in Kaggle)
!pip install -q datasets transformers torch tqdm accelerate bitsandbytes

import os
import json
import time
import re
from tqdm import tqdm
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# Create output directory
os.makedirs("/kaggle/working/results", exist_ok=True)

# ---------- Utility functions ----------
def normalize_answer(ans: str) -> str:
    return ans.strip().lower().replace(",", "")

def is_answer_correct(pred, gold):
    """Compare numeric answers, ignoring formatting and symbols like % or commas."""
    pred_nums = re.findall(r"[-+]?\d*\.?\d+", pred)
    gold_nums = re.findall(r"[-+]?\d*\.?\d+", gold)

    if not pred_nums or not gold_nums:
        return pred.strip().lower() == gold.strip().lower()

    try:
        pred_val = float(pred_nums[-1])
        gold_val = float(gold_nums[-1])
        return abs(pred_val - gold_val) < 1e-3
    except ValueError:
        return False

# ---------- Benchmark class ----------
class PhiBenchmark:
    def __init__(self, model_name="microsoft/Phi-3.5-mini-instruct",
                 device=None, max_new_tokens=512, temperature=0.0):
        print(f"🔹 Loading model: {model_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype=torch.float16, device_map="auto"
        )
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature

    def make_prompt(self, question):
        return (
            "You are a math reasoning assistant. "
            "Solve step by step and finish with 'Answer: <number>'.\n\n"
            f"Problem: {question}\n\nSolution:\n"
        )

    def extract_answer(self, text):
        text = re.split(r"\b(Problem:|Question:)\b", text)[0]
        if "Answer:" in text:
            ans = text.split("Answer:")[-1].strip()
            return ans.split("\n")[0]
        nums = re.findall(r"[-+]?\d*\.?\d+", text)
        return nums[-1].strip() if nums else text.strip().split("\n")[-1]

    @torch.inference_mode()
    def run_one(self, question):
        prompt = self.make_prompt(question)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        start = time.time()
        output = self.model.generate(
            **inputs,
            max_new_tokens=self.max_new_tokens,
            temperature=self.temperature,
            do_sample=False,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        latency = time.time() - start
        decoded = self.tokenizer.decode(output[0], skip_special_tokens=True)
        pred = self.extract_answer(decoded)

        # Retry if junk output
        if len(pred) < 2 or pred in {"'", '"', "'.", ".", ":", ";"}:
            output = self.model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                temperature=0.7,
                do_sample=True,
                top_p=0.9,
                pad_token_id=self.tokenizer.eos_token_id,
            )
            decoded = self.tokenizer.decode(output[0], skip_special_tokens=True)
            pred = self.extract_answer(decoded)

        return decoded, pred, latency

    def benchmark(self, dataset, save_path, q_key="question", a_key="answer", num_examples=None, checkpoint_interval=50):
        correct, latencies = 0, []
        total_examples = num_examples or len(dataset)
        os.makedirs(os.path.dirname(save_path), exist_ok=True)

        # Resume support
        processed = set()
        if os.path.exists(save_path):
            with open(save_path, "r", encoding="utf-8") as f:
                for line in f:
                    try:
                        data = json.loads(line)
                        processed.add(data["index"])
                    except:
                        continue
            print(f"🔄 Resuming from {len(processed)} completed examples...")

        with open(save_path, "a", encoding="utf-8") as f:
            for i, ex in enumerate(tqdm(dataset, desc=f"Running on {save_path}")):
                if i in processed:
                    continue
                if num_examples and i >= num_examples:
                    break

                q, gold = ex[q_key], str(ex[a_key])
                decoded, pred, t = self.run_one(q)
                ok = is_answer_correct(pred, gold)

                result = {
                    "index": i,
                    "question": q,
                    "gold_answer": gold,
                    "prediction": pred,
                    "correct": ok,
                    "latency": t,
                }

                f.write(json.dumps(result) + "\n")
                f.flush()
                os.fsync(f.fileno())

                # Print every checkpoint_interval examples
                if i % checkpoint_interval == 0:
                    print(json.dumps(result, indent=2))

                correct += ok
                latencies.append(t)

        total = len(latencies)
        acc = round(100 * correct / total, 2) if total else 0
        avg_t = round(sum(latencies) / total, 3) if total else 0
        print(f"\n✅ {save_path}: {correct}/{total} correct → {acc}% | avg latency {avg_t}s")

        return {"total": total, "correct": correct, "accuracy": acc, "avg_latency": avg_t}


# ---------- Load GSM-Symbolic dataset ----------
gsm_symbolic = load_dataset("apple/GSM-Symbolic", split="test")

# ---------- Run benchmark ----------
bench = PhiBenchmark("microsoft/Phi-3.5-mini-instruct")

gsm_metrics = bench.benchmark(
    gsm_symbolic,
    "/kaggle/working/results/phib_gsm_symbolic_checkpointed_2.jsonl",
    q_key="question",
    a_key="answer",
    checkpoint_interval=200
)

# ---------- Save summary ----------
summary = {"gsm_symbolic": gsm_metrics}
with open("/kaggle/working/results/summary_gsm.json", "w") as f:
    json.dump(summary, f, indent=2)

print("\n📂 All results stored in /kaggle/working/results/")


In [None]:
!zip -r results.zip /kaggle/working/
from IPython.display import FileLink
FileLink('results.zip')