In [None]:
# ==============================================================================
# CELL 0: Install/Upgrade Required Packages
# ==============================================================================
# Run this first, then restart kernel
"""
!pip uninstall -y transformers accelerate trl peft -y
!pip install transformers==4.36.2 --no-cache-dir
!pip install accelerate==0.25.0 --no-cache-dir
!pip install peft==0.7.1 --no-cache-dir
!pip install trl==0.7.4 --no-cache-dir
!pip install bitsandbytes==0.41.3 --no-cache-dir
!pip install datasets --no-cache-dir
"""

In [1]:
# ==============================================================================
# RUN THIS CELL FIRST - Install/Upgrade BitsAndBytes
# ==============================================================================
import subprocess
import sys

print("="*80)
print("INSTALLING/UPGRADING BITSANDBYTES")
print("="*80)

# Method 1: Try standard upgrade
try:
    print("\n1. Upgrading bitsandbytes to latest version...")
    result = subprocess.run(
        [sys.executable, "-m", "pip", "install", "-U", "bitsandbytes"],
        capture_output=True,
        text=True
    )
    print(result.stdout)
    if result.returncode == 0:
        print("Successfully upgraded bitsandbytes")
    else:
        print("Upgrade had some issues, trying alternative method...")
        raise Exception("Standard install failed")
except Exception as e:
    # Method 2: Try with specific version
    print("Trying to install specific version (0.41.0)...")
    try:
        subprocess.check_call(
            [sys.executable, "-m", "pip", "install", "--upgrade", "bitsandbytes==0.41.0"],
            stdout=subprocess.PIPE
        )
        print("Installed bitsandbytes 0.41.0")
    except:
        print("Could not install specific version")

# Verify installation
print("\n3. Verifying installation...")
try:
    import bitsandbytes as bnb
    print(f"✓ bitsandbytes version: {bnb.__version__}")
    print("✓ Import successful!")
except Exception as e:
    print(f"✗ Import failed: {e}")
    print("\n⚠ IMPORTANT: If bitsandbytes still doesn't work:")
    print("   - Set USE_QUANTIZATION = False in the config")
    print("   - The code will automatically fall back to FP16")
    print("   - You'll need more GPU memory but it will work")

# Also upgrade related packages
print("\n4. Upgrading related packages...")
try:
    subprocess.check_call(
        [sys.executable, "-m", "pip", "install", "-U", "-q", "accelerate", "transformers"],
    )
    print("Upgraded accelerate and transformers")
except:
    print("Could not upgrade all packages")

print("\n" + "="*80)
print("INSTALLATION COMPLETE ")
print("="*80)

INSTALLING/UPGRADING BITSANDBYTES

1. Upgrading bitsandbytes to latest version...
Collecting bitsandbytes
  Downloading bitsandbytes-0.48.1-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3,>=2.3->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3,>=2.3->bitsandbytes)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3,>=2.3->bitsandbytes)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3,>=2.3->bitsandbytes)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3,>=2.3->bitsandbytes)
  Downloading nvidia_cublas_cu1

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datasets 3.6.0 requires fsspec[http]<=2025.3.0,>=2023.1.0, but you have fsspec 2025.5.1 which is incompatible.


In [None]:
# # # ==============================================================================
# # # CELL 1: Install Required Packages
# # # ==============================================================================
# # print("Installing required packages...")
# # import subprocess
# # import sys

# # # Install bitsandbytes
# # try:
# #     import bitsandbytes
# #     print("✓ bitsandbytes already installed")
# # except ImportError:
# #     print("Installing bitsandbytes...")
# #     subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "bitsandbytes"])
# #     print("✓ bitsandbytes installed successfully")

# # ==============================================================================
# # CELL 2: Imports and Configuration
# # ==============================================================================
# import os
# import torch
# import torch.nn.functional as F
# from transformers import (
#     AutoTokenizer, 
#     AutoModelForCausalLM, 
#     BitsAndBytesConfig,
# )
# from peft import (
#     LoraConfig, 
#     get_peft_model, 
#     prepare_model_for_kbit_training,
#     PeftModel
# )
# from datasets import Dataset, load_dataset
# import numpy as np
# from typing import List, Dict, Tuple
# import re
# from torch.optim import AdamW
# from torch.utils.data import DataLoader

# # Try to get HF token
# try:
#     from kaggle_secrets import UserSecretsClient
#     hf = UserSecretsClient()
#     HF_TOKEN = hf.get_secret("HF_TOKEN")
# except:
#     # If not on Kaggle, try environment variable
#     HF_TOKEN = os.environ.get("HF_TOKEN", None)
#     if not HF_TOKEN:
#         print("Warning: No HF_TOKEN found. Using public models only.")

# # Configuration
# REPO = "O1-OPEN/OpenO1-Qwen-7B-v0.1"
# SUBFOLDER = "checkpoint-1000"
# USE_SUBFOLDER = True
# USE_QUANTIZATION = True  # Set to False if you have enough GPU memory
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# MAX_LENGTH = 512
# MAX_NEW_TOKENS = 256

# # Training hyperparameters
# STAGE1_EPOCHS = 3
# STAGE1_BATCH_SIZE = 2
# STAGE1_GRAD_ACCUM = 8
# STAGE1_LR = 2e-4
# KL_COEF = 0.1

# STAGE2_STEPS = 1000
# STAGE2_BATCH_SIZE = 2
# STAGE2_LR = 1e-5
# CORRECTION_BONUS = 1.0

# os.environ["TOKENIZERS_PARALLELISM"] = "false"

# print(f"Using device: {DEVICE}")
# print(f"Available GPUs: {torch.cuda.device_count()}")
# if torch.cuda.is_available():
#     print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# # ==============================================================================
# # CELL 3: Load Tokenizer
# # ==============================================================================
# print("\nLoading tokenizer...")
# try:
#     if USE_SUBFOLDER:
#         tokenizer = AutoTokenizer.from_pretrained(
#             REPO, 
#             subfolder=SUBFOLDER,
#             trust_remote_code=True,
#             token=HF_TOKEN
#         )
#     else:
#         tokenizer = AutoTokenizer.from_pretrained(
#             REPO,
#             trust_remote_code=True,
#             token=HF_TOKEN
#         )
# except Exception as e:
#     print(f"Error loading from {REPO}: {e}")
#     print("Falling back to Qwen2.5-1.5B-Instruct as alternative...")
#     REPO = "Qwen/Qwen2.5-1.5B-Instruct"
#     USE_SUBFOLDER = False
#     tokenizer = AutoTokenizer.from_pretrained(REPO, trust_remote_code=True)

# # Set pad token
# if tokenizer.pad_token is None:
#     tokenizer.pad_token = tokenizer.eos_token
#     tokenizer.pad_token_id = tokenizer.eos_token_id

# print(f"✓ Tokenizer loaded from: {REPO}")

# # ==============================================================================
# # CELL 4: Load Models with Quantization
# # ==============================================================================
# print("\nLoading models...")

# # Setup quantization config if enabled
# if USE_QUANTIZATION and DEVICE == "cuda":
#     print("Setting up 4-bit quantization...")
#     bnb_config = BitsAndBytesConfig(
#         load_in_4bit=True,
#         bnb_4bit_compute_dtype=torch.float16,
#         bnb_4bit_use_double_quant=True,
#         bnb_4bit_quant_type="nf4",
#     )
#     model_kwargs = {
#         "quantization_config": bnb_config,
#         "device_map": "auto",
#         "trust_remote_code": True,
#         "token": HF_TOKEN,
#     }
# else:
#     print("Loading model without quantization...")
#     model_kwargs = {
#         "device_map": "auto",
#         "trust_remote_code": True,
#         "token": HF_TOKEN,
#         "torch_dtype": torch.float16 if DEVICE == "cuda" else torch.float32,
#     }

# if USE_SUBFOLDER and "Qwen2.5" not in REPO:
#     model_kwargs["subfolder"] = SUBFOLDER

# # Load base model
# print("Loading base model (this may take several minutes)...")
# try:
#     base_model = AutoModelForCausalLM.from_pretrained(REPO, **model_kwargs)
#     base_model.config.use_cache = False
#     print(f"✓ Base model loaded")
#     if hasattr(base_model, 'hf_device_map'):
#         print(f"Device map: {base_model.hf_device_map}")
# except Exception as e:
#     print(f"Error loading model: {e}")
#     raise

# # Load reference model (frozen)
# print("Loading reference model (frozen)...")
# ref_model = AutoModelForCausalLM.from_pretrained(REPO, **model_kwargs)
# ref_model.eval()
# for p in ref_model.parameters():
#     p.requires_grad = False
# print("✓ Reference model loaded and frozen")

# # ==============================================================================
# # CELL 5: Prepare Model with LoRA
# # ==============================================================================
# print("\nPreparing model for training...")
# if USE_QUANTIZATION and DEVICE == "cuda":
#     base_model = prepare_model_for_kbit_training(base_model)
#     print("✓ Model prepared for k-bit training")

# # LoRA configuration - adjust target_modules based on model architecture
# # Common patterns: "q_proj", "k_proj", "v_proj", "o_proj" for most models
# lora_config = LoraConfig(
#     r=16,
#     lora_alpha=32,
#     target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
#     lora_dropout=0.05,
#     bias="none",
#     task_type="CAUSAL_LM"
# )

# print("Attaching LoRA adapters...")
# model = get_peft_model(base_model, lora_config)
# model.print_trainable_parameters()

# # ==============================================================================
# # CELL 6: Prepare Dataset
# # ==============================================================================
# def create_sample_dataset():
#     """Create a small sample dataset for testing"""
#     samples = [
#         {
#             "problem": "What is 25 + 17?",
#             "first_attempt": "Let me calculate: 25 + 17 = 41",
#             "second_attempt": "Let me recalculate: 25 + 17 = 42",
#             "answer": "42",
#             "is_correct_1": False,
#             "is_correct_2": True
#         },
#         {
#             "problem": "Solve: 3x + 5 = 14",
#             "first_attempt": "3x = 14 - 5 = 9, so x = 4",
#             "second_attempt": "3x = 14 - 5 = 9, so x = 3",
#             "answer": "3",
#             "is_correct_1": False,
#             "is_correct_2": True
#         },
#         {
#             "problem": "What is 15 * 8?",
#             "first_attempt": "15 * 8 = 110",
#             "second_attempt": "Let me recalculate: 15 * 8 = 120",
#             "answer": "120",
#             "is_correct_1": False,
#             "is_correct_2": True
#         },
#         {
#             "problem": "If y - 7 = 12, what is y?",
#             "first_attempt": "y = 12 + 7 = 20",
#             "second_attempt": "y = 12 + 7 = 19",
#             "answer": "19",
#             "is_correct_1": False,
#             "is_correct_2": True
#         },
#     ] * 25  # Repeat to create larger dataset
#     return Dataset.from_list(samples)

# print("\nCreating/loading dataset...")
# train_dataset = create_sample_dataset()
# print(f"✓ Dataset size: {len(train_dataset)}")
# print(f"Sample: {train_dataset[0]}")

# print("\n" + "="*80)
# print("SETUP COMPLETE! Ready for training.")
# print("="*80)

In [2]:


# ==============================================================================
# CELL 1: Imports and Configuration
# ==============================================================================
import os
import torch
import torch.nn.functional as F
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    BitsAndBytesConfig,
)
from peft import (
    LoraConfig, 
    get_peft_model, 
    prepare_model_for_kbit_training,
    PeftModel
)
# Note: We'll implement a simplified PPO instead of using trl's PPOTrainer to avoid import issues
from datasets import Dataset
from kaggle_secrets import UserSecretsClient
import numpy as np
from typing import List, Dict, Tuple
import re
from torch.optim import AdamW
from torch.utils.data import DataLoader

# Get HF token
hf = UserSecretsClient()
HF_TOKEN = hf.get_secret("HF_TOKEN")

# Configuration
REPO = "O1-OPEN/OpenO1-Qwen-7B-v0.1"
SUBFOLDER = "checkpoint-1000"
USE_SUBFOLDER = True
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LENGTH = 512
MAX_NEW_TOKENS = 256

# Training hyperparameters
STAGE1_EPOCHS = 3
STAGE1_BATCH_SIZE = 2
STAGE1_GRAD_ACCUM = 8
STAGE1_LR = 2e-4
KL_COEF = 0.1  # KL penalty coefficient

STAGE2_STEPS = 1000
STAGE2_BATCH_SIZE = 2
STAGE2_LR = 1e-5
CORRECTION_BONUS = 1.0  # Bonus when second attempt > first

os.environ["TOKENIZERS_PARALLELISM"] = "false"

print(f"Using device: {DEVICE}")
print(f"Available GPUs: {torch.cuda.device_count()}")

# ==============================================================================
# CELL 2: Load Tokenizer and Models (4-bit Quantization)
# ==============================================================================
print("Loading tokenizer...")
if USE_SUBFOLDER:
    tokenizer = AutoTokenizer.from_pretrained(
        REPO, 
        subfolder=SUBFOLDER,
        trust_remote_code=True,
        token=HF_TOKEN
    )
else:
    tokenizer = AutoTokenizer.from_pretrained(
        REPO,
        trust_remote_code=True,
        token=HF_TOKEN
    )

# Set pad token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

print("Setting up 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

print("Loading base model (this may take several minutes)...")
model_kwargs = {
    "quantization_config": bnb_config,
    "device_map": "auto",  # Spreads across both T4s
    "trust_remote_code": True,
    "token": HF_TOKEN,
}
if USE_SUBFOLDER:
    model_kwargs["subfolder"] = SUBFOLDER

base_model = AutoModelForCausalLM.from_pretrained(REPO, **model_kwargs)
base_model.config.use_cache = False  # Required for gradient checkpointing

print("Loading reference model (frozen)...")
ref_model = AutoModelForCausalLM.from_pretrained(REPO, **model_kwargs)
ref_model.eval()
for p in ref_model.parameters():
    p.requires_grad = False

print(f"Base model device map: {base_model.hf_device_map}")

# ==============================================================================
# CELL 3: Prepare Model with LoRA
# ==============================================================================
print("Preparing model for k-bit training...")
base_model = prepare_model_for_kbit_training(base_model)

# LoRA configuration
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # Qwen modules
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

print("Attaching LoRA adapters...")
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()

# ==============================================================================
# CELL 4: Prepare Dataset (Example Format)
# ==============================================================================
# Your dataset should have: problem, first_attempt, second_attempt, correctness
# Format: {"problem": "...", "first_attempt": "...", "second_attempt": "...", "answer": "...", "is_correct_1": bool, "is_correct_2": bool}
def create_sample_dataset():
    """Create a small sample dataset for testing"""
    samples = [
        {
            "problem": "What is 25 + 17?",
            "first_attempt": "Let me calculate: 25 + 17 = 41",
            "second_attempt": "Let me recalculate: 25 + 17 = 42",
            "answer": "42",
            "is_correct_1": False,
            "is_correct_2": True
        },
        {
            "problem": "Solve: 3x + 5 = 14",
            "first_attempt": "3x = 14 - 5 = 9, so x = 4",
            "second_attempt": "3x = 14 - 5 = 9, so x = 3",
            "answer": "3",
            "is_correct_1": False,
            "is_correct_2": True
        },
        {
            "problem": "What is 15 * 8?",
            "first_attempt": "15 * 8 = 110",
            "second_attempt": "Let me recalculate: 15 * 8 = 120",
            "answer": "120",
            "is_correct_1": False,
            "is_correct_2": True
        },
        {
            "problem": "If y - 7 = 12, what is y?",
            "first_attempt": "y = 12 + 7 = 20",
            "second_attempt": "y = 12 + 7 = 19",
            "answer": "19",
            "is_correct_1": False,
            "is_correct_2": True
        },
    ] * 25  # Repeat to create larger dataset
    return Dataset.from_list(samples)

# def create_sample_dataset():
#     """Create a small sample dataset for testing"""
#     samples = [
#         {
#             "problem": "What is 25 + 17?",
#             "first_attempt": "Let me calculate: 25 + 17 = 41",
#             "second_attempt": "Let me recalculate: 25 + 17 = 42",
#             "answer": "42",
#             "is_correct_1": False,
#             "is_correct_2": True
#         },
#         {
#             "problem": "Solve: 3x + 5 = 14",
#             "first_attempt": "3x = 14 - 5 = 9, so x = 4",
#             "second_attempt": "3x = 14 - 5 = 9, so x = 3",
#             "answer": "3",
#             "is_correct_1": False,
#             "is_correct_2": True
#         },
#         # Add more examples...
#     ]
#     return Dataset.from_list(samples)

# Load your actual dataset here
print("Creating/loading dataset...")
train_dataset = create_sample_dataset()
print(f"Dataset size: {len(train_dataset)}")



2025-10-04 07:06:41.084717: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759561601.252781      74 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759561601.303558      74 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Using device: cuda
Available GPUs: 2
Loading tokenizer...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

Setting up 4-bit quantization...
Loading base model (this may take several minutes)...


config.json:   0%|          | 0.00/730 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

checkpoint-1000/model-00004-of-00004.saf(…):   0%|          | 0.00/1.09G [00:00<?, ?B/s]

checkpoint-1000/model-00001-of-00004.saf(…):   0%|          | 0.00/4.88G [00:00<?, ?B/s]

checkpoint-1000/model-00002-of-00004.saf(…):   0%|          | 0.00/4.93G [00:00<?, ?B/s]

checkpoint-1000/model-00003-of-00004.saf(…):   0%|          | 0.00/4.33G [00:00<?, ?B/s]

{"timestamp":"2025-10-04T07:07:57.452672Z","level":"WARN","fields":{"message":"Status Code: 504. Retrying...","request_id":""},"filename":"/home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs","line_number":236}
{"timestamp":"2025-10-04T07:07:57.452739Z","level":"WARN","fields":{"message":"Retry attempt #0. Sleeping 2.031185106s before the next attempt"},"filename":"/root/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/reqwest-retry-0.7.0/src/middleware.rs","line_number":171}


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]

Loading reference model (frozen)...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Base model device map: {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 1, 'model.layers.8': 1, 'model.layers.9': 1, 'model.layers.10': 1, 'model.layers.11': 1, 'model.layers.12': 1, 'model.layers.13': 1, 'model.layers.14': 1, 'model.layers.15': 1, 'model.layers.16': 1, 'model.layers.17': 1, 'model.layers.18': 1, 'model.layers.19': 1, 'model.layers.20': 1, 'model.layers.21': 1, 'model.layers.22': 1, 'model.layers.23': 1, 'model.layers.24': 1, 'model.layers.25': 1, 'model.layers.26': 1, 'model.layers.27': 1, 'model.norm': 1, 'model.rotary_emb': 1, 'lm_head': 1}
Preparing model for k-bit training...
Attaching LoRA adapters...
trainable params: 10,092,544 || all params: 7,625,709,056 || trainable%: 0.1323
Creating/loading dataset...
Dataset size: 100


In [4]:
# STAGE I - Supervised Fine-tuning with KL Penalty (Fixed)
# ==============================================================================
print("\n" + "="*80)
print("STAGE I: Supervised Fine-tuning with KL Penalty (Fixed)")
print("="*80)
from torch.utils.data import DataLoader

# --- Step 1: Prepare DataLoader ---
train_loader = DataLoader(train_dataset, batch_size=STAGE1_BATCH_SIZE, shuffle=True)

# --- Step 2: Define KL Divergence ---
# def compute_kl_divergence(logits_policy, logits_ref):
#     """
#     Compute KL divergence between policy and reference model
#     """
#     log_probs_policy = F.log_softmax(logits_policy, dim=-1)
#     probs_ref = F.softmax(logits_ref, dim=-1)
#     kl = (probs_ref * (probs_ref.log() - log_probs_policy)).sum(dim=-1)
#     return kl.mean()
import torch
import torch.nn.functional as F

def compute_kl_divergence(policy_logits, ref_logits, attention_mask=None, eps=1e-12):
    """
    Compute KL(P_ref || Q_policy) per token with masking and numeric stability,
    returning the mean KL per sample.

    Args:
      policy_logits: Tensor [batch, seq_len, vocab]
      ref_logits:    Tensor [batch, seq_len, vocab]
      attention_mask: Optional Tensor [batch, seq_len] with 1 for real tokens, 0 for padding.
      eps: small value to avoid div/zero (not usually needed with log_softmax but kept for safety).

    Returns:
      scalar tensor: mean KL across non-padding tokens (averaged over batch)
    """
    # ensure shapes match
    assert policy_logits.shape == ref_logits.shape, f"policy {policy_logits.shape} vs ref {ref_logits.shape}"

    # stable log-probs
    log_probs_policy = F.log_softmax(policy_logits, dim=-1)   # log Q
    log_probs_ref = F.log_softmax(ref_logits, dim=-1)         # log P

    # probs for P (ref) via exp(log_probs_ref) — numerically stable
    probs_ref = log_probs_ref.exp()

    # per-token KL: sum_vocab P * (log P - log Q)
    kl_per_token = (probs_ref * (log_probs_ref - log_probs_policy)).sum(dim=-1)  # [batch, seq_len]

    if attention_mask is not None:
        # cast mask to same dtype
        mask = attention_mask.to(kl_per_token.dtype)  # [batch, seq_len]
        # zero out padding tokens, compute per-sample mean over valid tokens
        valid_tokens_per_sample = mask.sum(dim=1).clamp_min(1.0)  # avoid div by 0
        kl_per_sample = (kl_per_token * mask).sum(dim=1) / valid_tokens_per_sample
    else:
        # mean over seq_len when no mask provided
        kl_per_sample = kl_per_token.mean(dim=1)

    return kl_per_sample.mean()  # scalar

# --- Step 3: Stage I training step ---
def stage1_train_step(batch, model, ref_model, optimizer, tokenizer):
    model.train()
    
    # Prepare prompts and targets
    prompts = [f"Problem: {p}\n\nFirst attempt: {a1}\n\nLet me reconsider:" 
               for p, a1 in zip(batch["problem"], batch["first_attempt"])]
    targets = [f"{t}" for t in batch["second_attempt"]]
    
    # Combine prompts and targets for proper tokenization
    full_texts = [p + t for p, t in zip(prompts, targets)]
    
    # Tokenize the combined text
    inputs = tokenizer(full_texts, padding='longest', truncation=True,
                       max_length=MAX_LENGTH, return_tensors="pt")
    input_ids = inputs["input_ids"].to(model.device)
    attention_mask = inputs["attention_mask"].to(model.device)
    
    # Create labels by tokenizing prompts to find where targets start
    prompt_inputs = tokenizer(prompts, padding='longest', truncation=True,
                              max_length=MAX_LENGTH, return_tensors="pt")
    prompt_lengths = (prompt_inputs["attention_mask"].sum(dim=1)).tolist()
    
    # Create labels: -100 for prompt tokens, actual tokens for target
    labels = input_ids.clone()
    for i, prompt_len in enumerate(prompt_lengths):
        labels[i, :prompt_len] = -100
    
    # Replace padding tokens with -100
    labels[labels == tokenizer.pad_token_id] = -100
    
    # Forward pass
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    loss_lm = outputs.loss
    
    # --- KL penalty on first attempt ---
    first_prompts = [f"Problem: {p}\n\nSolution:" for p in batch["problem"]]
    first_inputs = tokenizer(first_prompts, padding='longest', truncation=True,
                             max_length=MAX_LENGTH, return_tensors="pt")
    first_inputs = {k: v.to(model.device) for k, v in first_inputs.items() 
                    if k in ["input_ids", "attention_mask"]}
    
    with torch.no_grad():
        ref_outputs = ref_model(**first_inputs)
        ref_logits = ref_outputs.logits
    
    policy_outputs = model(**first_inputs)
    policy_logits = policy_outputs.logits
    
    kl_loss = compute_kl_divergence(policy_logits, ref_logits)
    
    # --- Total loss ---
    loss = loss_lm + KL_COEF * kl_loss
    
    # Backward and optimizer step
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    return loss.item(), loss_lm.item(), kl_loss.item()

# --- Step 4: Training loop ---
print("Starting Stage I training...")
optimizer = torch.optim.AdamW(model.parameters(), lr=STAGE1_LR)

for epoch in range(STAGE1_EPOCHS):
    total_loss = 0
    total_lm_loss = 0
    total_kl_loss = 0
    
    for step, batch in enumerate(train_loader):
        loss, lm_loss, kl_loss = stage1_train_step(batch, model, ref_model, optimizer, tokenizer)
        total_loss += loss
        total_lm_loss += lm_loss
        total_kl_loss += kl_loss
        
        if step % 10 == 0:
            print(f"Epoch {epoch+1}, Step {step}: "
                  f"Loss={loss:.4f}, LM={lm_loss:.4f}, KL={kl_loss:.4f}")
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} completed. Avg Loss: {avg_loss:.4f}")

# --- Step 5: Save checkpoint ---
print("Stage I completed! Saving checkpoint...")
model.save_pretrained("./stage1_lora")
tokenizer.save_pretrained("./stage1_lora")
print("Stage I checkpoint saved at ./stage1_lora")


STAGE I: Supervised Fine-tuning with KL Penalty (Fixed)
Starting Stage I training...
Epoch 1, Step 0: Loss=0.0001, LM=0.0000, KL=0.0007
Epoch 1, Step 10: Loss=0.0043, LM=0.0000, KL=0.0427
Epoch 1, Step 20: Loss=0.0003, LM=0.0001, KL=0.0023
Epoch 1, Step 30: Loss=0.0012, LM=0.0000, KL=0.0113
Epoch 1, Step 40: Loss=0.0003, LM=0.0000, KL=0.0029
Epoch 1 completed. Avg Loss: 0.0020
Epoch 2, Step 0: Loss=0.0002, LM=0.0000, KL=0.0016
Epoch 2, Step 10: Loss=0.0049, LM=0.0000, KL=0.0485
Epoch 2, Step 20: Loss=0.0006, LM=0.0000, KL=0.0059
Epoch 2, Step 30: Loss=0.0002, LM=0.0000, KL=0.0021
Epoch 2, Step 40: Loss=0.0003, LM=0.0000, KL=0.0027
Epoch 2 completed. Avg Loss: 0.0006
Epoch 3, Step 0: Loss=0.0000, LM=0.0000, KL=0.0002
Epoch 3, Step 10: Loss=0.0001, LM=0.0000, KL=0.0003
Epoch 3, Step 20: Loss=0.0000, LM=0.0000, KL=0.0002
Epoch 3, Step 30: Loss=0.0001, LM=0.0000, KL=0.0010
Epoch 3, Step 40: Loss=0.0001, LM=0.0000, KL=0.0011
Epoch 3 completed. Avg Loss: 0.0002
Stage I completed! Saving che

In [5]:
# CELL 6: Stage II - Simplified REINFORCE with Correction Reward (FIXED)
# ==============================================================================
print("\n" + "="*80)
print("STAGE II: REINFORCE Training with Correction Rewards")
print("="*80)

def extract_answer(text: str) -> str:
    """Extract numeric answer from text"""
    numbers = re.findall(r'-?\d+\.?\d*', text)
    return numbers[-1] if numbers else ""

def compute_reward(problem: str, first_attempt: str, second_attempt: str, 
                   ground_truth: str) -> float:
    """
    Compute reward for SCoRe:
    - Base reward for correctness
    - Bonus if second attempt is better than first
    """
    ans1 = extract_answer(first_attempt)
    ans2 = extract_answer(second_attempt)
    gt = ground_truth.strip()
    
    correct_1 = (ans1 == gt)
    correct_2 = (ans2 == gt)
    
    reward = 1.0 if correct_2 else 0.0
    
    if not correct_1 and correct_2:
        reward += CORRECTION_BONUS
    
    if correct_1 and not correct_2:
        reward -= CORRECTION_BONUS
    
    return reward

class SimpleValueHead(torch.nn.Module):
    """Simple value head for policy gradient"""
    def __init__(self, hidden_size):
        super().__init__()
        self.value_head = torch.nn.Linear(hidden_size, 1)
    
    def forward(self, hidden_states):
        return self.value_head(hidden_states[:, -1, :]).squeeze(-1)

# Add value head to model
print("Adding value head to model...")
hidden_size = model.config.hidden_size
value_head = SimpleValueHead(hidden_size).to(model.device)
optimizer_rl = AdamW(
    list(model.parameters()) + list(value_head.parameters()),
    lr=STAGE2_LR
)

def reinforce_step(model, value_head, ref_model, tokenizer, batch, optimizer):
    """Single REINFORCE training step"""
    model.train()
    value_head.train()
    
    problem = batch["problem"]
    ground_truth = batch["answer"]
    
    # Generate first attempt
    prompt = f"Problem: {problem}\n\nSolution:"
    inputs = tokenizer(prompt, return_tensors="pt", padding=True,
                      truncation=True, max_length=MAX_LENGTH).to(model.device)
    
    with torch.no_grad():
        outputs_first = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            top_p=0.95,
            temperature=0.7,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    first_attempt = tokenizer.decode(outputs_first[0], skip_special_tokens=True)
    
    # Generate second attempt
    correction_prompt = f"Problem: {problem}\n\nSolution:\n\nFirst attempt: {first_attempt}\n\nLet me reconsider:"
    correction_inputs = tokenizer(correction_prompt, return_tensors="pt", 
                                  padding=True, truncation=True, 
                                  max_length=MAX_LENGTH).to(model.device)
    
    # Sample from model (with generation tracking)
    outputs_second = model.generate(
        **correction_inputs,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=True,
        top_p=0.95,
        temperature=0.7,
        pad_token_id=tokenizer.pad_token_id,
    )
    
    second_attempt = tokenizer.decode(outputs_second[0], skip_special_tokens=True)
    
    # Compute reward
    reward = compute_reward(problem, first_attempt, second_attempt, ground_truth)
    
    # Get generated tokens
    generated_tokens = outputs_second[0][correction_inputs.input_ids.shape[1]:]
    
    # Forward pass with gradients enabled and hidden states output
    with torch.enable_grad():
        full_outputs = model(
            input_ids=outputs_second,
            attention_mask=torch.ones_like(outputs_second),
            output_hidden_states=True  # CRITICAL FIX: Enable hidden states
        )
        logits = full_outputs.logits
        
        # Compute log probs for generated tokens
        log_probs = F.log_softmax(logits[0, correction_inputs.input_ids.shape[1]-1:-1, :], dim=-1)
        
        # Ensure we have enough generated tokens
        num_gen_tokens = min(len(generated_tokens), log_probs.shape[0])
        if num_gen_tokens == 0:
            return 0.0, reward, 0.0, 0.0
        
        generated_tokens = generated_tokens[:num_gen_tokens]
        selected_log_probs = log_probs[:num_gen_tokens, generated_tokens]
        
        # Compute value estimate (now hidden_states is available)
        value_estimate = value_head(full_outputs.hidden_states[-1])
        
        # REINFORCE loss
        advantage = reward - value_estimate.detach()
        policy_loss = -(selected_log_probs.mean() * advantage)
        value_loss = F.mse_loss(value_estimate, torch.tensor([reward]).float().to(model.device))
        
        # KL penalty with reference model
        with torch.no_grad():
            ref_outputs = ref_model(**correction_inputs)
            ref_logits = ref_outputs.logits
        
        policy_outputs_kl = model(**correction_inputs)
        policy_logits = policy_outputs_kl.logits
        kl_loss = compute_kl_divergence(policy_logits, ref_logits)
        
        # Total loss
        total_loss = policy_loss + 0.5 * value_loss + 0.01 * kl_loss
    
    # Backward pass
    optimizer.zero_grad()
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(
        list(model.parameters()) + list(value_head.parameters()), 
        1.0
    )
    optimizer.step()
    
    return total_loss.item(), reward, policy_loss.item(), value_loss.item()

# Stage II Training Loop
print("Starting Stage II REINFORCE training...")

for step in range(10): #STAGE2_STEPS
    idx = step % len(train_dataset)
    batch = train_dataset[idx]
    
    try:
        loss, reward, policy_loss, value_loss = reinforce_step(
            model, value_head, ref_model, tokenizer, batch, optimizer_rl
        )
        
        if step % 50 == 0:
            print(f"Step {step}: Total Loss={loss:.4f}, Reward={reward:.3f}, "
                  f"Policy Loss={policy_loss:.4f}, Value Loss={value_loss:.4f}")
        
        if step % 200 == 0 and step > 0:
            print(f"Saving checkpoint at step {step}...")
            model.save_pretrained(f"./stage2_lora_step{step}")
            torch.save(value_head.state_dict(), f"./stage2_lora_step{step}/value_head.pt")
    
    except Exception as e:
        print(f"Error at step {step}: {e}")
        import traceback
        traceback.print_exc()
        continue

print("Stage II completed!")
print("Saving final model...")
model.save_pretrained("./stage2_lora_final")
torch.save(value_head.state_dict(), "./stage2_lora_final/value_head.pt")
tokenizer.save_pretrained("./stage2_lora_final")

# ==============================================================================
# CELL 7: Inference Test
# ==============================================================================
print("\n" + "="*80)
print("TESTING TRAINED MODEL")
print("="*80)

def test_score_inference(problem: str):
    """Test the trained SCoRe model"""
    model.eval()
    
    prompt = f"Problem: {problem}\n\nSolution:"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        output1 = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    first_attempt = tokenizer.decode(output1[0], skip_special_tokens=True)
    
    correction_prompt = f"{prompt}\n\nFirst attempt: {first_attempt}\n\nLet me reconsider:"
    inputs2 = tokenizer(correction_prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        output2 = model.generate(
            **inputs2,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    second_attempt = tokenizer.decode(output2[0], skip_special_tokens=True)
    
    print(f"Problem: {problem}")
    print(f"\nFirst Attempt:\n{first_attempt}")
    print(f"\nSecond Attempt (Self-Correction):\n{second_attempt}")
    print("-" * 80)
    

test_problems = [
    "What is 144 + 256?",
    "Solve for x: 2x - 8 = 14",
]

for prob in test_problems:
    test_score_inference(prob)

print("\n✓ Training complete! LoRA adapters saved to ./stage2_lora_final")

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Caching is incompatible with gradient checkpointing in Qwen2DecoderLayer. Setting `past_key_values=None`.



STAGE II: REINFORCE Training with Correction Rewards
Adding value head to model...
Starting Stage II REINFORCE training...




Step 0: Total Loss=1.5820, Reward=0.000, Policy Loss=-2.4214, Value Loss=8.0055
Stage II completed!
Saving final model...

TESTING TRAINED MODEL
Problem: What is 144 + 256?

First Attempt:
Problem: What is 144 + 256?

Solution: To find the sum of 144 and 256, we can add them step by step. 

First, let's align the numbers vertically for clarity:

```
  144
+ 256
-------
```

Now, let's add each column from right to left:

1. **Units Column**: 4 (from 144) + 6 (from 256) = 10. We write down 0 and carry over 1.
2. **Tens Column**: 4 (from 144) + 5 (from 256) = 9. Adding the carried over 1 makes it 10. We write down 0 and carry over 1 again.
3. **Hundreds Column**: 1 (from 144) + 2 (from 256) = 3. Adding the carried over 1 makes it 4.

Putting it all together, the sum is:

```
  144
+ 256
-------
  400
```

Therefore, 144 + 256 equals 400. This method ensures that each place value is

Second Attempt (Self-Correction):
Problem: What is 144 + 256?

Solution:

First attempt: Problem: What is 

In [6]:
# ==============================================================================
# COMPREHENSIVE BENCHMARK EVALUATION
# ==============================================================================
print("\n" + "="*80)
print("BENCHMARK EVALUATION")
print("="*80)

import re
import json
from tqdm import tqdm
from typing import Dict, List, Tuple
import numpy as np
from datasets import load_dataset
# Configuration
tasks_to_run = ["gsm8k", "math", "mmlu", "hellaswag", "arc_challenge", "bbh"]
MAX_SAMPLES = 50 # Limit samples per task for faster evaluation
EVAL_BATCH_SIZE = 1  # Process one at a time for generation

def normalize_answer(text: str) -> str:
    """Normalize answer for comparison"""
    text = text.lower().strip()
    # Remove extra whitespace
    text = ' '.join(text.split())
    return text

def extract_numeric_answer(text: str) -> str:
    """Extract numeric answer from text"""
    # Look for patterns like "####" followed by number (GSM8K format)
    match = re.search(r'####\s*(-?\d+\.?\d*)', text)
    if match:
        return match.group(1)
    
    # Look for "the answer is X"
    match = re.search(r'(?:answer is|equals?)\s*[:\-]?\s*(-?\d+\.?\d*)', text.lower())
    if match:
        return match.group(1)
    
    # Extract last number in text
    numbers = re.findall(r'-?\d+\.?\d*', text)
    return numbers[-1] if numbers else ""

def extract_letter_answer(text: str) -> str:
    """Extract letter answer (A, B, C, D) from text"""
    # Look for explicit answer format
    match = re.search(r'(?:answer is|answer:|correct answer is)\s*([A-D])', text, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    
    # Look for standalone letter in parentheses or brackets
    match = re.search(r'[\(\[]([A-D])[\)\]]', text, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    
    # Last resort: first letter A-D that appears
    match = re.search(r'\b([A-D])\b', text, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    
    return ""

# ==============================================================================
# GSM8K Evaluation
# ==============================================================================
def evaluate_gsm8k(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on GSM8K dataset"""
    print("\n--- GSM8K Evaluation ---")
    ds = load_dataset("gsm8k", "main", split="test")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="GSM8K"):
        question = item["question"]
        answer = item["answer"].split("####")[-1].strip()
        
        prompt = f"Problem: {question}\n\nSolution:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, 
                          max_length=MAX_LENGTH).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        predicted = extract_numeric_answer(response)
        
        if normalize_answer(predicted) == normalize_answer(answer):
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"GSM8K Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# MATH Dataset Evaluation
# ==============================================================================
def evaluate_math(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on MATH dataset"""
    print("\n--- MATH Dataset Evaluation ---")
    ds = load_dataset("math_dataset", "algebra__linear_1d", split="test")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="MATH"):
        question = item["question"]
        answer = item["answer"]
        
        prompt = f"Problem: {question}\n\nSolution:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                          max_length=MAX_LENGTH).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        predicted = extract_numeric_answer(response)
        actual = extract_numeric_answer(answer)
        
        if normalize_answer(predicted) == normalize_answer(actual):
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"MATH Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# MMLU Evaluation
# ==============================================================================
def evaluate_mmlu(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on MMLU dataset"""
    print("\n--- MMLU Evaluation ---")
    ds = load_dataset("cais/mmlu", "abstract_algebra", split="test")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="MMLU"):
        question = item["question"]
        choices = item["choices"]
        answer_idx = item["answer"]
        
        # Format multiple choice
        choice_text = "\n".join([f"{chr(65+i)}. {c}" for i, c in enumerate(choices)])
        prompt = f"Question: {question}\n\n{choice_text}\n\nAnswer:"
        
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                          max_length=MAX_LENGTH).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=10,
                temperature=0.1,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        predicted_letter = extract_letter_answer(response)
        correct_letter = chr(65 + answer_idx)
        
        if predicted_letter == correct_letter:
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"MMLU Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# ARC Challenge Evaluation
# ==============================================================================
def evaluate_arc_challenge(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on ARC Challenge dataset"""
    print("\n--- ARC Challenge Evaluation ---")
    ds = load_dataset("ai2_arc", "ARC-Challenge", split="test")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="ARC-C"):
        question = item["question"]
        choices = item["choices"]["text"]
        labels = item["choices"]["label"]
        answer = item["answerKey"]
        
        # Format multiple choice
        choice_text = "\n".join([f"{labels[i]}. {choices[i]}" for i in range(len(choices))])
        prompt = f"Question: {question}\n\n{choice_text}\n\nAnswer:"
        
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                          max_length=MAX_LENGTH).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=10,
                temperature=0.1,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        predicted = extract_letter_answer(response)
        
        if predicted.upper() == answer.upper():
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"ARC Challenge Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# BBH Evaluation
# ==============================================================================
def evaluate_bbh(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on BBH (Big Bench Hard) dataset"""
    print("\n--- BBH Evaluation ---")
    ds = load_dataset("lukaemon/bbh", "boolean_expressions", split="test")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="BBH"):
        question = item["input"]
        answer = item["target"]
        
        prompt = f"Question: {question}\n\nAnswer:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                          max_length=MAX_LENGTH).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=50,
                temperature=0.1,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        
        # Check if answer is contained in response
        if normalize_answer(answer) in normalize_answer(response):
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"BBH Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# HellaSwag Evaluation
# ==============================================================================
def evaluate_hellaswag(model, tokenizer, num_samples=MAX_SAMPLES):
    """Evaluate on HellaSwag dataset"""
    print("\n--- HellaSwag Evaluation ---")
    ds = load_dataset("hellaswag", split="validation")
    ds = ds.select(range(min(num_samples, len(ds))))
    
    correct = 0
    total = 0
    
    for item in tqdm(ds, desc="HellaSwag"):
        context = item["ctx"]
        endings = item["endings"]
        label = int(item["label"])
        
        # Score each ending
        scores = []
        for ending in endings:
            full_text = context + " " + ending
            inputs = tokenizer(full_text, return_tensors="pt", truncation=True,
                             max_length=MAX_LENGTH).to(model.device)
            
            with torch.no_grad():
                outputs = model(**inputs, labels=inputs["input_ids"])
                # Use negative loss as score (lower loss = better)
                scores.append(-outputs.loss.item())
        
        # Predict the ending with highest score
        predicted = np.argmax(scores)
        
        if predicted == label:
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"HellaSwag Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy

# ==============================================================================
# Run All Evaluations
# ==============================================================================
def run_all_benchmarks(model, tokenizer, tasks=None):
    """Run all specified benchmarks"""
    if tasks is None:
        tasks = tasks_to_run
    
    results = {}
    
    if "gsm8k" in tasks:
        results["gsm8k"] = evaluate_gsm8k(model, tokenizer)
    
    if "math" in tasks:
        results["math"] = evaluate_math(model, tokenizer)
    
    if "mmlu" in tasks:
        results["mmlu"] = evaluate_mmlu(model, tokenizer)
    
    if "arc_challenge" in tasks:
        results["arc_challenge"] = evaluate_arc_challenge(model, tokenizer)
    
    if "bbh" in tasks:
        results["bbh"] = evaluate_bbh(model, tokenizer)
    
    if "hellaswag" in tasks:
        results["hellaswag"] = evaluate_hellaswag(model, tokenizer)
    
    return results

# ==============================================================================
# Main Evaluation
# ==============================================================================
print("\n" + "="*80)
print("STARTING BENCHMARK EVALUATION")
print("="*80)

# Load the trained model
print("Loading Stage 2 model...")
model.eval()

# Run benchmarks
results = run_all_benchmarks(model, tokenizer, tasks_to_run)

# Print summary
print("\n" + "="*80)
print("BENCHMARK RESULTS SUMMARY")
print("="*80)
print(results)
for task, accuracy in results.items():
    print(f"{task.upper()}: {accuracy*100:.2f}%")

# Calculate average
avg_accuracy = np.mean(list(results.values()))
print(f"\nAVERAGE ACCURACY: {avg_accuracy*100:.2f}%")

# Save results
results_with_avg = {**results, "average": avg_accuracy}
with open("benchmark_results.json", "w") as f:
    json.dump(results_with_avg, f, indent=2)

print("\n✓ Results saved to benchmark_results.json")


BENCHMARK EVALUATION

STARTING BENCHMARK EVALUATION
Loading Stage 2 model...

--- GSM8K Evaluation ---


README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

GSM8K: 100%|██████████| 50/50 [25:48<00:00, 30.98s/it]


GSM8K Accuracy: 0.2400 (12/50)

--- MATH Dataset Evaluation ---


README.md: 0.00B [00:00, ?B/s]

math_dataset.py: 0.00B [00:00, ?B/s]

The repository for math_dataset contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/math_dataset.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N]  y


Downloading data:   0%|          | 0.00/2.33G [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1999998 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

MATH: 100%|██████████| 50/50 [25:56<00:00, 31.13s/it]

MATH Accuracy: 0.1000 (5/50)

--- MMLU Evaluation ---





README.md: 0.00B [00:00, ?B/s]

dataset_infos.json: 0.00B [00:00, ?B/s]

abstract_algebra/test-00000-of-00001.par(…):   0%|          | 0.00/9.96k [00:00<?, ?B/s]

abstract_algebra/validation-00000-of-000(…):   0%|          | 0.00/3.73k [00:00<?, ?B/s]

abstract_algebra/dev-00000-of-00001.parq(…):   0%|          | 0.00/3.45k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

MMLU:   0%|          | 0/50 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
MMLU: 100%|██████████| 50/50 [01:05<00:00,  1.31s/it]

MMLU Accuracy: 0.5400 (27/50)

--- ARC Challenge Evaluation ---





README.md: 0.00B [00:00, ?B/s]

ARC-Challenge/train-00000-of-00001.parqu(…):   0%|          | 0.00/190k [00:00<?, ?B/s]

ARC-Challenge/test-00000-of-00001.parque(…):   0%|          | 0.00/204k [00:00<?, ?B/s]

ARC-Challenge/validation-00000-of-00001.(…):   0%|          | 0.00/55.7k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1119 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1172 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/299 [00:00<?, ? examples/s]

ARC-C: 100%|██████████| 50/50 [01:05<00:00,  1.31s/it]

ARC Challenge Accuracy: 0.8800 (44/50)

--- BBH Evaluation ---





README.md: 0.00B [00:00, ?B/s]

boolean_expressions/test-00000-of-00001.(…):   0%|          | 0.00/4.52k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/250 [00:00<?, ? examples/s]

BBH: 100%|██████████| 50/50 [05:02<00:00,  6.05s/it]

BBH Accuracy: 0.9200 (46/50)

--- HellaSwag Evaluation ---





README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/24.4M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/6.11M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/6.32M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/39905 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10003 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10042 [00:00<?, ? examples/s]

HellaSwag: 100%|██████████| 50/50 [00:45<00:00,  1.11it/s]

HellaSwag Accuracy: 0.6000 (30/50)

BENCHMARK RESULTS SUMMARY
{'gsm8k': 0.24, 'math': 0.1, 'mmlu': 0.54, 'arc_challenge': 0.88, 'bbh': 0.92, 'hellaswag': 0.6}
GSM8K: 24.00%
MATH: 10.00%
MMLU: 54.00%
ARC_CHALLENGE: 88.00%
BBH: 92.00%
HELLASWAG: 60.00%

AVERAGE ACCURACY: 54.67%

✓ Results saved to benchmark_results.json



