In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from kaggle_secrets import UserSecretsClient
import os

user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HF_TOKEN")  # your Hugging Face token stored in Kaggle secrets
os.environ["HF_TOKEN"] = hf_token


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

repo_id = "O1-OPEN/OpenO1-Qwen-7B-v0.1"  # full repo
subfolder = "checkpoint-1000"            # the model folder inside repo

device = "cuda" if torch.cuda.is_available() else "cpu"

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    repo_id,
    subfolder=subfolder,
    trust_remote_code=True,
    token=hf_token
)

# Model
model = AutoModelForCausalLM.from_pretrained(
    repo_id,
    subfolder=subfolder,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    device_map="auto" if device == "cuda" else None,
    trust_remote_code=True,
    token=hf_token
)

model.eval()


In [None]:
# ==============================================================================
# COMPREHENSIVE BENCHMARK EVALUATION
# ==============================================================================
print("\n" + "="*80)
print("BENCHMARK EVALUATION")
print("="*80)
MAX_LENGTH = 512
MAX_NEW_TOKENS = 256
import re
import json
from tqdm import tqdm
from typing import Dict, List, Tuple
import numpy as np
from datasets import load_dataset
# Configuration
tasks_to_run = ["gsm8k", "math", "mmlu", "hellaswag", "arc_challenge", "bbh"]
MAX_SAMPLES = 50 # Limit samples per task for faster evaluation
EVAL_BATCH_SIZE = 1  # Process one at a time for generation

def normalize_answer(text: str) -> str:
    """Normalize answer for comparison"""
    text = text.lower().strip()
    # Remove extra whitespace
    text = ' '.join(text.split())
    return text

def extract_numeric_answer(text: str) -> str:
    """Extract numeric answer from text"""
    # Look for patterns like "####" followed by number (GSM8K format)
    match = re.search(r'####\s*(-?\d+\.?\d*)', text)
    if match:
        return match.group(1)
    
    # Look for "the answer is X"
    match = re.search(r'(?:answer is|equals?)\s*[:\-]?\s*(-?\d+\.?\d*)', text.lower())
    if match:
        return match.group(1)
    
    # Extract last number in text
    numbers = re.findall(r'-?\d+\.?\d*', text)
    return numbers[-1] if numbers else ""
MAX_LENGTH = 512
MAX_NEW_TOKENS = 256
def extract_letter_answer(text: str) -> str:
    """Extract letter answer (A, B, C, D) from text"""
    # Look for explicit answer format
    match = re.search(r'(?:answer is|answer:|correct answer is)\s*([A-D])', text, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    
    # Look for standalone letter in parentheses or brackets
    match = re.search(r'[\(\[]([A-D])[\)\]]', text, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    
    # Last resort: first letter A-D that appears
    match = re.search(r'\b([A-D])\b', text, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    
    return ""

# ==============================================================================
# GSM8K Evaluation
# ==============================================================================
def evaluate_gsm8k(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on GSM8K dataset"""
    print("\n--- GSM8K Evaluation ---")
    ds = load_dataset("gsm8k", "main", split="test")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="GSM8K"):
        question = item["question"]
        answer = item["answer"].split("####")[-1].strip()
        
        prompt = f"Problem: {question}\n\nSolution:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, 
                          max_length=MAX_LENGTH).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        predicted = extract_numeric_answer(response)
        
        if normalize_answer(predicted) == normalize_answer(answer):
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"GSM8K Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# MATH Dataset Evaluation
# ==============================================================================
def evaluate_math(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on MATH dataset"""
    print("\n--- MATH Dataset Evaluation ---")
    ds = load_dataset("math_dataset", "algebra__linear_1d", split="test")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="MATH"):
        question = item["question"]
        answer = item["answer"]
        
        prompt = f"Problem: {question}\n\nSolution:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                          max_length=MAX_LENGTH).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        predicted = extract_numeric_answer(response)
        actual = extract_numeric_answer(answer)
        
        if normalize_answer(predicted) == normalize_answer(actual):
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"MATH Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# MMLU Evaluation
# ==============================================================================
def evaluate_mmlu(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on MMLU dataset"""
    print("\n--- MMLU Evaluation ---")
    ds = load_dataset("cais/mmlu", "abstract_algebra", split="test")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="MMLU"):
        question = item["question"]
        choices = item["choices"]
        answer_idx = item["answer"]
        
        # Format multiple choice
        choice_text = "\n".join([f"{chr(65+i)}. {c}" for i, c in enumerate(choices)])
        prompt = f"Question: {question}\n\n{choice_text}\n\nAnswer:"
        
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                          max_length=MAX_LENGTH).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=10,
                temperature=0.1,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        predicted_letter = extract_letter_answer(response)
        correct_letter = chr(65 + answer_idx)
        
        if predicted_letter == correct_letter:
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"MMLU Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# ARC Challenge Evaluation
# ==============================================================================
def evaluate_arc_challenge(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on ARC Challenge dataset"""
    print("\n--- ARC Challenge Evaluation ---")
    ds = load_dataset("ai2_arc", "ARC-Challenge", split="test")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="ARC-C"):
        question = item["question"]
        choices = item["choices"]["text"]
        labels = item["choices"]["label"]
        answer = item["answerKey"]
        
        # Format multiple choice
        choice_text = "\n".join([f"{labels[i]}. {choices[i]}" for i in range(len(choices))])
        prompt = f"Question: {question}\n\n{choice_text}\n\nAnswer:"
        
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                          max_length=MAX_LENGTH).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=10,
                temperature=0.1,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        predicted = extract_letter_answer(response)
        
        if predicted.upper() == answer.upper():
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"ARC Challenge Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# BBH Evaluation
# ==============================================================================
def evaluate_bbh(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on BBH (Big Bench Hard) dataset"""
    print("\n--- BBH Evaluation ---")
    ds = load_dataset("lukaemon/bbh", "boolean_expressions", split="test")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="BBH"):
        question = item["input"]
        answer = item["target"]
        
        prompt = f"Question: {question}\n\nAnswer:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                          max_length=MAX_LENGTH).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=50,
                temperature=0.1,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        
        # Check if answer is contained in response
        if normalize_answer(answer) in normalize_answer(response):
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"BBH Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# HellaSwag Evaluation
# ==============================================================================
def evaluate_hellaswag(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on HellaSwag dataset"""
    print("\n--- HellaSwag Evaluation ---")
    ds = load_dataset("hellaswag", split="validation")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="HellaSwag"):
        context = item["ctx"]
        endings = item["endings"]
        label = int(item["label"])
        
        # Score each ending
        scores = []
        for ending in endings:
            full_text = context + " " + ending
            inputs = tokenizer(full_text, return_tensors="pt", truncation=True,
                             max_length=MAX_LENGTH).to(model.device)
            
            with torch.no_grad():
                outputs = model(**inputs, labels=inputs["input_ids"])
                # Use negative loss as score (lower loss = better)
                scores.append(-outputs.loss.item())
        
        # Predict the ending with highest score
        predicted = np.argmax(scores)
        
        if predicted == label:
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"HellaSwag Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# Run All Evaluations
# ==============================================================================
def run_all_benchmarks(model, tokenizer, tasks=None):
    """Run all specified benchmarks"""
    if tasks is None:
        tasks = tasks_to_run
    
    results = {}
    
    if "gsm8k" in tasks:
        results["gsm8k"] = evaluate_gsm8k(model, tokenizer)
    
    if "math" in tasks:
        results["math"] = evaluate_math(model, tokenizer)
    
    if "mmlu" in tasks:
        results["mmlu"] = evaluate_mmlu(model, tokenizer)
    
    if "arc_challenge" in tasks:
        results["arc_challenge"] = evaluate_arc_challenge(model, tokenizer)
    
    if "bbh" in tasks:
        results["bbh"] = evaluate_bbh(model, tokenizer)
    
    if "hellaswag" in tasks:
        results["hellaswag"] = evaluate_hellaswag(model, tokenizer)
    
    return results

# ==============================================================================
# Main Evaluation
# ==============================================================================
print("\n" + "="*80)
print("STARTING BENCHMARK EVALUATION")
print("="*80)

# Load the trained model
print("Loading model...")
model.eval()

# Run benchmarks
results = run_all_benchmarks(model, tokenizer, tasks_to_run)

# Print summary
print("\n" + "="*80)
print("BENCHMARK RESULTS SUMMARY")
print("="*80)
print(results)
for task, accuracy in results.items():
    print(f"{task.upper()}: {accuracy*100:.2f}%")

# Calculate average
avg_accuracy = np.mean(list(results.values()))
print(f"\nAVERAGE ACCURACY: {avg_accuracy*100:.2f}%")

# Save results
results_with_avg = {**results, "average": avg_accuracy}
with open("benchmark_results.json", "w") as f:
    json.dump(results_with_avg, f, indent=2)

print("\n✓ Results saved to benchmark_results.json")

In [None]:
# Cell 4 - Small helper functions: generation, normalize answers, scoring
import numpy as np

def generate_text(prompt, max_new_tokens=128, do_sample=False, temperature=0.0):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature)
    text = tokenizer.decode(out[0], skip_special_tokens=True)
    # return only the generated suffix (after prompt)
    if prompt in text:
        return text.split(prompt, 1)[1].strip()
    return text.strip()

# math answer normalizer (simple)
def normalize_math_answer(text):
    
    # remove commas, words, extract first numeric-like token
    s = text.strip().lower()
    s = s.replace(",", "")
    # common prefixes "the answer is", "a:"
    s = re.sub(r"^(answer[:\s]*)", "", s)
    # try to find a number or fraction
    m = re.search(r"(-?\d+(\.\d+)?(/\d+)?)", s)
    if m:
        return m.group(1)
    # fallback: return cleaned text
    return re.sub(r"\s+", " ", s).strip()

# multiple-choice answer normalizer (for MMLU/Hellaswag/ARC etc.)
def normalize_choice(text):
    t = text.strip().upper()
    # prefer single letter A/B/C/D
    m = re.search(r"\b([A-D])\b", t)
    if m:
        return m.group(1)
    # maybe returns option text; we can't map automatically, return full text
    return t


In [None]:
max_examples = 100  # set to None for full dataset (may be large); use small for quick testing

output_file = "eval_results.json"
tasks_to_run = ["gsm8k", "math", "mmlu", "hellaswag", "arc_challenge", "bbh"]

In [None]:
# Cell 5 - Evaluate GSM8K (exact-match numeric)
from datasets import load_dataset
from tqdm.auto import tqdm
from pathlib import Path
import os, json, math, re
def eval_gsm8k(n=None):
    ds = load_dataset("gsm8k", "main", split="test")
    if n:
        ds = ds.select(range(min(n, len(ds))))
    correct = 0
    total = 0
    for ex in tqdm(ds, desc="GSM8K"):
        prompt = ex["question"] + "\n\nLet's think step by step.\nAnswer:"
        ans = generate_text(prompt, max_new_tokens=128, do_sample=False, temperature=0.0)
        pred = normalize_math_answer(ans)
        gold = normalize_math_answer(ex["answer"])
        total += 1
        if pred == gold:
            correct += 1
    return {"accuracy": correct/total if total>0 else 0, "correct": correct, "total": total}

# Quick run (use max_examples to limit)
gsm8k_res = eval_gsm8k(n=max_examples)
print("GSM8K:", gsm8k_res)


In [None]:
from datasets import load_dataset
from tqdm.auto import tqdm
from pathlib import Path
import os, json, math, re

# Cell 6 - Evaluate Swallow-MATH (exact match of final answer)
def eval_swallow_math(n=None):
    ds = load_dataset("tokyotech-llm/swallow-math", split="test")
    if n:
        ds = ds.select(range(min(n, len(ds))))
    correct=0; total=0
    for ex in tqdm(ds, desc="Swallow-MATH"):
        q = ex.get("question", "")
        a = ex.get("answer", "")

        # decode bytes if necessary
        if isinstance(q, (bytes, bytearray)):
            q = q.decode("utf-8", errors="ignore")
        if isinstance(a, (bytes, bytearray)):
            a = a.decode("utf-8", errors="ignore")

        prompt = q + "\n\nLet's think step by step.\nAnswer:"
        ans = generate_text(prompt, max_new_tokens=256, do_sample=False)
        pred = normalize_math_answer(ans)
        gold = normalize_math_answer(a)

        total += 1
        if pred == gold and gold!="":
            correct += 1
    return {"accuracy": correct/total if total>0 else 0, "correct": correct, "total": total}

# May fail if dataset missing/format mismatch. Use max_examples small to debug.
try:
    swallow_math_res = eval_swallow_math(n=max_examples)
    print("Swallow-MATH:", swallow_math_res)
except Exception as e:
    print("Swallow-MATH eval failed. Error:", e)
    swallow_math_res = {"error": str(e)}


In [None]:
from datasets import load_dataset
from tqdm.auto import tqdm
from pathlib import Path
import os, json, math, re
# Cell 6 - Evaluate MATH (exact match of final boxed answer)
# MATH dataset is large and has many long problems, use small n for debugging
def eval_math(n=None):
    ds = load_dataset("math_dataset", "algebra__linear_1d", split="test")  # dataset id may vary; fallback to 'math'
    # If the above fails, use datasets 'MATH' / check available config; handle exceptions
    # For safety below, try a few dataset names
    if len(ds)==0:
        ds = load_dataset("MATH", split="test")
    if n:
        ds = ds.select(range(min(n, len(ds))))
    correct=0; total=0
    for ex in tqdm(ds, desc="MATH"):
        prompt = ex["question"] + "\n\nLet's think step by step.\nAnswer:"
        ans = generate_text(prompt, max_new_tokens=256, do_sample=False)
        pred = normalize_math_answer(ans)
        # gold answer fields vary; try some keys
        gold = ex.get("answer") or ex.get("solution") or ""
        gold = normalize_math_answer(gold)
        total += 1
        if pred == gold and gold!="":
            correct += 1
    return {"accuracy": correct/total if total>0 else 0, "correct": correct, "total": total}

# May fail if dataset names differ. Use max_examples small to debug.
try:
    math_res = eval_math(n=max_examples)
    print("MATH:", math_res)
except Exception as e:
    print("MATH eval failed (dataset name or format mismatch). Error:", e)
    math_res = {"error": str(e)}


In [None]:
# Cell 7 - Evaluate multiple-choice tasks (MMLU, Hellaswag, ARC, BBH where applicable)
# We'll handle MMLU (57 tasks aggregated) using the 'mmlu' dataset from HF; some datasets may require processing.
# For BBH and others, dataset ids vary; we'll try standard names, otherwise leave as TODO.

def eval_multiple_choice(dataset_name, split="test", n=None, prompt_template=None, choice_key="choices", question_key="question", answer_key="answer"):
    ds = load_dataset(dataset_name, split=split)
    if n:
        ds = ds.select(range(min(n, len(ds))))
    correct=0; total=0
    for ex in tqdm(ds, desc=dataset_name):
        # Build a zero-shot prompt. Some datasets have different formats.
        q = ex.get(question_key) or ex.get("question") or ex.get("context") or ""
        choices = ex.get(choice_key) or ex.get("choices") or ex.get("options") or None
        if choices is None:
            # try alternative structures
            if "choices" in ex:
                choices = ex["choices"]
            else:
                choices = []
        # simple prompt
        prompt = q + "\n\nChoices:\n"
        if isinstance(choices, dict):
            # sometimes choices is a dict with "text" list
            choice_list = choices.get("text", [])
        else:
            choice_list = choices if isinstance(choices, list) else []
        for i, ch in enumerate(choice_list):
            prompt += f"{chr(65+i)}. {ch}\n"
        prompt += "\nAnswer (choose letter):"
        out = generate_text(prompt, max_new_tokens=30, do_sample=False)
        pred = normalize_choice(out)
        gold = ex.get(answer_key)
        # gold may be index or letter or text - try to normalize
        if isinstance(gold, int):
            gold_letter = chr(65+gold)
        elif isinstance(gold, str) and gold.strip().upper() in ["A","B","C","D"]:
            gold_letter = gold.strip().upper()
        else:
            # try to map gold text to a letter
            gold_letter = None
            # if gold is text, find matching choice
            for i,ch in enumerate(choice_list):
                if isinstance(gold,str) and gold.strip().lower() in str(ch).lower():
                    gold_letter = chr(65+i)
                    break
        total += 1
        if gold_letter and pred==gold_letter:
            correct += 1
    return {"accuracy": correct/total if total>0 else 0, "correct": correct, "total": total}

# Example: Hellaswag
try:
    hellaswag_res = eval_multiple_choice("hellaswag", split="validation", n=max_examples, question_key="context", choice_key="endings", answer_key="label")
    print("Hellaswag:", hellaswag_res)
except Exception as e:
    print("Hellaswag eval failed:", e)
    hellaswag_res = {"error": str(e)}


In [None]:
# Cell 8 - Try MMLU (use 'mmlu' dataset from HF if installed)
try:
    # MMLU has many subject subdatasets; to get full MMLU use the 'mmlu' dataset or process per subject
    mmlu_ds = load_dataset("cais/mmlu", "abstract_algebra", split="test")
    if max_examples:
        mmlu_ds = mmlu_ds.select(range(min(max_examples, len(mmlu_ds))))
    # MMLU examples often have 'question', 'options', 'answer'
    def run_mmlu(ds):
        correct=0; total=0
        for ex in tqdm(ds, desc="MMLU"):
            q = ex.get("input") or ex.get("question") or ex.get("prompt") or ""
            options = ex.get("options") or ex.get("choices") or ex.get("targets") or []
            # build prompt
            prompt = q + "\n\nChoices:\n"
            for i,opt in enumerate(options):
                prompt += f"{chr(65+i)}. {opt}\n"
            prompt += "\nAnswer (choose letter):"
            out = generate_text(prompt, max_new_tokens=20)
            pred = normalize_choice(out)
            gold = ex.get("output") or ex.get("answer")
            gold_letter = None
            if isinstance(gold, int):
                gold_letter = chr(65+gold)
            elif isinstance(gold, str) and gold.strip().upper() in ["A","B","C","D"]:
                gold_letter = gold.strip().upper()
            else:
                # attempt to map by text match
                for i,opt in enumerate(options):
                    if isinstance(gold,str) and gold.strip().lower() in str(opt).lower():
                        gold_letter = chr(65+i)
                        break
            total += 1
            if gold_letter and pred==gold_letter:
                correct += 1
        return {"accuracy": correct/total if total>0 else 0, "correct": correct, "total": total}

    mmlu_res = run_mmlu(mmlu_ds)
    print("MMLU (sample):", mmlu_res)
except Exception as e:
    print("MMLU eval failed:", e)
    mmlu_res = {"error": str(e)}


In [None]:
from datasets import load_dataset
from tqdm.auto import tqdm

# ---------- ARC-Challenge ----------
def eval_arc_challenge(n=None):
    ds = load_dataset("ai2_arc", "ARC-Challenge", split="test")
    if n:
        ds = ds.select(range(min(n, len(ds))))
    correct, total = 0, 0
    for ex in tqdm(ds, desc="ARC-Challenge"):
        q = ex.get("question", "")
        choices = ex.get("choices", {}).get("text", [])
        ans_key = ex.get("answerKey", "")

        # Build multiple-choice style prompt
        choice_str = "\n".join([f"{chr(65+i)}. {c}" for i, c in enumerate(choices)])
        prompt = q + "\n" + choice_str + "\n\nAnswer:"
        
        model_out = generate_text(prompt, max_new_tokens=64, do_sample=False)
        pred = model_out.strip().upper()[:1]  # first char like "A","B","C"
        
        total += 1
        if pred == ans_key:
            correct += 1
    return {"accuracy": correct/total if total>0 else 0, "correct": correct, "total": total}


# ---------- BBH (boolean_expressions subset) ----------
def eval_bbh_boolean(n=None):
    ds = load_dataset("lukaemon/bbh", "boolean_expressions", split="test")
    if n:
        ds = ds.select(range(min(n, len(ds))))
    correct, total = 0, 0
    for ex in tqdm(ds, desc="BBH-Boolean"):
        q = ex.get("input", "")
        gold = ex.get("target", "")

        prompt = q + "\n\nAnswer:"
        ans = generate_text(prompt, max_new_tokens=128, do_sample=False)
        pred = ans.strip().lower()
        gold = gold.strip().lower()

        total += 1
        if pred == gold:
            correct += 1
    return {"accuracy": correct/total if total>0 else 0, "correct": correct, "total": total}


# ---------- Run them ----------
try:
    arc_res = eval_arc_challenge(n=max_examples)
    print("ARC-Challenge:", arc_res)
except Exception as e:
    print("ARC-Challenge eval failed:", e)
    arc_res = {"error": str(e)}

try:
    bbh_res = eval_bbh_boolean(n=max_examples)
    print("BBH-Boolean:", bbh_res)
except Exception as e:
    print("BBH-Boolean eval failed:", e)
    bbh_res = {"error": str(e)}
