In [None]:
# ==============================================================================
# CELL 0: Install/Upgrade Required Packages
# ==============================================================================
# Run this first, then restart kernel
"""
!pip uninstall -y transformers accelerate trl peft -y
!pip install transformers==4.36.2 --no-cache-dir
!pip install accelerate==0.25.0 --no-cache-dir
!pip install peft==0.7.1 --no-cache-dir
!pip install trl==0.7.4 --no-cache-dir
!pip install bitsandbytes==0.41.3 --no-cache-dir
!pip install datasets==3.6.0 --no-cache-dir
"""

In [None]:
# ==============================================================================
# RUN THIS CELL FIRST - Install/Upgrade BitsAndBytes
# ==============================================================================
import subprocess
import sys

print("="*80)
print("INSTALLING/UPGRADING BITSANDBYTES")
print("="*80)

# Method 1: Try standard upgrade
try:
    print("\n1. Upgrading bitsandbytes to latest version...")
    result = subprocess.run(
        [sys.executable, "-m", "pip", "install", "-U", "bitsandbytes"],
        capture_output=True,
        text=True
    )
    print(result.stdout)
    if result.returncode == 0:
        print("Successfully upgraded bitsandbytes")
    else:
        print("Upgrade had some issues, trying alternative method...")
        raise Exception("Standard install failed")
except Exception as e:
    # Method 2: Try with specific version
    print("Trying to install specific version (0.41.0)...")
    try:
        subprocess.check_call(
            [sys.executable, "-m", "pip", "install", "--upgrade", "bitsandbytes==0.41.0"],
            stdout=subprocess.PIPE
        )
        print("Installed bitsandbytes 0.41.0")
    except:
        print("Could not install specific version")

# Verify installation
print("\n3. Verifying installation...")
try:
    import bitsandbytes as bnb
    print(f"✓ bitsandbytes version: {bnb.__version__}")
    print("✓ Import successful!")
except Exception as e:
    print(f"✗ Import failed: {e}")
    print("\n⚠ IMPORTANT: If bitsandbytes still doesn't work:")
    print("   - Set USE_QUANTIZATION = False in the config")
    print("   - The code will automatically fall back to FP16")
    print("   - You'll need more GPU memory but it will work")

# Also upgrade related packages
print("\n4. Upgrading related packages...")
try:
    subprocess.check_call(
        [sys.executable, "-m", "pip", "install", "-U", "-q", "accelerate", "transformers"],
    )
    print("Upgraded accelerate and transformers")
except:
    print("Could not upgrade all packages")

print("\n" + "="*80)
print("INSTALLATION COMPLETE ")
print("="*80)

In [None]:


# ==============================================================================
# CELL 1: Imports and Configuration
# ==============================================================================
import os
import torch
import torch.nn.functional as F
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    BitsAndBytesConfig,
)
from peft import (
    LoraConfig, 
    get_peft_model, 
    prepare_model_for_kbit_training,
    PeftModel
)
# Note: We'll implement a simplified PPO instead of using trl's PPOTrainer to avoid import issues
from datasets import Dataset
from kaggle_secrets import UserSecretsClient
import numpy as np
from typing import List, Dict, Tuple
import re
from torch.optim import AdamW
from torch.utils.data import DataLoader

# Get HF token
hf = UserSecretsClient()
HF_TOKEN = hf.get_secret("HF_TOKEN")

# Configuration
REPO = "O1-OPEN/OpenO1-LLama-8B-v0.1"
# SUBFOLDER = "checkpoint-1000"
USE_SUBFOLDER = False
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LENGTH = 512
MAX_NEW_TOKENS = 256

# Training hyperparameters
STAGE1_EPOCHS = 3
STAGE1_BATCH_SIZE = 2
STAGE1_GRAD_ACCUM = 8
STAGE1_LR = 2e-4
KL_COEF = 0.1  # KL penalty coefficient

STAGE2_STEPS = 1000
STAGE2_BATCH_SIZE = 2
STAGE2_LR = 1e-5
CORRECTION_BONUS = 1.0  # Bonus when second attempt > first

os.environ["TOKENIZERS_PARALLELISM"] = "false"

print(f"Using device: {DEVICE}")
print(f"Available GPUs: {torch.cuda.device_count()}")

# ==============================================================================
# CELL 2: Load Tokenizer and Models (4-bit Quantization)
# ==============================================================================
print("Loading tokenizer...")
if USE_SUBFOLDER:
    tokenizer = AutoTokenizer.from_pretrained(
        REPO, 
        subfolder=SUBFOLDER,
        trust_remote_code=True,
        token=HF_TOKEN
    )
else:
    tokenizer = AutoTokenizer.from_pretrained(
        REPO,
        trust_remote_code=True,
        token=HF_TOKEN
    )

# Set pad token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

print("Setting up 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

print("Loading base model (this may take several minutes)...")
model_kwargs = {
    "quantization_config": bnb_config,
    "device_map": "auto",  # Spreads across both T4s
    "trust_remote_code": True,
    "token": HF_TOKEN,
}
if USE_SUBFOLDER:
    model_kwargs["subfolder"] = SUBFOLDER

base_model = AutoModelForCausalLM.from_pretrained(REPO, **model_kwargs)
base_model.config.use_cache = False  # Required for gradient checkpointing

print("Loading reference model (frozen)...")
ref_model = AutoModelForCausalLM.from_pretrained(REPO, **model_kwargs)
ref_model.eval()
for p in ref_model.parameters():
    p.requires_grad = False

print(f"Base model device map: {base_model.hf_device_map}")

# ==============================================================================
# CELL 3: Prepare Model with LoRA
# ==============================================================================
print("Preparing model for k-bit training...")
base_model = prepare_model_for_kbit_training(base_model)

# LoRA configuration
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # Qwen modules
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

print("Attaching LoRA adapters...")
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()

# ==============================================================================
# CELL 4: Prepare Dataset (Example Format)
# ==============================================================================
# Your dataset should have: problem, first_attempt, second_attempt, correctness
# Format: {"problem": "...", "first_attempt": "...", "second_attempt": "...", "answer": "...", "is_correct_1": bool, "is_correct_2": bool}
def create_sample_dataset():
    """Create a small sample dataset for testing"""
    samples = [
        {
            "problem": "What is 25 + 17?",
            "first_attempt": "Let me calculate: 25 + 17 = 41",
            "second_attempt": "Let me recalculate: 25 + 17 = 42",
            "answer": "42",
            "is_correct_1": False,
            "is_correct_2": True
        },
        {
            "problem": "Solve: 3x + 5 = 14",
            "first_attempt": "3x = 14 - 5 = 9, so x = 4",
            "second_attempt": "3x = 14 - 5 = 9, so x = 3",
            "answer": "3",
            "is_correct_1": False,
            "is_correct_2": True
        },
        {
            "problem": "What is 15 * 8?",
            "first_attempt": "15 * 8 = 110",
            "second_attempt": "Let me recalculate: 15 * 8 = 120",
            "answer": "120",
            "is_correct_1": False,
            "is_correct_2": True
        },
        {
            "problem": "If y - 7 = 12, what is y?",
            "first_attempt": "y = 12 + 7 = 20",
            "second_attempt": "y = 12 + 7 = 19",
            "answer": "19",
            "is_correct_1": False,
            "is_correct_2": True
        },
    ] * 25  # Repeat to create larger dataset
    return Dataset.from_list(samples)

# def create_sample_dataset():
#     """Create a small sample dataset for testing"""
#     samples = [
#         {
#             "problem": "What is 25 + 17?",
#             "first_attempt": "Let me calculate: 25 + 17 = 41",
#             "second_attempt": "Let me recalculate: 25 + 17 = 42",
#             "answer": "42",
#             "is_correct_1": False,
#             "is_correct_2": True
#         },
#         {
#             "problem": "Solve: 3x + 5 = 14",
#             "first_attempt": "3x = 14 - 5 = 9, so x = 4",
#             "second_attempt": "3x = 14 - 5 = 9, so x = 3",
#             "answer": "3",
#             "is_correct_1": False,
#             "is_correct_2": True
#         },
#         # Add more examples...
#     ]
#     return Dataset.from_list(samples)

# Load your actual dataset here
print("Creating/loading dataset...")
train_dataset = create_sample_dataset()
print(f"Dataset size: {len(train_dataset)}")



In [None]:
# STAGE I - Supervised Fine-tuning with KL Penalty (Fixed)
# ==============================================================================
print("\n" + "="*80)
print("STAGE I: Supervised Fine-tuning with KL Penalty (Fixed)")
print("="*80)
from torch.utils.data import DataLoader

# --- Step 1: Prepare DataLoader ---
train_loader = DataLoader(train_dataset, batch_size=STAGE1_BATCH_SIZE, shuffle=True)

# --- Step 2: Define KL Divergence ---
# def compute_kl_divergence(logits_policy, logits_ref):
#     """
#     Compute KL divergence between policy and reference model
#     """
#     log_probs_policy = F.log_softmax(logits_policy, dim=-1)
#     probs_ref = F.softmax(logits_ref, dim=-1)
#     kl = (probs_ref * (probs_ref.log() - log_probs_policy)).sum(dim=-1)
#     return kl.mean()
import torch
import torch.nn.functional as F

def compute_kl_divergence(policy_logits, ref_logits, attention_mask=None, eps=1e-12):
    """
    Compute KL(P_ref || Q_policy) per token with masking and numeric stability,
    returning the mean KL per sample.

    Args:
      policy_logits: Tensor [batch, seq_len, vocab]
      ref_logits:    Tensor [batch, seq_len, vocab]
      attention_mask: Optional Tensor [batch, seq_len] with 1 for real tokens, 0 for padding.
      eps: small value to avoid div/zero (not usually needed with log_softmax but kept for safety).

    Returns:
      scalar tensor: mean KL across non-padding tokens (averaged over batch)
    """
    # ensure shapes match
    assert policy_logits.shape == ref_logits.shape, f"policy {policy_logits.shape} vs ref {ref_logits.shape}"

    # stable log-probs
    log_probs_policy = F.log_softmax(policy_logits, dim=-1)   # log Q
    log_probs_ref = F.log_softmax(ref_logits, dim=-1)         # log P

    # probs for P (ref) via exp(log_probs_ref) — numerically stable
    probs_ref = log_probs_ref.exp()

    # per-token KL: sum_vocab P * (log P - log Q)
    kl_per_token = (probs_ref * (log_probs_ref - log_probs_policy)).sum(dim=-1)  # [batch, seq_len]

    if attention_mask is not None:
        # cast mask to same dtype
        mask = attention_mask.to(kl_per_token.dtype)  # [batch, seq_len]
        # zero out padding tokens, compute per-sample mean over valid tokens
        valid_tokens_per_sample = mask.sum(dim=1).clamp_min(1.0)  # avoid div by 0
        kl_per_sample = (kl_per_token * mask).sum(dim=1) / valid_tokens_per_sample
    else:
        # mean over seq_len when no mask provided
        kl_per_sample = kl_per_token.mean(dim=1)

    return kl_per_sample.mean()  # scalar

# --- Step 3: Stage I training step ---
def stage1_train_step(batch, model, ref_model, optimizer, tokenizer):
    model.train()
    
    # Prepare prompts and targets
    prompts = [f"Problem: {p}\n\nFirst attempt: {a1}\n\nLet me reconsider:" 
               for p, a1 in zip(batch["problem"], batch["first_attempt"])]
    targets = [f"{t}" for t in batch["second_attempt"]]
    
    # Combine prompts and targets for proper tokenization
    full_texts = [p + t for p, t in zip(prompts, targets)]
    
    # Tokenize the combined text
    inputs = tokenizer(full_texts, padding='longest', truncation=True,
                       max_length=MAX_LENGTH, return_tensors="pt")
    input_ids = inputs["input_ids"].to(model.device)
    attention_mask = inputs["attention_mask"].to(model.device)
    
    # Create labels by tokenizing prompts to find where targets start
    prompt_inputs = tokenizer(prompts, padding='longest', truncation=True,
                              max_length=MAX_LENGTH, return_tensors="pt")
    prompt_lengths = (prompt_inputs["attention_mask"].sum(dim=1)).tolist()
    
    # Create labels: -100 for prompt tokens, actual tokens for target
    labels = input_ids.clone()
    for i, prompt_len in enumerate(prompt_lengths):
        labels[i, :prompt_len] = -100
    
    # Replace padding tokens with -100
    labels[labels == tokenizer.pad_token_id] = -100
    
    # Forward pass
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    loss_lm = outputs.loss
    
    # --- KL penalty on first attempt ---
    first_prompts = [f"Problem: {p}\n\nSolution:" for p in batch["problem"]]
    first_inputs = tokenizer(first_prompts, padding='longest', truncation=True,
                             max_length=MAX_LENGTH, return_tensors="pt")
    first_inputs = {k: v.to(model.device) for k, v in first_inputs.items() 
                    if k in ["input_ids", "attention_mask"]}
    
    with torch.no_grad():
        ref_outputs = ref_model(**first_inputs)
        ref_logits = ref_outputs.logits
    
    policy_outputs = model(**first_inputs)
    policy_logits = policy_outputs.logits
    
    kl_loss = compute_kl_divergence(policy_logits, ref_logits)
    
    # --- Total loss ---
    loss = loss_lm + KL_COEF * kl_loss
    
    # Backward and optimizer step
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    return loss.item(), loss_lm.item(), kl_loss.item()

# --- Step 4: Training loop ---
print("Starting Stage I training...")
optimizer = torch.optim.AdamW(model.parameters(), lr=STAGE1_LR)

for epoch in range(STAGE1_EPOCHS):
    total_loss = 0
    total_lm_loss = 0
    total_kl_loss = 0
    
    for step, batch in enumerate(train_loader):
        loss, lm_loss, kl_loss = stage1_train_step(batch, model, ref_model, optimizer, tokenizer)
        total_loss += loss
        total_lm_loss += lm_loss
        total_kl_loss += kl_loss
        
        if step % 10 == 0:
            print(f"Epoch {epoch+1}, Step {step}: "
                  f"Loss={loss:.4f}, LM={lm_loss:.4f}, KL={kl_loss:.4f}")
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} completed. Avg Loss: {avg_loss:.4f}")

# --- Step 5: Save checkpoint ---
print("Stage I completed! Saving checkpoint...")
model.save_pretrained("./stage1_lora")
tokenizer.save_pretrained("./stage1_lora")
print("Stage I checkpoint saved at ./stage1_lora")

In [None]:
# CELL 6: Stage II - Simplified REINFORCE with Correction Reward (FIXED)
# ==============================================================================
print("\n" + "="*80)
print("STAGE II: REINFORCE Training with Correction Rewards")
print("="*80)

def extract_answer(text: str) -> str:
    """Extract numeric answer from text"""
    numbers = re.findall(r'-?\d+\.?\d*', text)
    return numbers[-1] if numbers else ""

def compute_reward(problem: str, first_attempt: str, second_attempt: str, 
                   ground_truth: str) -> float:
    """
    Compute reward for SCoRe:
    - Base reward for correctness
    - Bonus if second attempt is better than first
    """
    ans1 = extract_answer(first_attempt)
    ans2 = extract_answer(second_attempt)
    gt = ground_truth.strip()
    
    correct_1 = (ans1 == gt)
    correct_2 = (ans2 == gt)
    
    reward = 1.0 if correct_2 else 0.0
    
    if not correct_1 and correct_2:
        reward += CORRECTION_BONUS
    
    if correct_1 and not correct_2:
        reward -= CORRECTION_BONUS
    
    return reward

class SimpleValueHead(torch.nn.Module):
    """Simple value head for policy gradient"""
    def __init__(self, hidden_size):
        super().__init__()
        self.value_head = torch.nn.Linear(hidden_size, 1)
    
    def forward(self, hidden_states):
        return self.value_head(hidden_states[:, -1, :]).squeeze(-1)

# Add value head to model
print("Adding value head to model...")
hidden_size = model.config.hidden_size
value_head = SimpleValueHead(hidden_size).to(model.device)
optimizer_rl = AdamW(
    list(model.parameters()) + list(value_head.parameters()),
    lr=STAGE2_LR
)

def reinforce_step(model, value_head, ref_model, tokenizer, batch, optimizer):
    """Single REINFORCE training step"""
    model.train()
    value_head.train()
    
    problem = batch["problem"]
    ground_truth = batch["answer"]
    
    # Generate first attempt
    prompt = f"Problem: {problem}\n\nSolution:"
    inputs = tokenizer(prompt, return_tensors="pt", padding=True,
                      truncation=True, max_length=MAX_LENGTH).to(model.device)
    
    with torch.no_grad():
        outputs_first = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            top_p=0.95,
            temperature=0.7,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    first_attempt = tokenizer.decode(outputs_first[0], skip_special_tokens=True)
    
    # Generate second attempt
    correction_prompt = f"Problem: {problem}\n\nSolution:\n\nFirst attempt: {first_attempt}\n\nLet me reconsider:"
    correction_inputs = tokenizer(correction_prompt, return_tensors="pt", 
                                  padding=True, truncation=True, 
                                  max_length=MAX_LENGTH).to(model.device)
    
    # Sample from model (with generation tracking)
    outputs_second = model.generate(
        **correction_inputs,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=True,
        top_p=0.95,
        temperature=0.7,
        pad_token_id=tokenizer.pad_token_id,
    )
    
    second_attempt = tokenizer.decode(outputs_second[0], skip_special_tokens=True)
    
    # Compute reward
    reward = compute_reward(problem, first_attempt, second_attempt, ground_truth)
    
    # Get generated tokens
    generated_tokens = outputs_second[0][correction_inputs.input_ids.shape[1]:]
    
    # Forward pass with gradients enabled and hidden states output
    with torch.enable_grad():
        full_outputs = model(
            input_ids=outputs_second,
            attention_mask=torch.ones_like(outputs_second),
            output_hidden_states=True  # CRITICAL FIX: Enable hidden states
        )
        logits = full_outputs.logits
        
        # Compute log probs for generated tokens
        log_probs = F.log_softmax(logits[0, correction_inputs.input_ids.shape[1]-1:-1, :], dim=-1)
        
        # Ensure we have enough generated tokens
        num_gen_tokens = min(len(generated_tokens), log_probs.shape[0])
        if num_gen_tokens == 0:
            return 0.0, reward, 0.0, 0.0
        
        generated_tokens = generated_tokens[:num_gen_tokens]
        selected_log_probs = log_probs[:num_gen_tokens, generated_tokens]
        
        # Compute value estimate (now hidden_states is available)
        value_estimate = value_head(full_outputs.hidden_states[-1])
        
        # REINFORCE loss
        advantage = reward - value_estimate.detach()
        policy_loss = -(selected_log_probs.mean() * advantage)
        value_loss = F.mse_loss(value_estimate, torch.tensor([reward]).float().to(model.device))
        
        # KL penalty with reference model
        with torch.no_grad():
            ref_outputs = ref_model(**correction_inputs)
            ref_logits = ref_outputs.logits
        
        policy_outputs_kl = model(**correction_inputs)
        policy_logits = policy_outputs_kl.logits
        kl_loss = compute_kl_divergence(policy_logits, ref_logits)
        
        # Total loss
        total_loss = policy_loss + 0.5 * value_loss + 0.01 * kl_loss
    
    # Backward pass
    optimizer.zero_grad()
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(
        list(model.parameters()) + list(value_head.parameters()), 
        1.0
    )
    optimizer.step()
    
    return total_loss.item(), reward, policy_loss.item(), value_loss.item()

# Stage II Training Loop
print("Starting Stage II REINFORCE training...")

for step in range(10): #STAGE2_STEPS
    idx = step % len(train_dataset)
    batch = train_dataset[idx]
    
    try:
        loss, reward, policy_loss, value_loss = reinforce_step(
            model, value_head, ref_model, tokenizer, batch, optimizer_rl
        )
        
        if step % 50 == 0:
            print(f"Step {step}: Total Loss={loss:.4f}, Reward={reward:.3f}, "
                  f"Policy Loss={policy_loss:.4f}, Value Loss={value_loss:.4f}")
        
        if step % 200 == 0 and step > 0:
            print(f"Saving checkpoint at step {step}...")
            model.save_pretrained(f"./stage2_lora_step{step}")
            torch.save(value_head.state_dict(), f"./stage2_lora_step{step}/value_head.pt")
    
    except Exception as e:
        print(f"Error at step {step}: {e}")
        import traceback
        traceback.print_exc()
        continue

print("Stage II completed!")
print("Saving final model...")
model.save_pretrained("./stage2_lora_final")
torch.save(value_head.state_dict(), "./stage2_lora_final/value_head.pt")
tokenizer.save_pretrained("./stage2_lora_final")

# ==============================================================================
# CELL 7: Inference Test
# ==============================================================================
# print("\n" + "="*80)
# print("TESTING TRAINED MODEL")
# print("="*80)

# def test_score_inference(problem: str):
#     """Test the trained SCoRe model"""
#     model.eval()
    
#     prompt = f"Problem: {problem}\n\nSolution:"
#     inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
#     with torch.no_grad():
#         output1 = model.generate(
#             **inputs,
#             max_new_tokens=MAX_NEW_TOKENS,
#             do_sample=True,
#             temperature=0.7,
#             pad_token_id=tokenizer.pad_token_id,
#         )
    
#     first_attempt = tokenizer.decode(output1[0], skip_special_tokens=True)
    
#     correction_prompt = f"{prompt}\n\nFirst attempt: {first_attempt}\n\nLet me reconsider:"
#     inputs2 = tokenizer(correction_prompt, return_tensors="pt").to(model.device)
    
#     with torch.no_grad():
#         output2 = model.generate(
#             **inputs2,
#             max_new_tokens=MAX_NEW_TOKENS,
#             do_sample=True,
#             temperature=0.7,
#             pad_token_id=tokenizer.pad_token_id,
#         )
    
#     second_attempt = tokenizer.decode(output2[0], skip_special_tokens=True)
    
#     print(f"Problem: {problem}")
#     print(f"\nFirst Attempt:\n{first_attempt}")
#     print(f"\nSecond Attempt (Self-Correction):\n{second_attempt}")
#     print("-" * 80)
    

# test_problems = [
#     "What is 144 + 256?",
#     "Solve for x: 2x - 8 = 14",
# ]

# for prob in test_problems:
#     test_score_inference(prob)

# print("\n✓ Training complete! LoRA adapters saved to ./stage2_lora_final")

In [None]:

# Verify the installation
import datasets
print(f"Successfully installed datasets version: {datasets.__version__}")

In [None]:
# ==============================================================================
# COMPREHENSIVE BENCHMARK EVALUATION
# ==============================================================================
print("\n" + "="*80)
print("BENCHMARK EVALUATION")
print("="*80)

import re
import json
from tqdm import tqdm
from typing import Dict, List, Tuple
import numpy as np
import datasets
print("version : ",datasets.__version__)
from datasets import load_dataset
# Configuration
tasks_to_run = ["gsm8k", "math", "mmlu", "hellaswag", "arc_challenge", "bbh"]
MAX_SAMPLES = 50 # Limit samples per task for faster evaluation
EVAL_BATCH_SIZE = 1  # Process one at a time for generation

def normalize_answer(text: str) -> str:
    """Normalize answer for comparison"""
    text = text.lower().strip()
    # Remove extra whitespace
    text = ' '.join(text.split())
    return text

def extract_numeric_answer(text: str) -> str:
    """Extract numeric answer from text"""
    # Look for patterns like "####" followed by number (GSM8K format)
    match = re.search(r'####\s*(-?\d+\.?\d*)', text)
    if match:
        return match.group(1)
    
    # Look for "the answer is X"
    match = re.search(r'(?:answer is|equals?)\s*[:\-]?\s*(-?\d+\.?\d*)', text.lower())
    if match:
        return match.group(1)
    
    # Extract last number in text
    numbers = re.findall(r'-?\d+\.?\d*', text)
    return numbers[-1] if numbers else ""

def extract_letter_answer(text: str) -> str:
    """Extract letter answer (A, B, C, D) from text"""
    # Look for explicit answer format
    match = re.search(r'(?:answer is|answer:|correct answer is)\s*([A-D])', text, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    
    # Look for standalone letter in parentheses or brackets
    match = re.search(r'[\(\[]([A-D])[\)\]]', text, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    
    # Last resort: first letter A-D that appears
    match = re.search(r'\b([A-D])\b', text, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    
    return ""

# ==============================================================================
# GSM8K Evaluation
# ==============================================================================
def evaluate_gsm8k(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on GSM8K dataset"""
    print("\n--- GSM8K Evaluation ---")
    ds = load_dataset("gsm8k", "main", split="test")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="GSM8K"):
        question = item["question"]
        answer = item["answer"].split("####")[-1].strip()
        
        prompt = f"Problem: {question}\n\nSolution:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, 
                          max_length=MAX_LENGTH).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        predicted = extract_numeric_answer(response)
        
        if normalize_answer(predicted) == normalize_answer(answer):
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"GSM8K Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# MATH Dataset Evaluation
# ==============================================================================
# def evaluate_math(model, tokenizer, num_samples=MAX_SAMPLES):
#     """Evaluate on MATH dataset"""
#     print("\n--- MATH Dataset Evaluation ---")
#     ds = load_dataset("math_dataset", "algebra__linear_1d", split="test")
#     ds = ds.select(range(min(num_samples, len(ds))))
    
#     correct = 0
#     total = 0
    
#     for item in tqdm(ds, desc="MATH"):
#         question = item["question"]
#         answer = item["answer"]
        
#         prompt = f"Problem: {question}\n\nSolution:"
#         inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
#                           max_length=MAX_LENGTH).to(model.device)
        
#         with torch.no_grad():
#             output = model.generate(
#                 **inputs,
#                 max_new_tokens=MAX_NEW_TOKENS,
#                 temperature=0.7,
#                 do_sample=True,
#                 pad_token_id=tokenizer.pad_token_id,
#             )
        
#         response = tokenizer.decode(output[0], skip_special_tokens=True)
#         predicted = extract_numeric_answer(response)
#         actual = extract_numeric_answer(answer)
        
#         if normalize_answer(predicted) == normalize_answer(actual):
#             correct += 1
#         total += 1
    
#     accuracy = correct / total if total > 0 else 0
#     print(f"MATH Accuracy: {accuracy:.4f} ({correct}/{total})")
#     return accuracy
def evaluate_math(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on MATH dataset"""
    print("\n--- MATH Dataset Evaluation ---")
    # Add trust_remote_code=True to allow custom dataset code
    ds = load_dataset("math_dataset", "algebra__linear_1d", split="test", trust_remote_code=True)
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="MATH"):
        question = item["question"]
        answer = item["answer"]
        
        prompt = f"Problem: {question}\n\nSolution:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                          max_length=MAX_LENGTH).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        predicted = extract_numeric_answer(response)
        actual = extract_numeric_answer(answer)
        
        if normalize_answer(predicted) == normalize_answer(actual):
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"MATH Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy
# ==============================================================================
# MMLU Evaluation
# ==============================================================================
def evaluate_mmlu(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on MMLU dataset"""
    print("\n--- MMLU Evaluation ---")
    ds = load_dataset("cais/mmlu", "abstract_algebra", split="test")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="MMLU"):
        question = item["question"]
        choices = item["choices"]
        answer_idx = item["answer"]
        
        # Format multiple choice
        choice_text = "\n".join([f"{chr(65+i)}. {c}" for i, c in enumerate(choices)])
        prompt = f"Question: {question}\n\n{choice_text}\n\nAnswer:"
        
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                          max_length=MAX_LENGTH).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=10,
                temperature=0.1,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        predicted_letter = extract_letter_answer(response)
        correct_letter = chr(65 + answer_idx)
        
        if predicted_letter == correct_letter:
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"MMLU Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# ARC Challenge Evaluation
# ==============================================================================
def evaluate_arc_challenge(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on ARC Challenge dataset"""
    print("\n--- ARC Challenge Evaluation ---")
    ds = load_dataset("ai2_arc", "ARC-Challenge", split="test")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="ARC-C"):
        question = item["question"]
        choices = item["choices"]["text"]
        labels = item["choices"]["label"]
        answer = item["answerKey"]
        
        # Format multiple choice
        choice_text = "\n".join([f"{labels[i]}. {choices[i]}" for i in range(len(choices))])
        prompt = f"Question: {question}\n\n{choice_text}\n\nAnswer:"
        
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                          max_length=MAX_LENGTH).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=10,
                temperature=0.1,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        predicted = extract_letter_answer(response)
        
        if predicted.upper() == answer.upper():
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"ARC Challenge Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# BBH Evaluation
# ==============================================================================
def evaluate_bbh(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on BBH (Big Bench Hard) dataset"""
    print("\n--- BBH Evaluation ---")
    ds = load_dataset("lukaemon/bbh", "boolean_expressions", split="test")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="BBH"):
        question = item["input"]
        answer = item["target"]
        
        prompt = f"Question: {question}\n\nAnswer:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                          max_length=MAX_LENGTH).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=50,
                temperature=0.1,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        
        # Check if answer is contained in response
        if normalize_answer(answer) in normalize_answer(response):
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"BBH Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# HellaSwag Evaluation
# ==============================================================================
def evaluate_hellaswag(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on HellaSwag dataset"""
    print("\n--- HellaSwag Evaluation ---")
    ds = load_dataset("hellaswag", split="validation")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="HellaSwag"):
        context = item["ctx"]
        endings = item["endings"]
        label = int(item["label"])
        
        # Score each ending
        scores = []
        for ending in endings:
            full_text = context + " " + ending
            inputs = tokenizer(full_text, return_tensors="pt", truncation=True,
                             max_length=MAX_LENGTH).to(model.device)
            
            with torch.no_grad():
                outputs = model(**inputs, labels=inputs["input_ids"])
                # Use negative loss as score (lower loss = better)
                scores.append(-outputs.loss.item())
        
        # Predict the ending with highest score
        predicted = np.argmax(scores)
        
        if predicted == label:
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"HellaSwag Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# Run All Evaluations
# ==============================================================================
def run_all_benchmarks(model, tokenizer, tasks=None):
    """Run all specified benchmarks"""
    if tasks is None:
        tasks = tasks_to_run
    
    results = {}
    
    # if "gsm8k" in tasks:
    #     results["gsm8k"] = evaluate_gsm8k(model, tokenizer)
    
    
    
    # if "mmlu" in tasks:
    #     results["mmlu"] = evaluate_mmlu(model, tokenizer)
    
    # if "arc_challenge" in tasks:
    #     results["arc_challenge"] = evaluate_arc_challenge(model, tokenizer)
    
    # if "bbh" in tasks:
    #     results["bbh"] = evaluate_bbh(model, tokenizer)
    
    # if "hellaswag" in tasks:
    #     results["hellaswag"] = evaluate_hellaswag(model, tokenizer)
    if "math" in tasks:
        results["math"] = evaluate_math(model, tokenizer)
    
    return results

# ==============================================================================
# Main Evaluation
# ==============================================================================
print("\n" + "="*80)
print("STARTING BENCHMARK EVALUATION")
print("="*80)

# Load the trained model

print("Loading Stage 2 model...")
model.eval()

# Run benchmarks
results = run_all_benchmarks(model, tokenizer, tasks_to_run)

# Print summary
print("\n" + "="*80)
print("BENCHMARK RESULTS SUMMARY")
print("="*80)
print(results)
for task, accuracy in results.items():
    print(f"{task.upper()}: {accuracy*100:.2f}%")

# Calculate average
avg_accuracy = np.mean(list(results.values()))
print(f"\nAVERAGE ACCURACY: {avg_accuracy*100:.2f}%")

# Save results
results_with_avg = {**results, "average": avg_accuracy}
with open("benchmark_results.json", "w") as f:
    json.dump(results_with_avg, f, indent=2)

print("\n✓ Results saved to benchmark_results.json")