In [6]:
# Equilibrium Search for AI Multiple Choice Questions
# Uses dual-model approach: Phi-2 (generator) + Qwen2-1.5B (discriminator)
# Implements an equilibrium algorithm to improve answer selection

import math
import numpy as np
import google.protobuf
import sentencepiece
import pandas as pd
import os
from huggingface_hub import login
from datasets import load_dataset
import gc
import torch
from collections import defaultdict
from tqdm import tqdm
import time

################################
# DATASET LOADING
################################

def load_arc_datasets():
    """Load and preprocess ARC-Challenge and ARC-Easy datasets"""
    # Load ARC-Challenge dataset
    arc_data = load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test")
    arc_df = arc_data.to_pandas()
    arc_df = arc_df.drop_duplicates(subset=['question'])
    
    # Preprocess choices and add subject label
    arc_df["choices_dic"] = arc_df["choices"]
    arc_df["choices"] = arc_df["choices"].apply(lambda x: x["text"])
    arc_df["subject"] = "science"
    
    # Load ARC-Easy dataset with same preprocessing
    arc_data_easy = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test")
    arc_df_easy = arc_data_easy.to_pandas()
    arc_df_easy = arc_df_easy.drop_duplicates(subset=['question'])
    arc_df_easy["choices_dic"] = arc_df_easy["choices"]
    arc_df_easy["choices"] = arc_df_easy["choices"].apply(lambda x: x["text"])
    arc_df_easy["subject"] = "science"
    
    print(f"Loaded ARC-Challenge: {arc_df.shape[0]} questions")
    print(f"Loaded ARC-Easy: {arc_df_easy.shape[0]} questions")
    
    return arc_df, arc_df_easy

# Load datasets
arc_df, arc_df_easy = load_arc_datasets()

# Setup device configuration
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    device = "cuda"
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
else:
    device = "cpu"

print(f"Using device: {device}")

# Import required transformers components
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList, MinLengthLogitsProcessor
import accelerate
import re

################################
# CONFIGURATION
################################

# Processing mode: True for full dataset, False for testing with small subset
USE_BATCH_PROCESSING = False  # Changed to False for quick testing
BATCH_SIZE = 50              # Questions per batch (approximately 1.4 hours per batch)
TEST_SIZE_CHALLENGE = 20     # Test mode: subset size for ARC-Challenge (increased for better results)
TEST_SIZE_EASY = 20          # Test mode: subset size for ARC-Easy (increased for better results)

################################
# GENERATOR FUNCTIONS (PHI-2)
################################

def format_subject(subject):
    """Convert underscore-separated subject to readable format"""
    return " ".join(subject.split("_"))

def build_generator_prompt(subject, target_question, target_choices, get_correct):
    """
    Build prompt for generator model (Phi-2) to predict correct/incorrect answers
    
    Args:
        subject: Subject area (e.g., "science")
        target_question: The question text
        target_choices: List of answer choices
        get_correct: If True, asks for correct answer; if False, asks for incorrect answer
    """
    prompt = f"The following are multiple choice questions (with answers) about {format_subject(subject)}.\n\n"
    prompt += f"{target_question}"
    
    # Add labeled choices (A, B, C, D)
    for i, choice in enumerate(target_choices):
        prompt += f"\n{chr(65+i)}. {choice}"
    
    # Set prompt ending based on what we want to elicit
    prompt += "\nAnswer:" if get_correct else "\nIncorrect Answer:"
    return prompt

def get_generator_answer_probs(model, tokenizer, prompt_text, choices_list):
    """
    Get probability distribution over answer choices from generator model
    
    Returns:
        dict: Mapping from choice letters (A, B, C, D) to probabilities
    """
    # Tokenize prompt and get model predictions
    input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids.to(model.device)
    
    with torch.no_grad():
        # Get logits for next token prediction
        logits = model(input_ids=input_ids).logits[0, -1]

    # Extract logits for each answer choice letter
    choices = [f"{chr(65+i)}" for i, _ in enumerate(choices_list)]
    choice_logits = []
    
    for letter in choices:
        try:
            token_id = tokenizer.encode(letter, add_special_tokens=False)
            if token_id:
                choice_logits.append(logits[token_id[0]].item())
            else:
                choice_logits.append(-100.0)  # Very low probability if token not found
        except:
            choice_logits.append(-100.0)
    
    # Convert logits to probabilities using softmax
    choice_logits = torch.tensor(choice_logits, device=model.device).float()
    probs = torch.nn.functional.softmax(choice_logits, dim=0).detach().cpu().numpy()
    
    return {choice: prob for choice, prob in zip(choices, probs)}

def generator_probs(subject, question, choices_list, get_correct, model, tokenizer):
    """Wrapper function to get generator probabilities for a question"""
    choices = [f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices_list)]
    prompt = build_generator_prompt(subject, question, choices, get_correct)
    return get_generator_answer_probs(model, tokenizer, prompt, choices_list)

def get_initial_generator_probs(row, model, tokenizer):
    """
    Get initial generator probabilities for both correct and incorrect answer modes
    
    Returns:
        dict: Contains 'correct' and 'incorrect' probability distributions
    """
    gen_init = {"correct": {}, "incorrect": {}}
    subject, question, choices = row["subject"], row["question"], row["choices"]
    
    # Get probabilities for both correct and incorrect answer prompts
    for get_correct in [True, False]:
        choice_probs = generator_probs(subject, question, choices, get_correct, model, tokenizer)
        key = "correct" if get_correct else "incorrect"
        gen_init[key] = choice_probs
    
    return gen_init

################################
# DISCRIMINATOR FUNCTIONS (QWEN2-1.5B)
################################

def build_discriminator_prompt(subject, question, proposed_answer):
    """
    Build prompt for discriminator model to evaluate answer correctness
    
    Args:
        subject: Subject area
        question: The question text  
        proposed_answer: Answer choice to evaluate
    
    Returns:
        str: Formatted prompt asking for correctness judgment
    """
    return f"""You are an expert evaluator of questions about {format_subject(subject)}. 
Determine if the proposed answer is correct. Output ONLY 'A' or 'B'.
Question: {question}
Proposed Answer: {proposed_answer}

Is this answer correct? Respond ONLY with:
A. Correct
B. Incorrect

Answer:
"""

def get_discriminator_probs(model, tokenizer, prompt_text, choices_list):
    """
    Get correctness probabilities from discriminator model
    
    Returns:
        dict: {'correct': prob, 'incorrect': prob}
    """
    input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids.to(model.device)
    
    with torch.no_grad():
        logits = model(input_ids=input_ids).logits[0, -1]

    try:
        # Get logits for 'A' (correct) and 'B' (incorrect) tokens
        choice_logits = torch.tensor([
            logits[tokenizer("A").input_ids[-1]],
            logits[tokenizer("B").input_ids[-1]],
        ]).float()
        
        # Convert to probabilities
        probs = torch.nn.functional.softmax(choice_logits, dim=0).detach().cpu().numpy()
        return {"correct": probs[0], "incorrect": probs[1]}
    except:
        # Fallback to uniform distribution if token extraction fails
        return {"correct": 0.5, "incorrect": 0.5}

def evaluate_answer_correctness(row, model, tokenizer):
    """
    Evaluate correctness of each answer choice using discriminator model
    
    Returns:
        dict: Mapping from choice letters to correctness probabilities
    """
    subject = row["subject"]
    question = row["question"]
    choices = row["choices"]
    
    results = {}
    
    # Evaluate each answer choice
    for idx, answer in enumerate(choices):
        prompt = build_discriminator_prompt(subject, question, answer)
        probs = get_discriminator_probs(model, tokenizer, prompt, choices)
        choice_letter = chr(65 + idx)  # Convert index to letter (A, B, C, D)
        results[choice_letter] = probs
    
    return results

def get_initial_discriminator_probs(row, model, tokenizer):
    """Get initial discriminator probabilities for all answer choices"""
    return evaluate_answer_correctness(row, model, tokenizer)

################################
# EQUILIBRIUM SEARCH ALGORITHM
################################

def pick_answer(gen, disc, candidates, method="generator"):
    """
    Select best answer based on either generator or discriminator probabilities
    
    Args:
        gen: Generator probability distributions
        disc: Discriminator probability distributions  
        candidates: List of answer choice letters
        method: Either "generator" or "discriminator"
    
    Returns:
        str: Best answer choice letter
    """
    if method == "generator":
        # Choose answer with highest generator "correct" probability
        return max(candidates, key=lambda y: gen["correct"][y])
    else:
        # Choose answer with highest discriminator "correct" probability
        return max(candidates, key=lambda y: disc[y]["correct"])

def softmax(arr):
    """Numerically stable softmax implementation"""
    m = np.max(arr)
    exp_vals = np.exp(arr - m)
    return exp_vals / np.sum(exp_vals)

def equilibrium_search(gen_init, disc_init, candidates, 
                       T=5,           
                       eta_G=0.1,      
                       eta_D=0.1,     # Normal discriminator learning rate
                       lam_G=0.1,     # Strong generator regularization
                       lam_D=0.1):    # Normal discriminator regularization
    """
    Equilibrium Search Algorithm
    
    This algorithm iteratively updates generator and discriminator policies to reach
    an equilibrium while preserving the generator's strong initial performance.
    
    This prevents the generator from deviating too much from its strong initial policy.
    """
    
    # Initialize working copies of probability distributions
    gen = {"correct": dict(gen_init["correct"]), 
           "incorrect": dict(gen_init["incorrect"])}
    disc = {y: dict(disc_init[y]) for y in candidates}

    # Initialize Q-values (cumulative rewards) for both models
    Qg = {"correct": {y: 0.0 for y in candidates}, 
          "incorrect": {y: 0.0 for y in candidates}}
    Qd = {y: {"correct": 0.0, "incorrect": 0.0} for y in candidates}

    # Iterative equilibrium search
    for t in range(1, T + 1):
        # Update Q-values based on current policies
        # Generator Q-values incorporate discriminator feedback
        for v in ["correct", "incorrect"]:
            for y in candidates:
                Qg[v][y] += (1.0 / (2.0 * t)) * disc[y][v]

        # Discriminator Q-values incorporate generator feedback  
        for y in candidates:
            for v in ["correct", "incorrect"]:
                Qd[y][v] += (1.0 / (2.0 * t)) * gen[v][y]

        # Update generator policy with regularization
        # Strong regularization keeps it close to initial policy
        for v in ["correct", "incorrect"]:
            logits = []
            for y in candidates:
                # Combine Q-value with regularization toward initial policy
                val = (Qg[v][y] + lam_G * math.log(gen_init[v][y] + 1e-12)) / (1/eta_G + lam_G)
                logits.append(val)

            # Apply softmax to get new probability distribution
            new_probs = softmax(np.array(logits))
            for i, y in enumerate(candidates):
                gen[v][y] = new_probs[i]
        
        # Update discriminator policy
        # Less regularization allows more adaptation
        for y in candidates:
            logits_correct = []
            logits_incorrect = []
            
            for y_inner in candidates:
                # Correct probabilities
                val_correct = (Qd[y]["correct"] + lam_D * math.log(disc_init[y]["correct"] + 1e-12)) / (1/eta_D + lam_D)
                logits_correct.append(val_correct if y_inner == y else -float('inf'))
                
                # Incorrect probabilities  
                val_incorrect = (Qd[y]["incorrect"] + lam_D * math.log(disc_init[y]["incorrect"] + 1e-12)) / (1/eta_D + lam_D)
                logits_incorrect.append(val_incorrect if y_inner == y else -float('inf'))

            # Update this answer choice's probabilities
            new_probs_correct = softmax(np.array([val_correct]))
            new_probs_incorrect = softmax(np.array([val_incorrect]))
            
            disc[y]["correct"] = new_probs_correct[0]
            disc[y]["incorrect"] = new_probs_incorrect[0]

    return gen, disc

################################
# MODEL LOADING
################################

def load_dual_models():
    """
    Load the dual-model setup: Phi-2 (generator) + Qwen2-1.5B (discriminator)
    
    Returns:
        tuple: (generator_model, generator_tokenizer, discriminator_model, discriminator_tokenizer)
    """
    
    print("Loading dual-model setup:")
    print("  Generator: microsoft/phi-2") 
    print("  Discriminator: Qwen/Qwen2-1.5B-Instruct")
    
    # Configure device and precision
    if device == "cuda":
        torch_dtype = torch.float16
        device_map = "cuda"
    else:
        torch_dtype = torch.float32
        device_map = "cpu"
    
    # Load Phi-2 as generator
    generator_model = AutoModelForCausalLM.from_pretrained(
        "microsoft/phi-2",
        torch_dtype=torch_dtype,
        load_in_8bit=False,
        low_cpu_mem_usage=True,
        device_map=device_map,
        trust_remote_code=True
    )
    generator_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
    if generator_tokenizer.pad_token is None:
        generator_tokenizer.pad_token = generator_tokenizer.eos_token
    
    # Load Qwen2-1.5B as discriminator
    discriminator_model = AutoModelForCausalLM.from_pretrained(
        "Qwen/Qwen2-1.5B-Instruct",
        torch_dtype=torch_dtype,
        load_in_8bit=False,
        low_cpu_mem_usage=True,
        device_map=device_map,
        trust_remote_code=True
    )
    discriminator_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct", trust_remote_code=True)
    if discriminator_tokenizer.pad_token is None:
        discriminator_tokenizer.pad_token = discriminator_tokenizer.eos_token
    
    print(f"Models loaded successfully on {device}")
    
    return generator_model, generator_tokenizer, discriminator_model, discriminator_tokenizer

################################
# MAIN PROCESSING PIPELINE
################################

def process_questions_with_equilibrium(generator_model, generator_tokenizer, 
                                      discriminator_model, discriminator_tokenizer, df):
    """
    Main pipeline for processing questions using equilibrium search
    
    Args:
        generator_model, generator_tokenizer: Phi-2 model and tokenizer
        discriminator_model, discriminator_tokenizer: Qwen2-1.5B model and tokenizer
        df: DataFrame containing questions to process
        
    Returns:
        tuple: (processed_dataframe, average_time_per_question)
    """
    
    category_df = df.copy()
    
    # Initialize result lists
    gen_answer = []
    disc_answer = []
    gen_init_answer = []
    disc_init_answer = []
    disc_init_policy = []
    gen_init_policy = []
    disc_final_policy_consensus = []
    gen_final_policy_consensus = []
    
    print(f"Processing {len(category_df)} questions with Equilibrium Search...")
    
    question_times = []
    
    # Process each question
    for idx, (_, row) in enumerate(tqdm(category_df.iterrows(), total=len(category_df))):
        question_start_time = time.time()

        # Get initial discriminator probabilities
        disc_init = get_initial_discriminator_probs(row, discriminator_model, discriminator_tokenizer)
        disc_init_policy.append(disc_init)
        
        # Clean up memory
        gc.collect()
        if device == "cuda":
            torch.cuda.empty_cache()
            
        # Get initial generator probabilities  
        gen_init = get_initial_generator_probs(row, generator_model, generator_tokenizer)
        gen_init_policy.append(gen_init)
        
        # Clean up memory
        gc.collect()
        if device == "cuda":
            torch.cuda.empty_cache()
        
        # Record initial answers (before equilibrium search)
        gen_init_answer.append(max(gen_init["correct"], key=gen_init["correct"].get))
        disc_init_answer.append(max(disc_init, key=lambda choice: disc_init[choice]["correct"]))
        
        # Prepare for equilibrium search
        candidates = [f"{chr(65+i)}" for i in range(len(row["choices"]))]

        # Run equilibrium search
        gen_final, disc_final = equilibrium_search(
            gen_init, disc_init, candidates,
            T=5,        # Fewer iterations  
            eta_G=0.1,  # Slow generator learning
            eta_D=0.1,  # Normal discriminator learning
            lam_G=0.5,  # Strong generator regularization
            lam_D=0.1   # Normal discriminator regularization
        )
        
        # Store final policies
        disc_final_policy_consensus.append(disc_final)
        gen_final_policy_consensus.append(gen_final)

        # Select final answers using updated policies
        best_answer_g = pick_answer(gen_final, disc_final, candidates, method="generator")
        best_answer_d = pick_answer(gen_final, disc_final, candidates, method="discriminator")
        
        gen_answer.append(best_answer_g)
        disc_answer.append(best_answer_d)
        
        # Track timing
        question_time = time.time() - question_start_time
        question_times.append(question_time)
    
    avg_time_per_question = np.mean(question_times)
    
    # Add results to dataframe
    category_df["gen_init_answer"] = gen_init_answer
    category_df["disc_answer"] = disc_answer
    category_df["gen_answer"] = gen_answer
    category_df["disc_init_answer"] = disc_init_answer
    category_df["disc_final_policy_consensus"] = disc_final_policy_consensus
    category_df["disc_init_policy"] = disc_init_policy
    category_df["gen_init_policy"] = gen_init_policy
    category_df["gen_final_policy_consensus"] = gen_final_policy_consensus
    
    return category_df, avg_time_per_question

################################
# BATCH PROCESSING
################################

def process_in_batches(generator_model, generator_tokenizer, 
                      discriminator_model, discriminator_tokenizer, 
                      df, dataset_name, batch_size=50):
    """
    Process large dataset in batches to manage memory and save intermediate results
    
    Args:
        models and tokenizers: The loaded models
        df: Dataset to process
        dataset_name: Name for saving files
        batch_size: Number of questions per batch
        
    Returns:
        pd.DataFrame: Complete processed results
    """
    
    total_questions = len(df)
    num_batches = (total_questions + batch_size - 1) // batch_size
    
    print(f"\nProcessing {dataset_name} in {num_batches} batches of {batch_size}")
    print(f"Total questions: {total_questions}")
    
    all_results = []
    
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, total_questions)
        batch_df = df.iloc[start_idx:end_idx].copy()
        
        print(f"\nBatch {batch_idx + 1}/{num_batches}: Questions {start_idx+1}-{end_idx}")
        
        batch_start_time = time.time()
        batch_result, avg_time = process_questions_with_equilibrium(
            generator_model, generator_tokenizer,
            discriminator_model, discriminator_tokenizer, 
            batch_df
        )
        batch_time = time.time() - batch_start_time
        
        # Save batch result immediately for crash recovery
        batch_filename = f'Data/{dataset_name}_batch_{batch_idx+1}_of_{num_batches}.csv'
        batch_result.to_csv(batch_filename, index=False)
        all_results.append(batch_result)
        
        # Progress reporting
        completed_questions = end_idx
        remaining_questions = total_questions - completed_questions
        if remaining_questions > 0:
            estimated_remaining_time = (remaining_questions / len(batch_df)) * batch_time
            print(f"Batch complete: {batch_time/60:.1f}min | Remaining: {estimated_remaining_time/3600:.1f}h")
        else:
            print(f"Batch complete: {batch_time/60:.1f}min")
        
        # Memory cleanup between batches
        gc.collect()
        if device == "cuda":
            torch.cuda.empty_cache()
    
    # Combine all batch results
    print(f"\nCombining {num_batches} batches...")
    final_result = pd.concat(all_results, ignore_index=True)
    
    # Save final combined result
    final_filename = f'Data/{dataset_name}_complete_results.csv'
    final_result.to_csv(final_filename, index=False)
    print(f"Results saved: {final_filename}")
    
    return final_result

################################
# RESULTS ANALYSIS
################################

def analyze_results(results_df, dataset_name):
    """Analyze and display results for a processed dataset"""
    total_questions = len(results_df)
    
    # Calculate accuracy metrics
    gen_init_acc = (results_df['gen_init_answer'] == results_df['answerKey']).mean()
    disc_init_acc = (results_df['disc_init_answer'] == results_df['answerKey']).mean()
    gen_final_acc = (results_df['gen_answer'] == results_df['answerKey']).mean()
    disc_final_acc = (results_df['disc_answer'] == results_df['answerKey']).mean()
    
    print(f"\n{dataset_name.upper()} RESULTS ({total_questions} questions):")
    print(f"Generator (Initial):   {gen_init_acc:.3f} ({int(gen_init_acc*total_questions)}/{total_questions})")
    print(f"Discriminator (Initial): {disc_init_acc:.3f} ({int(disc_init_acc*total_questions)}/{total_questions})")
    print(f"Generator (Final):     {gen_final_acc:.3f} ({int(gen_final_acc*total_questions)}/{total_questions})")
    print(f"Discriminator (Final): {disc_final_acc:.3f} ({int(disc_final_acc*total_questions)}/{total_questions})")
    print(f"Generator Improvement: {gen_final_acc - gen_init_acc:+.3f}")
    print(f"Discriminator Improvement: {disc_final_acc - disc_init_acc:+.3f}")
    print(f"Best Method: {max(gen_final_acc, disc_final_acc):.3f}")
    
    return {
        'gen_init': gen_init_acc,
        'disc_init': disc_init_acc, 
        'gen_final': gen_final_acc,
        'disc_final': disc_final_acc
    }

################################
# MAIN EXECUTION
################################

def main():
    """Main execution function"""
    
    print(f"\nConfiguration:")
    if USE_BATCH_PROCESSING:
        print(f"  Mode: BATCH PROCESSING (Full Dataset)")
        print(f"  Batch Size: {BATCH_SIZE} questions")
        total_estimated_batches = (len(arc_df) + BATCH_SIZE - 1) // BATCH_SIZE + (len(arc_df_easy) + BATCH_SIZE - 1) // BATCH_SIZE
        print(f"  Estimated Total Batches: {total_estimated_batches}")
    else:
        print(f"  Mode: TEST (Small Subset)")
        print(f"  Test Size: {TEST_SIZE_CHALLENGE + TEST_SIZE_EASY} questions")

    # Clean up any existing models and load fresh ones
    gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()

    # Load models
    print("\n" + "="*50)
    print("LOADING MODELS")
    print("="*50)
    generator_model, generator_tokenizer, discriminator_model, discriminator_tokenizer = load_dual_models()

    # Create output directory
    os.makedirs('Data', exist_ok=True)

    if USE_BATCH_PROCESSING:
        # Full dataset processing mode
        print("\n" + "="*50)
        print("BATCH PROCESSING - FULL DATASET")
        print("="*50)
        
        # Process ARC-Challenge
        challenge_results = process_in_batches(
            generator_model, generator_tokenizer,
            discriminator_model, discriminator_tokenizer,
            arc_df, "arc_challenge", BATCH_SIZE
        )
        
        # Process ARC-Easy
        easy_results = process_in_batches(
            generator_model, generator_tokenizer,
            discriminator_model, discriminator_tokenizer,
            arc_df_easy, "arc_easy", BATCH_SIZE
        )
        
        # Analyze final results
        print("\n" + "="*50)
        print("FINAL RESULTS")
        print("="*50)
        
        challenge_metrics = analyze_results(challenge_results, "ARC-Challenge")
        easy_metrics = analyze_results(easy_results, "ARC-Easy")
        
        print(f"\nSUMMARY:")
        print(f"ARC-Challenge Best: {max(challenge_metrics['gen_final'], challenge_metrics['disc_final']):.1%}")
        print(f"ARC-Easy Best: {max(easy_metrics['gen_final'], easy_metrics['disc_final']):.1%}")

    else:
        # Test mode with small subsets
        print("\n" + "="*50)
        print("TEST MODE")
        print("="*50)

        # Create test subsets
        arc_df_test = arc_df.head(TEST_SIZE_CHALLENGE).copy()
        arc_df_easy_test = arc_df_easy.head(TEST_SIZE_EASY).copy()

        # Process test subsets
        print(f"\nProcessing ARC-Challenge subset ({TEST_SIZE_CHALLENGE} questions)...")
        challenge_results, _ = process_questions_with_equilibrium(
            generator_model, generator_tokenizer,
            discriminator_model, discriminator_tokenizer, 
            arc_df_test
        )
        
        print(f"\nProcessing ARC-Easy subset ({TEST_SIZE_EASY} questions)...")
        easy_results, _ = process_questions_with_equilibrium(
            generator_model, generator_tokenizer,
            discriminator_model, discriminator_tokenizer,
            arc_df_easy_test
        )

        # Save test results
        challenge_results.to_csv(f'Data/arc_challenge_test_{TEST_SIZE_CHALLENGE}.csv', index=False)
        easy_results.to_csv(f'Data/arc_easy_test_{TEST_SIZE_EASY}.csv', index=False)

        # Analyze test results
        print("\n" + "="*50)
        print("TEST RESULTS")
        print("="*50)
        
        analyze_results(challenge_results, "ARC-Challenge Test")
        analyze_results(easy_results, "ARC-Easy Test")

    # Final cleanup
    gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()
    
    print("\n Experiment completed successfully!")

if __name__ == "__main__":
    main()

Loaded ARC-Challenge: 1170 questions
Loaded ARC-Easy: 2371 questions
Using device: cpu

Configuration:
  Mode: TEST (Small Subset)
  Test Size: 40 questions

LOADING MODELS
Loading dual-model setup:
  Generator: microsoft/phi-2
  Discriminator: Qwen/Qwen2-1.5B-Instruct


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Models loaded successfully on cpu

TEST MODE

Processing ARC-Challenge subset (20 questions)...
Processing 20 questions with Conservative Equilibrium...


100%|██████████████████████████████████████████| 20/20 [34:41<00:00, 104.10s/it]



Processing ARC-Easy subset (20 questions)...
Processing 20 questions with Conservative Equilibrium...


100%|███████████████████████████████████████████| 20/20 [31:33<00:00, 94.66s/it]



TEST RESULTS

ARC-CHALLENGE TEST RESULTS (20 questions):
Generator (Initial):   0.850 (17/20)
Discriminator (Initial): 0.600 (12/20)
Generator (Final):     0.800 (16/20)
Discriminator (Final): 0.200 (4/20)
Generator Improvement: -0.050
Discriminator Improvement: -0.400
Best Method: 0.800

ARC-EASY TEST RESULTS (20 questions):
Generator (Initial):   0.800 (16/20)
Discriminator (Initial): 0.800 (16/20)
Generator (Final):     0.800 (16/20)
Discriminator (Final): 0.150 (3/20)
Generator Improvement: +0.000
Discriminator Improvement: -0.650
Best Method: 0.800

 Experiment completed successfully!
