In [None]:
import math
import numpy as np
import pandas as pd
import os
import gc
import torch
import random
from collections import defaultdict
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from huggingface_hub import login

# Set up HuggingFace authentication
#os.environ["HF_TOKEN"] = "hf_***REDACTED***"
#login(os.environ["HF_TOKEN"])

#==============================================================================
# STEP 1: DATA LOADING AND SETUP
#==============================================================================

def load_arc_datasets():
    """Load both ARC Challenge and Easy datasets."""
    print("Loading ARC Challenge dataset...")
    arc_data = load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test")
    arc_df = arc_data.to_pandas()
    arc_df = arc_df.drop_duplicates(subset=['question'])
    arc_df["choices"] = arc_df["choices"].apply(lambda x: x["text"])
    arc_df["subject"] = "science"
    print(f"ARC Challenge shape: {arc_df.shape}")
    
    print("Loading ARC Easy dataset...")
    arc_easy_data = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test")
    arc_easy_df = arc_easy_data.to_pandas()
    arc_easy_df = arc_easy_df.drop_duplicates(subset=['question'])
    arc_easy_df["choices"] = arc_easy_df["choices"].apply(lambda x: x["text"])
    arc_easy_df["subject"] = "science"
    print(f"ARC Easy shape: {arc_easy_df.shape}")
    
    return arc_df, arc_easy_df

#==============================================================================
# STEP 2: PREPROCESSING SYSTEM (Phi-2 + Qwen2-1.5B for Confidence Filtering)
#==============================================================================

class PreprocessingSystem:
    """
    Step 2: Use Phi-2 + Qwen2-1.5B to analyze all choices and filter out low-confidence options
    """
    
    def __init__(self, confidence_threshold=0.10):
        self.confidence_threshold = confidence_threshold
        self.phi2_model = None
        self.phi2_tokenizer = None
        self.qwen2_model = None
        self.qwen2_tokenizer = None
        
    def load_preprocessing_models(self):
        """Load both models for preprocessing."""
        print("Loading Phi-2 for preprocessing...")
        self.phi2_model = AutoModelForCausalLM.from_pretrained(
            "microsoft/phi-2",
            torch_dtype=torch.float32,
            low_cpu_mem_usage=True,
            device_map="cpu",
            trust_remote_code=True
        )
        self.phi2_tokenizer = AutoTokenizer.from_pretrained(
            "microsoft/phi-2", 
            trust_remote_code=True
        )
        if self.phi2_tokenizer.pad_token is None:
            self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token
            
        print("Loading Qwen2-1.5B for preprocessing...")
        self.qwen2_model = AutoModelForCausalLM.from_pretrained(
            "Qwen/Qwen2-1.5B-Instruct",
            torch_dtype=torch.float32,
            low_cpu_mem_usage=True,
            device_map="cpu",
            trust_remote_code=True
        )
        self.qwen2_tokenizer = AutoTokenizer.from_pretrained(
            "Qwen/Qwen2-1.5B-Instruct",
            trust_remote_code=True
        )
        if self.qwen2_tokenizer.pad_token is None:
            self.qwen2_tokenizer.pad_token = self.qwen2_tokenizer.eos_token
            
        print("Preprocessing models loaded successfully!")
        
    def build_analysis_prompt(self, question, choices):
        """Build prompt for confidence analysis."""
        prompt = f"""The following is a multiple choice science question. Analyze all choices and select the most likely correct answer.

Question: {question}
"""
        for i, choice in enumerate(choices):
            prompt += f"{chr(65+i)}. {choice}\n"
        prompt += "\nAnswer:"
        return prompt
        
    def get_model_confidence_scores(self, model, tokenizer, prompt, num_choices):
        """Get confidence scores from a model for answer choices."""
        try:
            inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
            
            with torch.no_grad():
                outputs = model(**inputs)
                logits = outputs.logits[0, -1]
            
            # Get probabilities for A, B, C, D based on number of choices
            choice_logits = []
            for i in range(num_choices):
                letter = chr(65 + i)  # A, B, C, D
                token_id = tokenizer.encode(letter, add_special_tokens=False)[-1]
                choice_logits.append(logits[token_id].item())
            
            # Convert to probabilities
            choice_logits = torch.tensor(choice_logits)
            probs = torch.nn.functional.softmax(choice_logits, dim=0).numpy()
            
            return probs
            
        except Exception as e:
            print(f"Error getting confidence scores: {e}")
            # Return uniform distribution as fallback
            return np.ones(num_choices) / num_choices
    
    def analyze_all_choices(self, question, choices):
        """
        Step 2: Phi-2 + Qwen2-1.5B analyze all choices → combined probabilities
        """
        prompt = self.build_analysis_prompt(question, choices)
        num_choices = len(choices)
        
        # Get confidence scores from both models
        phi2_probs = self.get_model_confidence_scores(
            self.phi2_model, self.phi2_tokenizer, prompt, num_choices
        )
        qwen2_probs = self.get_model_confidence_scores(
            self.qwen2_model, self.qwen2_tokenizer, prompt, num_choices
        )
        
        # Combined with 50/50 weighting
        combined_probs = 0.5 * phi2_probs + 0.5 * qwen2_probs
        
        return {
            'phi2_probs': phi2_probs,
            'qwen2_probs': qwen2_probs,
            'combined_probs': combined_probs
        }
    
    def filter_low_confidence_choices(self, choices, combined_probs):
        """
        Step 3: Filter out choices with prob < 0.10 threshold
        """
        # Find choices above threshold
        high_confidence_indices = [
            i for i, prob in enumerate(combined_probs) 
            if prob >= self.confidence_threshold
        ]
        
        # Fallback: keep all choices if none meet threshold
        if not high_confidence_indices:
            high_confidence_indices = list(range(len(choices)))
        
        # Extract high-confidence choices
        filtered_choices = [choices[i] for i in high_confidence_indices]
        filtered_probs = [combined_probs[i] for i in high_confidence_indices]
        
        return filtered_choices, filtered_probs, high_confidence_indices
    
    def randomize_choice_order(self, choices, probs, original_indices):
        """Randomly shuffle choices to prevent positional bias."""
        combined = list(zip(choices, probs, original_indices))
        random.shuffle(combined)
        
        shuffled_choices, shuffled_probs, shuffled_original_indices = zip(*combined)
        position_mapping = {i: orig_idx for i, orig_idx in enumerate(shuffled_original_indices)}
        
        return list(shuffled_choices), list(shuffled_probs), position_mapping
    
    def preprocess_question(self, question, choices, answer_key=None):
        """
        Complete preprocessing pipeline: Steps 2-3
        """
        # Step 2: Analyze all choices
        analysis_results = self.analyze_all_choices(question, choices)
        
        # Step 3: Filter low-confidence choices
        filtered_choices, filtered_probs, high_confidence_indices = self.filter_low_confidence_choices(
            choices, analysis_results['combined_probs']
        )
        
        # Randomize to prevent bias
        final_choices, final_probs, position_mapping = self.randomize_choice_order(
            filtered_choices, filtered_probs, high_confidence_indices
        )
        
        # Update answer key if provided
        new_answer_key = None
        if answer_key:
            original_answer_idx = ord(answer_key) - ord('A')
            if original_answer_idx in high_confidence_indices:
                for new_pos, orig_pos in position_mapping.items():
                    if orig_pos == original_answer_idx:
                        new_answer_key = chr(ord('A') + new_pos)
                        break
        
        return {
            'original_question': question,
            'original_choices': choices,
            'filtered_choices': final_choices,
            'analysis_results': analysis_results,
            'position_mapping': position_mapping,
            'high_confidence_indices': high_confidence_indices,
            'original_answer_key': answer_key,
            'new_answer_key': new_answer_key,
            'filtering_applied': len(final_choices) < len(choices)
        }
    
    def cleanup_preprocessing_models(self):
        """Clean up preprocessing models."""
        if self.phi2_model:
            del self.phi2_model
        if self.phi2_tokenizer:
            del self.phi2_tokenizer
        if self.qwen2_model:
            del self.qwen2_model
        if self.qwen2_tokenizer:
            del self.qwen2_tokenizer
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

#==============================================================================
# STEP 4: ADVERSARIAL FRAMEWORK (Qwen2-1.5B Generator vs Phi-2 Discriminator)
#==============================================================================

class AdversarialFramework:
    """
    Steps 4-6: Adversarial training between Qwen2-1.5B (Generator) and Phi-2 (Discriminator)
    """
    
    def __init__(self):
        self.generator_model = None      # Qwen2-1.5B
        self.generator_tokenizer = None
        self.discriminator_model = None  # Phi-2
        self.discriminator_tokenizer = None
    
    def setup_adversarial_models(self, phi2_model, phi2_tokenizer, qwen2_model, qwen2_tokenizer):
        """
        Step 4: Set up adversarial roles
        Generator: Phi-2 (tries to pick correct answers)
        Discriminator: Qwen2-1.5B (tries to evaluate correctness)
        """
        self.generator_model = phi2_model         # Phi-2 as Generator
        self.generator_tokenizer = phi2_tokenizer
        self.discriminator_model = qwen2_model    # Qwen2-1.5B as Discriminator
        self.discriminator_tokenizer = qwen2_tokenizer
        
        print("Adversarial roles set:")
        print("  Generator: Phi-2")
        print("  Discriminator: Qwen2-1.5B-Instruct")

    def format_subject(self, subject):
        """Format subject string."""
        return " ".join(subject.split("_"))

    def build_generator_prompt(self, subject, question, choices, get_correct):
        """Build prompt for generator (Qwen2-1.5B)."""
        prompt = f"The following are multiple choice questions (with answers) about {self.format_subject(subject)}.\n\n"
        prompt += f"{question}"
        
        for i, choice in enumerate(choices):
            prompt += f"\n{chr(65+i)}. {choice}"
            
        if get_correct:
            prompt += "\nAnswer:"
        else:
            prompt += "\nIncorrect Answer:"
        return prompt

    def get_generator_probabilities(self, prompt_text, num_choices):
        """Get generator (Qwen2-1.5B) answer probabilities."""
        try:
            input_ids = self.generator_tokenizer(prompt_text, return_tensors="pt").input_ids.to(self.generator_model.device)
            
            with torch.no_grad():
                logits = self.generator_model(input_ids=input_ids).logits[0, -1]

            # Get probabilities for A, B, C, D...
            choice_logits = []
            for i in range(num_choices):
                letter = chr(65 + i)
                token_id = self.generator_tokenizer(letter, return_tensors="pt").input_ids[0, -1].item()
                choice_logits.append(logits[token_id].item())
            
            choice_logits = torch.tensor(choice_logits, device=self.generator_model.device).float()
            probs = torch.nn.functional.softmax(choice_logits, dim=0).detach().cpu().numpy()
            
            return probs
            
        except Exception as e:
            print(f"Error in generator probabilities: {e}")
            return np.ones(num_choices) / num_choices

    def get_generator_initial_probs(self, question, choices, subject):
        """Get initial generator probabilities for correct/incorrect."""
        gen_init = {"correct": {}, "incorrect": {}}
        candidates = [f"{chr(65+i)}" for i in range(len(choices))]
        
        for get_correct in [True, False]:
            prompt = self.build_generator_prompt(subject, question, choices, get_correct)
            probs = self.get_generator_probabilities(prompt, len(choices))
            
            for i, candidate in enumerate(candidates):
                if get_correct:
                    gen_init["correct"][candidate] = probs[i]
                else:
                    gen_init["incorrect"][candidate] = probs[i]

        return gen_init

    def build_discriminator_prompt(self, subject, question, proposed_answer):
        """Build prompt for discriminator (Phi-2)."""
        prompt = f"""You are an expert evaluator of questions about {self.format_subject(subject)}. 
Determine if the proposed answer is correct. Output ONLY 'A' or 'B'.
Question: {question}
Proposed Answer: {proposed_answer}

Is this answer correct? Respond ONLY with:
A. Correct
B. Incorrect

Answer:
"""
        return prompt

    def get_discriminator_probabilities(self, prompt_text):
        """Get discriminator (Phi-2) probabilities."""
        try:
            input_ids = self.discriminator_tokenizer(prompt_text, return_tensors="pt").input_ids.to(self.discriminator_model.device)
            
            with torch.no_grad():
                logits = self.discriminator_model(input_ids=input_ids).logits[0, -1]

            choice_logits = torch.tensor([
                logits[self.discriminator_tokenizer("A").input_ids[-1]],
                logits[self.discriminator_tokenizer("B").input_ids[-1]],
            ]).float()
            
            probs = torch.nn.functional.softmax(choice_logits, dim=0).detach().cpu().numpy()
            
            return {"correct": probs[0], "incorrect": probs[1]}
            
        except Exception as e:
            print(f"Error in discriminator probabilities: {e}")
            return {"correct": 0.5, "incorrect": 0.5}

    def get_discriminator_initial_probs(self, question, choices, subject):
        """Get initial discriminator probabilities for each choice."""
        results = {}
        
        for idx, answer in enumerate(choices):
            prompt = self.build_discriminator_prompt(subject, question, answer)
            probs = self.get_discriminator_probabilities(prompt)
            
            candidate = f"{chr(65+idx)}"
            results[candidate] = probs
        
        return results

    def softmax(self, arr):
        """Numerically stable softmax."""
        m = np.max(arr)
        exp_vals = np.exp(arr - m)
        return exp_vals / np.sum(exp_vals)

    def equilibrium_search(self, gen_init, disc_init, candidates, T=5000, 
                          eta_G=0.1, eta_D=0.1, lam_G=0.1, lam_D=0.01):
        """
        Step 5: Equilibrium search finds Nash equilibrium
        """
        gen = {"correct": dict(gen_init["correct"]), 
               "incorrect": dict(gen_init["incorrect"])}
        disc = {}
        for y in candidates:
            disc[y] = dict(disc_init[y])

        Qg = {"correct": {y: 0.0 for y in candidates}, 
              "incorrect": {y: 0.0 for y in candidates}}
        Qd = {y: {"correct": 0.0, "incorrect": 0.0} for y in candidates}

        for t in range(1, T+1):
            # Update Q values
            for v in ["correct", "incorrect"]:
                for y in candidates:
                    Qg[v][y] += (1.0/(2.0*t)) * disc[y][v]

            for y in candidates:
                for v in ["correct", "incorrect"]:
                    Qd[y][v] += (1.0/(2.0*t)) * gen[v][y]

            # Update generator policy
            for v in ["correct", "incorrect"]:
                logits = []
                for y in candidates:
                    val = (Qg[v][y] + lam_G * math.log(gen_init[v][y] + 1e-12)) / (1/eta_G + lam_G)
                    logits.append(val)

                new_probs = self.softmax(np.array(logits))
                for i, y in enumerate(candidates):
                    gen[v][y] = new_probs[i]

            # Update discriminator policy
            logits_correct = []
            logits_incorrect = []
            for y in candidates:
                val_correct = (Qd[y]["correct"] + lam_D * math.log(disc_init[y]["correct"] + 1e-12)) / (1/eta_D + lam_D)
                logits_correct.append(val_correct)

                val_incorrect = (Qd[y]["incorrect"] + lam_D * math.log(disc_init[y]["incorrect"] + 1e-12)) / (1/eta_D + lam_D)
                logits_incorrect.append(val_incorrect)

            new_probs_correct = self.softmax(np.array(logits_correct))
            new_probs_incorrect = self.softmax(np.array(logits_incorrect))

            for i, y in enumerate(candidates):
                disc[y]["correct"] = new_probs_correct[i]
                disc[y]["incorrect"] = new_probs_incorrect[i]

        return gen, disc

    def get_final_answers(self, gen_final, disc_final, candidates):
        """
        Step 6: Final answers from Generator (Qwen2) vs Discriminator (Phi-2)
        """
        # Generator's final answer
        gen_answer = None
        best_gen_prob = -1.0
        for y in candidates:
            p = gen_final["correct"][y]
            if p > best_gen_prob:
                best_gen_prob = p
                gen_answer = y
        
        # Discriminator's final answer
        disc_answer = None
        best_disc_prob = -1.0
        for y in candidates:
            p = disc_final[y]["correct"]
            if p > best_disc_prob:
                best_disc_prob = p
                disc_answer = y
        
        return gen_answer, disc_answer

#==============================================================================
# STEP 7: INTEGRATED PIPELINE
#==============================================================================

class IntegratedPipeline:
    """
    Complete pipeline combining preprocessing + adversarial training
    """
    
    def __init__(self, confidence_threshold=0.10):
        self.preprocessing = PreprocessingSystem(confidence_threshold)
        self.adversarial = AdversarialFramework()
        self.models_loaded = False
        
    def initialize_all_models(self):
        """Initialize all models."""
        if not self.models_loaded:
            print("Initializing complete pipeline...")
            
            # Load preprocessing models
            self.preprocessing.load_preprocessing_models()
            
            # Set up adversarial framework with same models
            self.adversarial.setup_adversarial_models(
                self.preprocessing.phi2_model, 
                self.preprocessing.phi2_tokenizer,
                self.preprocessing.qwen2_model, 
                self.preprocessing.qwen2_tokenizer
            )
            
            self.models_loaded = True
            print("All models initialized successfully!")
        
    def evaluate_policy_accuracy(self, policy_probs, correct_answer, candidates, policy_type):
        """Evaluate accuracy of a policy (initial or final)."""
        if policy_type == "generator":
            # For generator, use "correct" probabilities
            best_answer = max(candidates, key=lambda x: policy_probs["correct"][x])
        else:
            # For discriminator, use "correct" probabilities for each choice
            best_answer = max(candidates, key=lambda x: policy_probs[x]["correct"])
        
        return 1.0 if best_answer == correct_answer else 0.0

    def process_single_question(self, question, choices, answer_key, subject):
        """Process a single question through the complete pipeline."""
        
        # Steps 2-3: Preprocessing
        preprocessing_result = self.preprocessing.preprocess_question(question, choices, answer_key)
        
        filtered_choices = preprocessing_result["filtered_choices"]
        new_answer_key = preprocessing_result.get("new_answer_key")
        
        # Check if only one choice remains (automatic consensus)
        if len(filtered_choices) == 1:
            consensus_answer = "A"
            return {
                **preprocessing_result,
                'consensus_achieved': True,
                'gen_final_answer': consensus_answer,
                'disc_final_answer': consensus_answer,
                'adversarial_training_applied': False,
                'preprocessing_eliminated_choices': len(choices) - 1,
                'gen_initial_accuracy': None,
                'disc_initial_accuracy': None,
                'gen_final_accuracy': 1.0 if new_answer_key == "A" else 0.0,
                'disc_final_accuracy': 1.0 if new_answer_key == "A" else 0.0,
                'gen_accuracy_change': None,
                'disc_accuracy_change': None
            }
        
        # Steps 4-6: Adversarial training
        candidates = [f"{chr(65+i)}" for i in range(len(filtered_choices))]
        
        # Step 4: Get initial probabilities
        gen_init = self.adversarial.get_generator_initial_probs(question, filtered_choices, subject)
        disc_init = self.adversarial.get_discriminator_initial_probs(question, filtered_choices, subject)
        
        # Evaluate initial accuracy BEFORE equilibrium search
        gen_initial_accuracy = self.evaluate_policy_accuracy(gen_init, new_answer_key, candidates, "generator")
        disc_initial_accuracy = self.evaluate_policy_accuracy(disc_init, new_answer_key, candidates, "discriminator")
        
        # Step 5: Equilibrium search
        gen_final, disc_final = self.adversarial.equilibrium_search(
            gen_init, disc_init, candidates, T=20, 
            eta_G=0.1, eta_D=0.1, lam_G=0.1, lam_D=0.1
        )
        
        # Step 6: Get final answers
        gen_answer, disc_answer = self.adversarial.get_final_answers(gen_final, disc_final, candidates)
        
        # Evaluate final accuracy AFTER equilibrium search
        gen_final_accuracy = 1.0 if gen_answer == new_answer_key else 0.0
        disc_final_accuracy = 1.0 if disc_answer == new_answer_key else 0.0
        
        # Calculate accuracy change due to game theory
        gen_accuracy_change = gen_final_accuracy - gen_initial_accuracy
        disc_accuracy_change = disc_final_accuracy - disc_initial_accuracy
        
        return {
            **preprocessing_result,
            'consensus_achieved': False,
            'gen_final_answer': gen_answer,
            'disc_final_answer': disc_answer,
            'adversarial_training_applied': True,
            'gen_init_probs': gen_init,
            'disc_init_probs': disc_init,
            'gen_final_probs': gen_final,
            'disc_final_probs': disc_final,
            'preprocessing_eliminated_choices': len(choices) - len(filtered_choices),
            'gen_initial_accuracy': gen_initial_accuracy,
            'disc_initial_accuracy': disc_initial_accuracy,
            'gen_final_accuracy': gen_final_accuracy,
            'disc_final_accuracy': disc_final_accuracy,
            'gen_accuracy_change': gen_accuracy_change,
            'disc_accuracy_change': disc_accuracy_change
        }
    
    def process_dataset(self, df, max_questions=None):
        """Process complete dataset through pipeline."""
        self.initialize_all_models()
        
        # Limit to subset if specified
        if max_questions is not None:
            df = df.head(max_questions)
            print(f"Processing SUBSET of {len(df)} questions for testing...")
        
        results = []
        
        # Accuracy tracking
        preprocessing_correct = 0
        generator_correct = 0
        discriminator_correct = 0
        total_questions = 0
        consensus_questions = 0
        adversarial_questions = 0
        
        # Game theory improvement tracking
        gen_initial_correct = 0
        disc_initial_correct = 0
        gen_improved = 0
        disc_improved = 0
        gen_degraded = 0
        disc_degraded = 0
        
        print("Processing dataset through complete pipeline...")
        print("Pipeline: Steps 1-6")
        print("  1. Load questions")
        print("  2. Phi-2 + Qwen2-1.5B → confidence scores")
        print("  3. Filter choices with prob < 0.10")
        print("  4. Qwen2-1.5B (Generator) vs Phi-2 (Discriminator)")
        print("  5. Nash equilibrium search")
        print("  6. Final answers")
        
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing pipeline"):
            total_questions += 1
            
            # Process through complete pipeline
            result = self.process_single_question(
                row["question"], 
                row["choices"], 
                row.get("answerKey"),
                row["subject"]
            )
            
            # Add original row data
            result.update(row.to_dict())
            
            # Track accuracy
            if result['consensus_achieved']:
                consensus_questions += 1
                accuracy = 1.0 if (result['gen_final_answer'] == result['new_answer_key']) else 0.0
                if accuracy == 1.0:
                    preprocessing_correct += 1
                    generator_correct += 1
                    discriminator_correct += 1
                print(f"Consensus: {row.get('id', 'Q' + str(total_questions))} - Correct: {accuracy == 1.0}")
            else:
                adversarial_questions += 1
                gen_accuracy = 1.0 if (result['gen_final_answer'] == result['new_answer_key']) else 0.0
                disc_accuracy = 1.0 if (result['disc_final_answer'] == result['new_answer_key']) else 0.0
                
                if gen_accuracy == 1.0:
                    generator_correct += 1
                if disc_accuracy == 1.0:
                    discriminator_correct += 1
                
                # Track game theory improvements
                if result['gen_initial_accuracy'] == 1.0:
                    gen_initial_correct += 1
                if result['disc_initial_accuracy'] == 1.0:
                    disc_initial_correct += 1
                
                # Track accuracy changes
                if result['gen_accuracy_change'] > 0:
                    gen_improved += 1
                elif result['gen_accuracy_change'] < 0:
                    gen_degraded += 1
                
                if result['disc_accuracy_change'] > 0:
                    disc_improved += 1
                elif result['disc_accuracy_change'] < 0:
                    disc_degraded += 1
                    
                print(f"Adversarial: {row.get('id', 'Q' + str(total_questions))} - Gen: {gen_accuracy == 1.0} (Δ{result['gen_accuracy_change']:+.1f}), Disc: {disc_accuracy == 1.0} (Δ{result['disc_accuracy_change']:+.1f})")
            
            results.append(result)
            
            # Memory management
            if total_questions % 50 == 0:
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
        
        # Print final summary
        print(f"\n{'='*60}")
        print(f"COMPLETE PIPELINE RESULTS")
        print(f"{'='*60}")
        print(f"📊 Total Questions: {total_questions}")
        print(f"📊 Preprocessing Consensus: {consensus_questions} ({consensus_questions/total_questions*100:.1f}%)")
        print(f"📊 Adversarial Training: {adversarial_questions} ({adversarial_questions/total_questions*100:.1f}%)")
        print(f"")
        print(f"🎯 FINAL ACCURACY RESULTS:")
        print(f"   Preprocessing Only: {preprocessing_correct}/{consensus_questions} = {preprocessing_correct/consensus_questions*100:.1f}%" if consensus_questions > 0 else "   Preprocessing Only: N/A")
        print(f"   Generator (Qwen2): {generator_correct}/{total_questions} = {generator_correct/total_questions*100:.1f}%")
        print(f"   Discriminator (Phi-2): {discriminator_correct}/{total_questions} = {discriminator_correct/total_questions*100:.1f}%")
        
        if adversarial_questions > 0:
            print(f"")
            print(f"🎲 GAME THEORY IMPACT (on {adversarial_questions} adversarial questions):")
            print(f"   Generator Initial Accuracy: {gen_initial_correct}/{adversarial_questions} = {gen_initial_correct/adversarial_questions*100:.1f}%")
            print(f"   Generator Final Accuracy:   {generator_correct-preprocessing_correct}/{adversarial_questions} = {(generator_correct-preprocessing_correct)/adversarial_questions*100:.1f}%")
            print(f"   Generator Improvement:      {gen_improved} improved, {gen_degraded} degraded, {adversarial_questions-gen_improved-gen_degraded} unchanged")
            print(f"")
            print(f"   Discriminator Initial Accuracy: {disc_initial_correct}/{adversarial_questions} = {disc_initial_correct/adversarial_questions*100:.1f}%")
            print(f"   Discriminator Final Accuracy:   {discriminator_correct-preprocessing_correct}/{adversarial_questions} = {(discriminator_correct-preprocessing_correct)/adversarial_questions*100:.1f}%")
            print(f"   Discriminator Improvement:      {disc_improved} improved, {disc_degraded} degraded, {adversarial_questions-disc_improved-disc_degraded} unchanged")
            
            # Calculate net improvement
            gen_net_improvement = (generator_correct-preprocessing_correct) - gen_initial_correct
            disc_net_improvement = (discriminator_correct-preprocessing_correct) - disc_initial_correct
            
            print(f"")
            print(f"📈 NET GAME THEORY IMPROVEMENT:")
            print(f"   Generator Net Gain: {gen_net_improvement:+d} questions ({gen_net_improvement/adversarial_questions*100:+.1f}%)")
            print(f"   Discriminator Net Gain: {disc_net_improvement:+d} questions ({disc_net_improvement/adversarial_questions*100:+.1f}%)")
            
            # Quick summary for testing
            print(f"\n🔍 QUICK TEST SUMMARY:")
            print(f"   Game theory helped Generator: {gen_net_improvement > 0}")
            print(f"   Game theory helped Discriminator: {disc_net_improvement > 0}")
            print(f"   Overall beneficial: {gen_net_improvement + disc_net_improvement > 0}")
        
        return pd.DataFrame(results)
    
    def cleanup_all_models(self):
        """Clean up all models."""
        self.preprocessing.cleanup_preprocessing_models()
        self.models_loaded = False

#==============================================================================
# STEP 8: MAIN EXECUTION
#==============================================================================

def test_subset(num_questions=20):
    """Test the pipeline on a small subset to verify game theory improvements."""
    print("="*60)
    print(f"TESTING SUBSET - {num_questions} QUESTIONS")
    print("Generator: Phi-2 | Discriminator: Qwen2-1.5B")
    print("="*60)
    
    # Step 1: Load datasets
    arc_df, arc_easy_df = load_arc_datasets()
    
    # Initialize pipeline
    pipeline = IntegratedPipeline(confidence_threshold=0.10)
    
    # Test on ARC Challenge subset
    print(f"\nTesting on {num_questions} ARC Challenge questions...")
    results_test = pipeline.process_dataset(arc_df, max_questions=num_questions)
    
    # Save test results
    test_filename = f"Data/test_subset_{num_questions}_questions_swapped_roles.csv"
    results_test.to_csv(test_filename, index=False)
    print(f"Test results saved to {test_filename}")
    
    # Cleanup
    pipeline.cleanup_all_models()
    
    print(f"\n{'='*60}")
    print("SUBSET TEST COMPLETE - SWAPPED ROLES")
    print(f"{'='*60}")
    
    return results_test


def compare_roles_test(num_questions=20):
    """Compare both role configurations side by side."""
    print("="*70)
    print(f"COMPARING ROLES - {num_questions} QUESTIONS EACH")
    print("="*70)
    
    # Test default roles first
    print("\n" + "="*35)
    print("TESTING DEFAULT ROLES")
    print("="*35)
    results_default = test_subset(num_questions, swap_roles=False)
    
    # Test swapped roles
    print("\n" + "="*35)
    print("TESTING SWAPPED ROLES")
    print("="*35)
    results_swapped = test_subset(num_questions, swap_roles=True)
    
    # Compare results
    print("\n" + "="*70)
    print("ROLE COMPARISON SUMMARY")
    print("="*70)
    
    # Extract accuracy info from results
    default_gen_acc = results_default.get('generator_accuracy', [0]).mean() if hasattr(results_default, 'generator_accuracy') else 0
    default_disc_acc = results_default.get('discriminator_accuracy', [0]).mean() if hasattr(results_default, 'discriminator_accuracy') else 0
    swapped_gen_acc = results_swapped.get('generator_accuracy', [0]).mean() if hasattr(results_swapped, 'generator_accuracy') else 0
    swapped_disc_acc = results_swapped.get('discriminator_accuracy', [0]).mean() if hasattr(results_swapped, 'discriminator_accuracy') else 0
    
    print(f"DEFAULT ROLES (Qwen2=Gen, Phi-2=Disc):")
    print(f"  Generator Performance: {default_gen_acc:.1%}")
    print(f"  Discriminator Performance: {default_disc_acc:.1%}")
    print(f"")
    print(f"SWAPPED ROLES (Phi-2=Gen, Qwen2=Disc):")
    print(f"  Generator Performance: {swapped_gen_acc:.1%}")
    print(f"  Discriminator Performance: {swapped_disc_acc:.1%}")
    print(f"")
    print(f"RECOMMENDATION:")
    
    default_total = default_gen_acc + default_disc_acc
    swapped_total = swapped_gen_acc + swapped_disc_acc
    
    if swapped_total > default_total:
        print(f"  → Use SWAPPED ROLES (better combined performance)")
    elif default_total > swapped_total:
        print(f"  → Use DEFAULT ROLES (better combined performance)")
    else:
        print(f"  → Both configurations perform similarly")
    
    return results_default, results_swapped


def main():
    """Main execution following the complete pipeline."""
    
    print("="*60)
    print("ADVERSARIAL QA SYSTEM - SWAPPED ROLES")
    print("="*60)
    print("Step 1: Load ARC datasets")
    print("Step 2: Phi-2 + Qwen2-1.5B → confidence analysis")
    print("Step 3: Filter choices with prob < threshold")
    print("Step 4: Phi-2 (Generator) vs Qwen2-1.5B (Discriminator)")
    print("Step 5: Nash equilibrium search")
    print("Step 6: Final answers")
    print("="*60)
    
    # Ask user for test or full run
    print("\nChoose execution mode:")
    print("1. Test on subset (20 questions) - RECOMMENDED FIRST")
    print("2. Test on larger subset (100 questions)")
    print("3. Full dataset (1000+ questions)")
    
    choice = input("Enter choice (1/2/3): ").strip()
    
    if choice == "1":
        return test_subset(20)
    elif choice == "2":
        return test_subset(100)
    elif choice == "3":
        # Full dataset processing
        arc_df, arc_easy_df = load_arc_datasets()
        
        # Initialize pipeline
        pipeline = IntegratedPipeline(confidence_threshold=0.10)
        
        # Process ARC Challenge
        print(f"\nProcessing FULL ARC Challenge dataset...")
        results_challenge = pipeline.process_dataset(arc_df)
        
        # Save results
        challenge_filename = "Data/complete_pipeline_arc_challenge_swapped.csv"
        results_challenge.to_csv(challenge_filename, index=False)
        print(f"ARC Challenge results saved to {challenge_filename}")
        
        # Process ARC Easy
        print(f"\nProcessing FULL ARC Easy dataset...")
        results_easy = pipeline.process_dataset(arc_easy_df)
        
        # Save results
        easy_filename = "Data/complete_pipeline_arc_easy_swapped.csv"
        results_easy.to_csv(easy_filename, index=False)
        print(f"ARC Easy results saved to {easy_filename}")
        
        # Final cleanup
        pipeline.cleanup_all_models()
        
        print(f"\n{'='*60}")
        print("COMPLETE PIPELINE EXECUTION FINISHED - SWAPPED ROLES")
        print(f"{'='*60}")
        print("All datasets processed successfully!")
        
        return results_challenge, results_easy
    else:
        print("Invalid choice. Running test subset...")
        return test_subset(20)


def run_full_dataset(swap_roles=False):
    """Run the full dataset with specified role configuration."""
    role_desc = "SWAPPED" if swap_roles else "DEFAULT"
    print(f"\n{'='*60}")
    print(f"FULL DATASET PROCESSING - {role_desc} ROLES")
    print(f"{'='*60}")
    
    # Load datasets
    arc_df, arc_easy_df = load_arc_datasets()
    
    # Initialize pipeline
    pipeline = IntegratedPipeline(confidence_threshold=0.10, swap_roles=swap_roles)
    
    # Process ARC Challenge
    print(f"\nProcessing FULL ARC Challenge dataset...")
    results_challenge = pipeline.process_dataset(arc_df)
    
    # Save results
    role_suffix = "swapped" if swap_roles else "default"
    challenge_filename = f"Data/complete_pipeline_arc_challenge_{role_suffix}.csv"
    results_challenge.to_csv(challenge_filename, index=False)
    print(f"ARC Challenge results saved to {challenge_filename}")
    
    # Process ARC Easy
    print(f"\nProcessing FULL ARC Easy dataset...")
    results_easy = pipeline.process_dataset(arc_easy_df)
    
    # Save results
    easy_filename = f"Data/complete_pipeline_arc_easy_{role_suffix}.csv"
    results_easy.to_csv(easy_filename, index=False)
    print(f"ARC Easy results saved to {easy_filename}")
    
    # Final cleanup
    pipeline.cleanup_all_models()
    
    print(f"\n{'='*60}")
    print(f"COMPLETE PIPELINE EXECUTION FINISHED - {role_desc} ROLES")
    print(f"{'='*60}")
    print("All datasets processed successfully!")
    
    return results_challenge, results_easy

if __name__ == "__main__":
    main()

ADVERSARIAL QA SYSTEM - SWAPPED ROLES
Step 1: Load ARC datasets
Step 2: Phi-2 + Qwen2-1.5B → confidence analysis
Step 3: Filter choices with prob < threshold
Step 4: Phi-2 (Generator) vs Qwen2-1.5B (Discriminator)
Step 5: Nash equilibrium search
Step 6: Final answers

Choose execution mode:
1. Test on subset (20 questions) - RECOMMENDED FIRST
2. Test on larger subset (100 questions)
3. Full dataset (1000+ questions)


Enter choice (1/2/3):  3


Loading ARC Challenge dataset...
ARC Challenge shape: (1170, 5)
Loading ARC Easy dataset...
ARC Easy shape: (2371, 5)

Processing FULL ARC Challenge dataset...
Initializing complete pipeline...
Loading Phi-2 for preprocessing...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading Qwen2-1.5B for preprocessing...
Preprocessing models loaded successfully!
Adversarial roles set:
  Generator: Phi-2
  Discriminator: Qwen2-1.5B-Instruct
All models initialized successfully!
Processing dataset through complete pipeline...
Pipeline: Steps 1-6
  1. Load questions
  2. Phi-2 + Qwen2-1.5B → confidence scores
  3. Filter choices with prob < 0.10
  4. Qwen2-1.5B (Generator) vs Phi-2 (Discriminator)
  5. Nash equilibrium search
  6. Final answers


Processing pipeline:   0%|                  | 1/1170 [01:31<29:44:49, 91.61s/it]

Consensus: Mercury_7175875 - Correct: True


Processing pipeline:   0%|                  | 2/1170 [02:38<25:03:34, 77.24s/it]

Consensus: Mercury_SC_409171 - Correct: True


Processing pipeline:   0%|                 | 3/1170 [04:57<34:08:11, 105.31s/it]

Adversarial: Mercury_SC_408547 - Gen: True (Δ+0.0), Disc: True (Δ+1.0)


Processing pipeline:   0%|                  | 4/1170 [06:06<29:25:09, 90.83s/it]

Consensus: Mercury_407327 - Correct: True


Processing pipeline:   0%|                 | 5/1170 [08:28<35:24:26, 109.41s/it]

Adversarial: MCAS_2006_9_44 - Gen: False (Δ+0.0), Disc: False (Δ+0.0)


Processing pipeline:   1%|                 | 6/1170 [11:04<40:33:05, 125.42s/it]

Adversarial: Mercury_7270393 - Gen: False (Δ-1.0), Disc: False (Δ+0.0)


Processing pipeline:   1%|                 | 7/1170 [14:16<47:27:07, 146.89s/it]

Adversarial: MCAS_2014_5_7 - Gen: True (Δ+0.0), Disc: True (Δ+1.0)


Processing pipeline:   1%|                 | 8/1170 [17:18<51:04:11, 158.22s/it]

Adversarial: Mercury_7086660 - Gen: False (Δ+0.0), Disc: False (Δ-1.0)


Processing pipeline:   1%|▏                | 9/1170 [20:27<54:06:49, 167.79s/it]

Adversarial: Mercury_7168805 - Gen: False (Δ-1.0), Disc: True (Δ+1.0)


Processing pipeline:   1%|▏               | 10/1170 [23:09<53:31:53, 166.13s/it]

Adversarial: MCAS_2003_8_11 - Gen: True (Δ+0.0), Disc: True (Δ+0.0)


Processing pipeline:   1%|▏               | 11/1170 [25:38<51:46:28, 160.82s/it]

Adversarial: Mercury_7250058 - Gen: True (Δ+0.0), Disc: True (Δ+1.0)


Processing pipeline:   1%|▏               | 12/1170 [28:19<51:42:42, 160.76s/it]

Adversarial: Mercury_7012740 - Gen: False (Δ+0.0), Disc: False (Δ-1.0)


Processing pipeline:   1%|▏               | 13/1170 [31:42<55:45:52, 173.51s/it]

Adversarial: Mercury_LBS10610 - Gen: True (Δ+0.0), Disc: True (Δ+1.0)


Processing pipeline:   1%|▏               | 14/1170 [34:44<56:34:02, 176.16s/it]

Adversarial: Mercury_SC_407400 - Gen: True (Δ+0.0), Disc: True (Δ+0.0)


Processing pipeline:   1%|▏               | 15/1170 [37:50<57:30:02, 179.22s/it]

Adversarial: Mercury_7212993 - Gen: False (Δ-1.0), Disc: True (Δ+1.0)


Processing pipeline:   1%|▏               | 16/1170 [40:24<54:59:24, 171.55s/it]

Adversarial: Mercury_SC_413240 - Gen: True (Δ+0.0), Disc: True (Δ+0.0)


Processing pipeline:   1%|▏               | 17/1170 [41:37<45:25:27, 141.83s/it]

Consensus: Mercury_7186358 - Correct: True


Processing pipeline:   2%|▏               | 18/1170 [43:57<45:14:53, 141.40s/it]

Adversarial: Mercury_7166425 - Gen: True (Δ+0.0), Disc: True (Δ+0.0)


Processing pipeline:   2%|▎               | 19/1170 [45:05<38:07:10, 119.23s/it]

Consensus: MDSA_2007_8_3 - Correct: True


Processing pipeline:   2%|▎               | 20/1170 [48:27<46:05:47, 144.30s/it]

Adversarial: Mercury_7094290 - Gen: True (Δ+0.0), Disc: True (Δ+0.0)


Processing pipeline:   2%|▎               | 21/1170 [49:38<39:00:12, 122.20s/it]

Consensus: Mercury_7186568 - Correct: True


Processing pipeline:   2%|▎               | 22/1170 [51:10<36:04:15, 113.11s/it]

Consensus: Mercury_402216 - Correct: True


Processing pipeline:   2%|▎               | 23/1170 [55:01<47:20:54, 148.61s/it]

Adversarial: Mercury_404894 - Gen: True (Δ+0.0), Disc: True (Δ+0.0)


Processing pipeline:   2%|▎               | 24/1170 [58:28<52:52:59, 166.13s/it]

Adversarial: MCAS_2002_8_11 - Gen: True (Δ+1.0), Disc: False (Δ-1.0)


Processing pipeline:   2%|▎             | 25/1170 [1:00:44<49:57:32, 157.08s/it]

Adversarial: Mercury_SC_405086 - Gen: True (Δ+0.0), Disc: True (Δ+0.0)


Processing pipeline:   2%|▎             | 26/1170 [1:03:06<48:28:04, 152.52s/it]

Adversarial: Mercury_SC_408324 - Gen: True (Δ+1.0), Disc: True (Δ+0.0)


Processing pipeline:   2%|▎             | 27/1170 [1:05:32<47:45:44, 150.43s/it]

Adversarial: Mercury_7218820 - Gen: False (Δ+0.0), Disc: False (Δ+0.0)


Processing pipeline:   2%|▎             | 28/1170 [1:08:06<48:04:09, 151.53s/it]

Adversarial: Mercury_412202 - Gen: True (Δ+0.0), Disc: True (Δ+1.0)


Processing pipeline:   2%|▎             | 29/1170 [1:10:14<45:50:03, 144.61s/it]

Consensus: Mercury_SC_409139 - Correct: True


Processing pipeline:   3%|▎             | 30/1170 [1:13:18<49:32:42, 156.46s/it]

Consensus: Mercury_400687 - Correct: True


Processing pipeline:   3%|▎             | 31/1170 [1:16:47<54:29:52, 172.25s/it]

Adversarial: Mercury_7171605 - Gen: True (Δ+0.0), Disc: True (Δ+1.0)


Processing pipeline:   3%|▍             | 32/1170 [1:19:51<55:32:15, 175.69s/it]

Adversarial: Mercury_7210245 - Gen: False (Δ+0.0), Disc: False (Δ-1.0)


Processing pipeline:   3%|▍             | 33/1170 [1:21:00<45:23:30, 143.72s/it]

Consensus: AKDE&ED_2008_4_25 - Correct: True
