In [None]:
import math
import numpy as np
import pandas as pd
import os
import gc
import torch
import torch.nn.functional as F
import random
import time
import warnings
from collections import defaultdict
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
warnings.filterwarnings('ignore')

class PerformanceConfig:
    """Centralized configuration for performance optimizations"""
    
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device_name = torch.cuda.get_device_name() if torch.cuda.is_available() else "CPU"
        self.use_fp16 = self.device == "cuda"
        self.dtype = torch.float16 if self.use_fp16 else torch.float32
        self.clear_cache_every = 50
        self.equilibrium_iterations = 5
        self.early_stop_threshold = 0.001
        
        print(f"Performance Configuration:")
        print(f"  Device: {self.device} ({self.device_name})")
        print(f"  Precision: {'FP16' if self.use_fp16 else 'FP32'}")

PERF_CONFIG = PerformanceConfig()

def load_arc_datasets():
    """Load both ARC Challenge and Easy datasets."""
    print("Loading ARC Challenge dataset...")
    arc_data = load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test")
    arc_df = arc_data.to_pandas()
    arc_df = arc_df.drop_duplicates(subset=['question'])
    arc_df["choices"] = arc_df["choices"].apply(lambda x: x["text"])
    arc_df["subject"] = "science"
    print(f"ARC Challenge shape: {arc_df.shape}")
    
    print("Loading ARC Easy dataset...")
    arc_easy_data = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test")
    arc_easy_df = arc_easy_data.to_pandas()
    arc_easy_df = arc_easy_df.drop_duplicates(subset=['question'])
    arc_easy_df["choices"] = arc_easy_df["choices"].apply(lambda x: x["text"])
    arc_easy_df["subject"] = "science"
    print(f"ARC Easy shape: {arc_easy_df.shape}")
    
    return arc_df, arc_easy_df

#==============================================================================
# HYBRID CALIBRATION + GAME SOFTMAX CLASSES
#==============================================================================

class HybridCalibrationGameSoftmax:
    """
    Core hybrid approach: Calibrated weights for foundation + Softmax for game dynamics
    Phase 1: Establish base model strengths through comprehensive calibration
    Phase 2: Apply those strengths dynamically in game theory with softmax
    """
    
    def __init__(self, preprocessing_system, adversarial_framework):
        self.preprocessing = preprocessing_system
        self.adversarial = adversarial_framework
        
        # Phase 1: Base strengths from calibration
        self.base_phi2_strength = 0.5
        self.base_qwen2_strength = 0.5
        self.calibration_confidence = 0.5
        self.calibration_details = {}
        
        # Phase 2: Game dynamics tracking
        self.temperature_history = []
        self.weight_evolution = []
        
    def calibrate_base_strengths(self, calibration_df, num_questions=50, method='comprehensive'):
        """
        Phase 1: Establish empirically-grounded base model strengths
        This gives us the foundation logits for game theory
        """
        print(f"\n{'='*60}")
        print(f"PHASE 1: CALIBRATING BASE MODEL STRENGTHS")
        print(f"Method: {method.upper()}")
        print(f"Questions: {num_questions}")
        print(f"{'='*60}")
        
        sample_df = calibration_df.sample(n=min(num_questions, len(calibration_df)), random_state=42)
        
        if method == 'comprehensive':
            strengths = self._comprehensive_strength_calibration(sample_df)
        elif method == 'adversarial_aware':
            strengths = self._adversarial_aware_calibration(sample_df)
        else:
            strengths = self._simple_strength_calibration(sample_df)
        
        # Store base strengths (these become our logits for game theory)
        self.base_phi2_strength = strengths['phi2_strength']
        self.base_qwen2_strength = strengths['qwen2_strength']
        self.calibration_confidence = strengths['confidence']
        self.calibration_details = strengths
        
        print(f"\nBase Model Strengths Established:")
        print(f"  Phi-2 base strength: {self.base_phi2_strength:.3f}")
        print(f"  Qwen2 base strength: {self.base_qwen2_strength:.3f}")
        print(f"  Calibration confidence: {self.calibration_confidence:.3f}")
        print(f"  Method used: {method}")
        
        return strengths
    
    def _comprehensive_strength_calibration(self, sample_df):
        """Comprehensive calibration considering multiple factors"""
        
        phi2_metrics = {'accuracy': 0, 'confidence': 0, 'consistency': 0, 'robustness': 0}
        qwen2_metrics = {'accuracy': 0, 'confidence': 0, 'consistency': 0, 'robustness': 0}
        
        phi2_predictions = []
        qwen2_predictions = []
        correct_answers = []
        phi2_confidences = []
        qwen2_confidences = []
        
        print("  Analyzing model performance across multiple criteria...")
        
        for idx, (_, row) in enumerate(sample_df.iterrows()):
            question = row["question"]
            choices = row["choices"]
            correct_answer = row.get("answerKey")
            
            if not correct_answer:
                continue
                
            correct_idx = ord(correct_answer) - ord('A')
            prompt = self.preprocessing.build_analysis_prompt(question, choices)
            
            # Get model responses
            phi2_probs = self.preprocessing.get_model_confidence_scores(
                self.preprocessing.phi2_model, self.preprocessing.phi2_tokenizer, prompt, len(choices)
            )
            qwen2_probs = self.preprocessing.get_model_confidence_scores(
                self.preprocessing.qwen2_model, self.preprocessing.qwen2_tokenizer, prompt, len(choices)
            )
            
            phi2_pred = np.argmax(phi2_probs)
            qwen2_pred = np.argmax(qwen2_probs)
            phi2_conf = phi2_probs[phi2_pred]
            qwen2_conf = qwen2_probs[qwen2_pred]
            
            # Store for analysis
            phi2_predictions.append(phi2_pred)
            qwen2_predictions.append(qwen2_pred)
            correct_answers.append(correct_idx)
            phi2_confidences.append(phi2_conf)
            qwen2_confidences.append(qwen2_conf)
            
            # Update metrics
            phi2_metrics['accuracy'] += (phi2_pred == correct_idx)
            qwen2_metrics['accuracy'] += (qwen2_pred == correct_idx)
            phi2_metrics['confidence'] += phi2_conf
            qwen2_metrics['confidence'] += qwen2_conf
            
            # Progress update
            if (idx + 1) % 10 == 0 or idx == len(sample_df) - 1:
                print(f"    Processed {idx + 1}/{len(sample_df)} questions...")
        
        n = len(sample_df)
        
        # Normalize basic metrics
        phi2_metrics['accuracy'] /= n
        qwen2_metrics['accuracy'] /= n
        phi2_metrics['confidence'] /= n
        qwen2_metrics['confidence'] /= n
        
        # Calculate advanced metrics
        phi2_metrics['consistency'] = self._calculate_consistency(phi2_predictions)
        qwen2_metrics['consistency'] = self._calculate_consistency(qwen2_predictions)
        
        phi2_metrics['robustness'] = self._calculate_robustness(phi2_confidences, phi2_predictions, correct_answers)
        qwen2_metrics['robustness'] = self._calculate_robustness(qwen2_confidences, qwen2_predictions, correct_answers)
        
        # Calculate complementarity
        complementarity = self._calculate_complementarity(phi2_predictions, qwen2_predictions, correct_answers)
        
        # Combine into strength scores with weighted criteria
        criteria_weights = {
            'accuracy': 0.4,      # Most important: raw performance
            'confidence': 0.2,    # How confident when making predictions
            'consistency': 0.2,   # How consistent across questions
            'robustness': 0.2     # High confidence when correct
        }
        
        phi2_strength = sum(criteria_weights[k] * phi2_metrics[k] for k in criteria_weights)
        qwen2_strength = sum(criteria_weights[k] * qwen2_metrics[k] for k in criteria_weights)
        
        # Calculate calibration confidence (how reliable these strengths are)
        strength_variance = abs(phi2_strength - qwen2_strength)
        confidence = min(1.0, 0.3 + 0.7 * (1 - strength_variance))  # Higher difference = higher confidence
        
        print(f"  Comprehensive Analysis Results:")
        print(f"    Phi-2 - Accuracy: {phi2_metrics['accuracy']:.1%}, Confidence: {phi2_metrics['confidence']:.3f}")
        print(f"           Consistency: {phi2_metrics['consistency']:.3f}, Robustness: {phi2_metrics['robustness']:.3f}")
        print(f"    Qwen2 - Accuracy: {qwen2_metrics['accuracy']:.1%}, Confidence: {qwen2_metrics['confidence']:.3f}")
        print(f"           Consistency: {qwen2_metrics['consistency']:.3f}, Robustness: {qwen2_metrics['robustness']:.3f}")
        print(f"    Complementarity: {complementarity:.1%}")
        print(f"    Combined strengths: Phi-2={phi2_strength:.3f}, Qwen2={qwen2_strength:.3f}")
        
        return {
            'phi2_strength': phi2_strength,
            'qwen2_strength': qwen2_strength,
            'confidence': confidence,
            'phi2_metrics': phi2_metrics,
            'qwen2_metrics': qwen2_metrics,
            'complementarity': complementarity,
            'sample_size': n,
            'method': 'comprehensive'
        }
    
    def _simple_strength_calibration(self, sample_df):
        """Simple accuracy-based strength calibration"""
        phi2_correct = 0
        qwen2_correct = 0
        total_phi2_conf = 0
        total_qwen2_conf = 0
        
        print("  Running simple accuracy-based calibration...")
        
        for idx, (_, row) in enumerate(sample_df.iterrows()):
            question = row["question"]
            choices = row["choices"]
            correct_answer = row.get("answerKey")
            
            if not correct_answer:
                continue
                
            correct_idx = ord(correct_answer) - ord('A')
            prompt = self.preprocessing.build_analysis_prompt(question, choices)
            
            phi2_probs = self.preprocessing.get_model_confidence_scores(
                self.preprocessing.phi2_model, self.preprocessing.phi2_tokenizer, prompt, len(choices)
            )
            qwen2_probs = self.preprocessing.get_model_confidence_scores(
                self.preprocessing.qwen2_model, self.preprocessing.qwen2_tokenizer, prompt, len(choices)
            )
            
            phi2_pred = np.argmax(phi2_probs)
            qwen2_pred = np.argmax(qwen2_probs)
            
            if phi2_pred == correct_idx:
                phi2_correct += 1
            if qwen2_pred == correct_idx:
                qwen2_correct += 1
                
            total_phi2_conf += phi2_probs[phi2_pred]
            total_qwen2_conf += qwen2_probs[qwen2_pred]
            
            # Progress update
            if (idx + 1) % 10 == 0 or idx == len(sample_df) - 1:
                print(f"    Processed {idx + 1}/{len(sample_df)} questions...")
        
        phi2_strength = phi2_correct / len(sample_df)
        qwen2_strength = qwen2_correct / len(sample_df)
        avg_phi2_conf = total_phi2_conf / len(sample_df)
        avg_qwen2_conf = total_qwen2_conf / len(sample_df)
        
        # Confidence based on how clear the winner is
        strength_diff = abs(phi2_strength - qwen2_strength)
        confidence = min(1.0, 0.5 + strength_diff)
        
        print(f"  Simple Calibration Results:")
        print(f"    Phi-2: {phi2_strength:.1%} accuracy, {avg_phi2_conf:.3f} avg confidence")
        print(f"    Qwen2: {qwen2_strength:.1%} accuracy, {avg_qwen2_conf:.3f} avg confidence")
        
        return {
            'phi2_strength': phi2_strength,
            'qwen2_strength': qwen2_strength,
            'confidence': confidence,
            'phi2_accuracy': phi2_strength,
            'qwen2_accuracy': qwen2_strength,
            'phi2_avg_confidence': avg_phi2_conf,
            'qwen2_avg_confidence': avg_qwen2_conf,
            'sample_size': len(sample_df),
            'method': 'simple'
        }
    
    def _adversarial_aware_calibration(self, sample_df):
        """Calibration that considers adversarial dynamics"""
        # This would run mini adversarial games during calibration
        # For now, fall back to comprehensive method
        print("  Running adversarial-aware calibration (using comprehensive method)...")
        return self._comprehensive_strength_calibration(sample_df)
    
    def game_theory_softmax_equilibrium(self, gen_init, disc_init, candidates, T=None,
                                       temperature_strategy='adaptive'):
        """
        Phase 2: Use calibrated base strengths in game theory with softmax dynamics
        """
        if T is None:
            T = PERF_CONFIG.equilibrium_iterations
            
        # Initialize with base strengths
        gen = {"correct": dict(gen_init["correct"]), "incorrect": dict(gen_init["incorrect"])}
        disc = {y: dict(disc_init[y]) for y in candidates}
        
        # Q-values for learning
        Qg = {"correct": {y: 0.0 for y in candidates}, "incorrect": {y: 0.0 for y in candidates}}
        Qd = {y: {"correct": 0.0, "incorrect": 0.0} for y in candidates}
        
        equilibrium_history = []
        
        for t in range(1, T+1):
            # Determine temperature for this iteration
            temperature = self._get_temperature(t, T, temperature_strategy, gen, disc, candidates)
            
            # Apply base strengths with current temperature to get dynamic weights
            base_logits = torch.tensor([self.base_phi2_strength, self.base_qwen2_strength], dtype=torch.float32)
            current_weights = F.softmax(base_logits / temperature, dim=0)
            
            gen_weight = float(current_weights[0])  # Phi-2 weight (generator)
            disc_weight = float(current_weights[1])  # Qwen2 weight (discriminator)
            
            # Store iteration info
            iteration_info = {
                'iteration': t,
                'temperature': temperature,
                'gen_weight': gen_weight,
                'disc_weight': disc_weight,
                'base_phi2_strength': self.base_phi2_strength,
                'base_qwen2_strength': self.base_qwen2_strength
            }
            
            # Update Q-values with current weights
            for v in ["correct", "incorrect"]:
                for y in candidates:
                    # Discriminator feedback weighted by its current strength
                    Qg[v][y] += (disc_weight / (2.0 * t)) * disc[y][v]
            
            for y in candidates:
                for v in ["correct", "incorrect"]:
                    # Generator feedback weighted by its current strength
                    Qd[y][v] += (gen_weight / (2.0 * t)) * gen[v][y]
            
            # Update generator policy (with weighted learning rate)
            eta_G = 0.1 * gen_weight
            for v in ["correct", "incorrect"]:
                logits = []
                for y in candidates:
                    val = (Qg[v][y] + 0.1 * math.log(gen_init[v][y] + 1e-12)) / (1/eta_G + 0.1)
                    logits.append(val)
                
                if max(logits) > min(logits):
                    new_probs = self._softmax(np.array(logits))
                    for i, y in enumerate(candidates):
                        gen[v][y] = new_probs[i]
            
            # Update discriminator policy (with weighted learning rate)
            eta_D = 0.1 * disc_weight
            for y in candidates:
                logits = [
                    (Qd[y]["correct"] + 0.01 * math.log(disc_init[y]["correct"] + 1e-12)) / (1/eta_D + 0.01),
                    (Qd[y]["incorrect"] + 0.01 * math.log(disc_init[y]["incorrect"] + 1e-12)) / (1/eta_D + 0.01)
                ]
                
                if max(logits) > min(logits):
                    probs = self._softmax(np.array(logits))
                    disc[y]["correct"] = probs[0]
                    disc[y]["incorrect"] = probs[1]
            
            # Evaluate convergence
            convergence_metrics = self._evaluate_convergence(gen, disc, candidates, t, T)
            iteration_info.update(convergence_metrics)
            
            equilibrium_history.append(iteration_info)
            
            # Early stopping if converged
            if convergence_metrics.get('converged', False):
                break
        
        # Store history for analysis
        self.temperature_history.extend([info['temperature'] for info in equilibrium_history])
        self.weight_evolution.extend(equilibrium_history)
        
        return {
            'gen_final': gen,
            'disc_final': disc,
            'equilibrium_history': equilibrium_history,
            'final_weights': {
                'gen_weight': equilibrium_history[-1]['gen_weight'],
                'disc_weight': equilibrium_history[-1]['disc_weight']
            },
            'convergence_iteration': len(equilibrium_history),
            'temperature_strategy': temperature_strategy
        }
    
    def _get_temperature(self, iteration, total_iterations, strategy, gen, disc, candidates):
        """Determine temperature for current iteration"""
        
        if strategy == 'fixed':
            return 1.0
        
        elif strategy == 'annealing':
            # Start high, cool down linearly
            return 2.0 * (total_iterations - iteration) / total_iterations + 0.5
        
        elif strategy == 'adaptive':
            # Adapt based on calibration confidence and game state
            
            # Base temperature from calibration confidence
            base_temp = 0.5 + (2.0 - 0.5) * (1 - self.calibration_confidence)
            
            # Adjust based on iteration progress
            if iteration <= 2:
                # Early iterations: explore more
                return base_temp * 1.5
            
            # Check strategy stability for later iterations
            if iteration > 3:
                stability = self._calculate_strategy_stability(gen, disc, candidates)
                if stability > 0.95:  # Very stable
                    return base_temp * 0.7  # Exploit more
                elif stability < 0.8:  # Unstable
                    return base_temp * 1.3  # Explore more
            
            return base_temp
        
        elif strategy == 'confidence_based':
            # Temperature based purely on calibration confidence
            return 0.5 + 1.5 * (1 - self.calibration_confidence)
        
        else:
            return 1.0  # Default
    
    def _calculate_strategy_stability(self, gen, disc, candidates):
        """Calculate how stable the current strategies are"""
        
        # Generator stability (how concentrated is the "correct" strategy)
        gen_probs = [gen["correct"][y] for y in candidates]
        gen_entropy = self._entropy(gen_probs)
        
        # Discriminator stability (how concentrated are the "correct" judgments)
        disc_probs = [disc[y]["correct"] for y in candidates]
        disc_entropy = self._entropy(disc_probs)
        
        # Normalize by maximum possible entropy
        max_entropy = np.log(len(candidates))
        if max_entropy > 0:
            gen_stability = 1.0 - (gen_entropy / max_entropy)
            disc_stability = 1.0 - (disc_entropy / max_entropy)
        else:
            gen_stability = disc_stability = 1.0
        
        # Overall stability
        return (gen_stability + disc_stability) / 2
    
    def _evaluate_convergence(self, gen, disc, candidates, iteration, total_iterations):
        """Evaluate if equilibrium has converged"""
        
        # Check strategy concentration
        gen_concentration = max(gen["correct"][y] for y in candidates)
        disc_concentration = max(disc[y]["correct"] for y in candidates)
        
        # Check early stopping criteria
        converged = False
        if iteration >= 3:  # Don't stop too early
            if gen_concentration > 0.8 and disc_concentration > 0.8:
                converged = True
            elif iteration >= total_iterations:
                converged = True
        
        return {
            'converged': converged,
            'gen_concentration': gen_concentration,
            'disc_concentration': disc_concentration,
            'stability': self._calculate_strategy_stability(gen, disc, candidates)
        }
    
    def get_final_answer_with_hybrid_weights(self, equilibrium_result, candidates, correct_answer=None):
        """Get final answer using the hybrid approach results"""
        
        gen_final = equilibrium_result['gen_final']
        disc_final = equilibrium_result['disc_final']
        final_weights = equilibrium_result['final_weights']
        
        # Use the final dynamic weights from the game
        gen_weight = final_weights['gen_weight']
        disc_weight = final_weights['disc_weight']
        
        # Get individual predictions
        gen_answer = max(candidates, key=lambda x: gen_final["correct"][x])
        disc_answer = max(candidates, key=lambda x: disc_final[x]["correct"])
        
        # Weighted combination for final answer
        combined_scores = {}
        for y in candidates:
            gen_score = gen_final["correct"][y]
            disc_score = disc_final[y]["correct"]
            combined_scores[y] = gen_weight * gen_score + disc_weight * disc_score
        
        best_answer = max(candidates, key=lambda x: combined_scores[x])
        
        # Calculate accuracy if correct answer provided
        accuracy = 0.0  # Default to 0.0 instead of None
        if correct_answer:
            accuracy = 1.0 if best_answer == correct_answer else 0.0
        
        return {
            'final_answer': best_answer,
            'gen_answer': gen_answer,
            'disc_answer': disc_answer,
            'combined_scores': combined_scores,
            'weights_used': final_weights,
            'base_strengths': {
                'phi2': self.base_phi2_strength,
                'qwen2': self.base_qwen2_strength
            },
            'weight_evolution': equilibrium_result['equilibrium_history'],
            'accuracy': accuracy
        }
    
    # Helper methods
    def _calculate_consistency(self, predictions):
        """Calculate prediction consistency (1 - normalized entropy)"""
        if not predictions:
            return 0.0
        unique, counts = np.unique(predictions, return_counts=True)
        probs = counts / len(predictions)
        entropy = -sum(p * np.log(p + 1e-12) for p in probs if p > 0)
        max_entropy = np.log(len(unique)) if len(unique) > 1 else 1
        return 1.0 - (entropy / max_entropy) if max_entropy > 0 else 1.0
    
    def _calculate_robustness(self, confidences, predictions, correct_answers):
        """Calculate prediction robustness (high confidence when correct)"""
        correct_mask = np.array(predictions) == np.array(correct_answers)
        if not any(correct_mask):
            return 0.0
        correct_confidences = np.array(confidences)[correct_mask]
        return float(np.mean(correct_confidences))
    
    def _calculate_complementarity(self, phi2_preds, qwen2_preds, correct_answers):
        """Calculate how well models complement each other"""
        phi2_correct = np.array(phi2_preds) == np.array(correct_answers)
        qwen2_correct = np.array(qwen2_preds) == np.array(correct_answers)
        
        # Combined performance (either model correct)
        either_correct = phi2_correct | qwen2_correct
        combined_accuracy = np.mean(either_correct)
        
        # Best individual performance
        best_individual = max(np.mean(phi2_correct), np.mean(qwen2_correct))
        
        # Complementarity = relative improvement from combination
        if best_individual > 0:
            return max(0.0, (combined_accuracy - best_individual) / best_individual)
        else:
            return 0.0
    
    def _softmax(self, arr):
        """Numerically stable softmax"""
        m = np.max(arr)
        exp_vals = np.exp(arr - m)
        return exp_vals / np.sum(exp_vals)
    
    def _entropy(self, probs):
        """Calculate entropy of probability distribution"""
        probs = np.array(probs)
        probs = probs[probs > 0]  # Remove zeros
        if len(probs) == 0:
            return 0.0
        return -np.sum(probs * np.log(probs + 1e-12))

#==============================================================================
# ENHANCED PREPROCESSING SYSTEM (Compatible with Hybrid)
#==============================================================================

class HybridOptimizedPreprocessingSystem:
    """
    Enhanced preprocessing system that works with hybrid calibration
    """
    
    def __init__(self, confidence_threshold=0.10):
        self.confidence_threshold = confidence_threshold
        self.phi2_model = None
        self.phi2_tokenizer = None
        self.qwen2_model = None
        self.qwen2_tokenizer = None
        
        # Weights (will be set by hybrid calibrator)
        self.phi2_weight = 0.5
        self.qwen2_weight = 0.5
        self.weights_calibrated = False
        self.calibration_results = None
        
    def load_preprocessing_models(self):
        """Load both models for preprocessing with optimizations."""
        print(f"Loading preprocessing models on {PERF_CONFIG.device}...")
        
        print("Loading Phi-2 for preprocessing...")
        self.phi2_model = AutoModelForCausalLM.from_pretrained(
            "microsoft/phi-2",
            torch_dtype=PERF_CONFIG.dtype,
            low_cpu_mem_usage=True,
            device_map="auto" if PERF_CONFIG.device == "cuda" else "cpu",
            trust_remote_code=True
        )
        self.phi2_tokenizer = AutoTokenizer.from_pretrained(
            "microsoft/phi-2", 
            trust_remote_code=True
        )
        if self.phi2_tokenizer.pad_token is None:
            self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token
            
        print("Loading Qwen2-1.5B for preprocessing...")
        self.qwen2_model = AutoModelForCausalLM.from_pretrained(
            "Qwen/Qwen2-1.5B-Instruct",
            torch_dtype=PERF_CONFIG.dtype,
            low_cpu_mem_usage=True,
            device_map="auto" if PERF_CONFIG.device == "cuda" else "cpu",
            trust_remote_code=True
        )
        self.qwen2_tokenizer = AutoTokenizer.from_pretrained(
            "Qwen/Qwen2-1.5B-Instruct",
            trust_remote_code=True
        )
        if self.qwen2_tokenizer.pad_token is None:
            self.qwen2_tokenizer.pad_token = self.qwen2_tokenizer.eos_token
            
        # Set models to eval mode
        self.phi2_model.eval()
        self.qwen2_model.eval()
            
        print(f"Preprocessing models loaded successfully!")
        
    def build_analysis_prompt(self, question, choices):
        """Build prompt for confidence analysis."""
        prompt = f"""The following is a multiple choice science question. Analyze all choices and select the most likely correct answer.

Question: {question}
"""
        for i, choice in enumerate(choices):
            prompt += f"{chr(65+i)}. {choice}\n"
        prompt += "\nAnswer:"
        return prompt
        
    def get_model_confidence_scores(self, model, tokenizer, prompt, num_choices):
        """Get confidence scores from a model for answer choices."""
        try:
            inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
            
            # Move inputs to device if using GPU
            if PERF_CONFIG.device == "cuda":
                inputs = {k: v.to(model.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                with torch.cuda.amp.autocast(enabled=PERF_CONFIG.use_fp16 and PERF_CONFIG.device == "cuda"):
                    outputs = model(**inputs)
                    logits = outputs.logits[0, -1]
            
            # Get probabilities for A, B, C, D based on number of choices
            choice_logits = []
            for i in range(num_choices):
                letter = chr(65 + i)  # A, B, C, D
                token_id = tokenizer.encode(letter, add_special_tokens=False)[-1]
                choice_logits.append(logits[token_id].item())
            
            # Convert to probabilities
            choice_logits = torch.tensor(choice_logits)
            probs = torch.nn.functional.softmax(choice_logits, dim=0).numpy()
            
            return probs
            
        except Exception as e:
            print(f"Error getting confidence scores: {e}")
            # Return uniform distribution as fallback
            return np.ones(num_choices) / num_choices
    
    def analyze_all_choices(self, question, choices):
        """Analyze all choices using both models with current weights"""
        prompt = self.build_analysis_prompt(question, choices)
        num_choices = len(choices)
        
        # Get confidence scores from both models
        phi2_probs = self.get_model_confidence_scores(
            self.phi2_model, self.phi2_tokenizer, prompt, num_choices
        )
        qwen2_probs = self.get_model_confidence_scores(
            self.qwen2_model, self.qwen2_tokenizer, prompt, num_choices
        )
        
        # Use current weights (set by hybrid calibrator)
        combined_probs = self.phi2_weight * phi2_probs + self.qwen2_weight * qwen2_probs
        
        return {
            'phi2_probs': phi2_probs,
            'qwen2_probs': qwen2_probs,
            'combined_probs': combined_probs,
            'weights_used': {
                'phi2': self.phi2_weight,
                'qwen2': self.qwen2_weight
            }
        }
    
    def filter_low_confidence_choices(self, choices, combined_probs):
        """Filter out choices with prob < threshold"""
        # Find choices above threshold
        high_confidence_indices = [
            i for i, prob in enumerate(combined_probs) 
            if prob >= self.confidence_threshold
        ]
        
        # Fallback: keep all choices if none meet threshold
        if not high_confidence_indices:
            high_confidence_indices = list(range(len(choices)))
        
        # Extract high-confidence choices
        filtered_choices = [choices[i] for i in high_confidence_indices]
        filtered_probs = [combined_probs[i] for i in high_confidence_indices]
        
        return filtered_choices, filtered_probs, high_confidence_indices
    
    def randomize_choice_order(self, choices, probs, original_indices):
        """Randomly shuffle choices to prevent positional bias."""
        combined = list(zip(choices, probs, original_indices))
        random.shuffle(combined)
        
        shuffled_choices, shuffled_probs, shuffled_original_indices = zip(*combined)
        position_mapping = {i: orig_idx for i, orig_idx in enumerate(shuffled_original_indices)}
        
        return list(shuffled_choices), list(shuffled_probs), position_mapping
    
    def preprocess_question(self, question, choices, answer_key=None):
        """Complete preprocessing pipeline"""
        # Analyze all choices
        analysis_results = self.analyze_all_choices(question, choices)
        
        # Filter low-confidence choices
        filtered_choices, filtered_probs, high_confidence_indices = self.filter_low_confidence_choices(
            choices, analysis_results['combined_probs']
        )
        
        # Randomize to prevent bias
        final_choices, final_probs, position_mapping = self.randomize_choice_order(
            filtered_choices, filtered_probs, high_confidence_indices
        )
        
        # Update answer key if provided
        new_answer_key = None
        if answer_key:
            original_answer_idx = ord(answer_key) - ord('A')
            if original_answer_idx in high_confidence_indices:
                for new_pos, orig_pos in position_mapping.items():
                    if orig_pos == original_answer_idx:
                        new_answer_key = chr(ord('A') + new_pos)
                        break
        
        return {
            'original_question': question,
            'original_choices': choices,
            'filtered_choices': final_choices,
            'analysis_results': analysis_results,
            'position_mapping': position_mapping,
            'high_confidence_indices': high_confidence_indices,
            'original_answer_key': answer_key,
            'new_answer_key': new_answer_key,
            'filtering_applied': len(final_choices) < len(choices)
        }
    
    def cleanup_preprocessing_models(self):
        """Clean up preprocessing models."""
        if self.phi2_model:
            del self.phi2_model
        if self.phi2_tokenizer:
            del self.phi2_tokenizer
        if self.qwen2_model:
            del self.qwen2_model
        if self.qwen2_tokenizer:
            del self.qwen2_tokenizer
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

#==============================================================================
# ENHANCED ADVERSARIAL FRAMEWORK (Compatible with Hybrid)
#==============================================================================

class HybridOptimizedAdversarialFramework:
    """
    Enhanced adversarial framework that works with hybrid calibration
    """
    
    def __init__(self):
        self.generator_model = None      # Phi-2
        self.generator_tokenizer = None
        self.discriminator_model = None  # Qwen2-1.5B
        self.discriminator_tokenizer = None
    
    def setup_adversarial_models(self, phi2_model, phi2_tokenizer, qwen2_model, qwen2_tokenizer):
        """Set up adversarial roles"""
        self.generator_model = phi2_model         # Phi-2 as Generator
        self.generator_tokenizer = phi2_tokenizer
        self.discriminator_model = qwen2_model    # Qwen2-1.5B as Discriminator
        self.discriminator_tokenizer = qwen2_tokenizer
        
        print("Adversarial roles set:")
        print("  Generator: Phi-2")
        print("  Discriminator: Qwen2-1.5B-Instruct")

    def format_subject(self, subject):
        """Format subject string."""
        return " ".join(subject.split("_"))

    def build_generator_prompt(self, subject, question, choices, get_correct):
        """Build prompt for generator."""
        prompt = f"The following are multiple choice questions (with answers) about {self.format_subject(subject)}.\n\n"
        prompt += f"{question}"
        
        for i, choice in enumerate(choices):
            prompt += f"\n{chr(65+i)}. {choice}"
            
        if get_correct:
            prompt += "\nAnswer:"
        else:
            prompt += "\nIncorrect Answer:"
        return prompt

    def get_generator_probabilities(self, prompt_text, num_choices):
        """Get generator answer probabilities."""
        try:
            inputs = self.generator_tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512)
            
            # Move to device if using GPU
            if PERF_CONFIG.device == "cuda":
                inputs = {k: v.to(self.generator_model.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                with torch.cuda.amp.autocast(enabled=PERF_CONFIG.use_fp16 and PERF_CONFIG.device == "cuda"):
                    outputs = self.generator_model(**inputs)
                    logits = outputs.logits[0, -1]

            # Get probabilities for A, B, C, D...
            choice_logits = []
            for i in range(num_choices):
                letter = chr(65 + i)
                token_id = self.generator_tokenizer.encode(letter, add_special_tokens=False)[-1]
                choice_logits.append(logits[token_id].item())
            
            choice_logits = torch.tensor(choice_logits)
            probs = torch.nn.functional.softmax(choice_logits, dim=0).numpy()
            
            return probs
            
        except Exception as e:
            print(f"Error in generator probabilities: {e}")
            return np.ones(num_choices) / num_choices

    def get_generator_initial_probs(self, question, choices, subject):
        """Get initial generator probabilities for correct/incorrect."""
        gen_init = {"correct": {}, "incorrect": {}}
        candidates = [f"{chr(65+i)}" for i in range(len(choices))]
        
        for get_correct in [True, False]:
            prompt = self.build_generator_prompt(subject, question, choices, get_correct)
            probs = self.get_generator_probabilities(prompt, len(choices))
            
            for i, candidate in enumerate(candidates):
                if get_correct:
                    gen_init["correct"][candidate] = probs[i]
                else:
                    gen_init["incorrect"][candidate] = probs[i]

        return gen_init

    def build_discriminator_prompt(self, subject, question, proposed_answer):
        """Build prompt for discriminator."""
        prompt = f"""You are an expert evaluator of questions about {self.format_subject(subject)}. 
Determine if the proposed answer is correct. Output ONLY 'A' or 'B'.
Question: {question}
Proposed Answer: {proposed_answer}

Is this answer correct? Respond ONLY with:
A. Correct
B. Incorrect

Answer:"""
        return prompt

    def get_discriminator_probabilities(self, prompt_text):
        """Get discriminator probabilities."""
        try:
            inputs = self.discriminator_tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512)
            
            # Move to device if using GPU
            if PERF_CONFIG.device == "cuda":
                inputs = {k: v.to(self.discriminator_model.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                with torch.cuda.amp.autocast(enabled=PERF_CONFIG.use_fp16 and PERF_CONFIG.device == "cuda"):
                    outputs = self.discriminator_model(**inputs)
                    logits = outputs.logits[0, -1]

            # Get probabilities for A (correct) and B (incorrect)
            a_token = self.discriminator_tokenizer.encode("A", add_special_tokens=False)[-1]
            b_token = self.discriminator_tokenizer.encode("B", add_special_tokens=False)[-1]
            
            choice_logits = torch.tensor([logits[a_token], logits[b_token]])
            probs = torch.nn.functional.softmax(choice_logits, dim=0).numpy()
            
            return {"correct": float(probs[0]), "incorrect": float(probs[1])}
            
        except Exception as e:
            print(f"Error in discriminator probabilities: {e}")
            return {"correct": 0.5, "incorrect": 0.5}

    def get_discriminator_initial_probs(self, question, choices, subject):
        """Get initial discriminator probabilities for each choice."""
        results = {}
        
        for idx, answer in enumerate(choices):
            prompt = self.build_discriminator_prompt(subject, question, answer)
            probs = self.get_discriminator_probabilities(prompt)
            
            candidate = f"{chr(65+idx)}"
            results[candidate] = probs
        
        return results

    def softmax(self, arr):
        """Numerically stable softmax."""
        m = np.max(arr)
        exp_vals = np.exp(arr - m)
        return exp_vals / np.sum(exp_vals)

    def equilibrium_search(self, gen_init, disc_init, candidates, T=None, 
                          eta_G=0.1, eta_D=0.1, lam_G=0.1, lam_D=0.01):
        """Standard equilibrium search (for compatibility)"""
        if T is None:
            T = PERF_CONFIG.equilibrium_iterations
            
        gen = {"correct": dict(gen_init["correct"]), 
               "incorrect": dict(gen_init["incorrect"])}
        disc = {}
        for y in candidates:
            disc[y] = dict(disc_init[y])

        Qg = {"correct": {y: 0.0 for y in candidates}, 
              "incorrect": {y: 0.0 for y in candidates}}
        Qd = {y: {"correct": 0.0, "incorrect": 0.0} for y in candidates}
        
        prev_gen = None
        prev_disc = None

        for t in range(1, T+1):
            # Update Q values
            for v in ["correct", "incorrect"]:
                for y in candidates:
                    Qg[v][y] += (1.0/(2.0*t)) * disc[y][v]

            for y in candidates:
                for v in ["correct", "incorrect"]:
                    Qd[y][v] += (1.0/(2.0*t)) * gen[v][y]

            # Update generator policy
            for v in ["correct", "incorrect"]:
                logits = []
                for y in candidates:
                    val = (Qg[v][y] + lam_G * math.log(gen_init[v][y] + 1e-12)) / (1/eta_G + lam_G)
                    logits.append(val)

                new_probs = self.softmax(np.array(logits))
                for i, y in enumerate(candidates):
                    gen[v][y] = new_probs[i]

            # Update discriminator policy
            for y in candidates:
                logits = [
                    (Qd[y]["correct"] + lam_D * math.log(disc_init[y]["correct"] + 1e-12)) / (1/eta_D + lam_D),
                    (Qd[y]["incorrect"] + lam_D * math.log(disc_init[y]["incorrect"] + 1e-12)) / (1/eta_D + lam_D)
                ]
                probs = self.softmax(np.array(logits))
                disc[y]["correct"] = probs[0]
                disc[y]["incorrect"] = probs[1]
            
            # Early stopping check
            if t > 2 and prev_gen is not None:
                gen_change = sum(abs(gen[v][y] - prev_gen[v][y]) 
                               for v in ["correct", "incorrect"] for y in candidates)
                disc_change = sum(abs(disc[y][v] - prev_disc[y][v]) 
                                for y in candidates for v in ["correct", "incorrect"])
                
                max_change = max(gen_change, disc_change) / (len(candidates) * 2)
                
                if max_change < PERF_CONFIG.early_stop_threshold:
                    break
            
            # Store previous values
            prev_gen = {v: dict(gen[v]) for v in ["correct", "incorrect"]}
            prev_disc = {y: dict(disc[y]) for y in candidates}

        return gen, disc

    def get_final_answers(self, gen_final, disc_final, candidates):
        """Get final answers from generator and discriminator"""
        # Generator's final answer
        gen_answer = None
        best_gen_prob = -1.0
        for y in candidates:
            p = gen_final["correct"][y]
            if p > best_gen_prob:
                best_gen_prob = p
                gen_answer = y
        
        # Discriminator's final answer
        disc_answer = None
        best_disc_prob = -1.0
        for y in candidates:
            p = disc_final[y]["correct"]
            if p > best_disc_prob:
                best_disc_prob = p
                disc_answer = y
        
        return gen_answer, disc_answer

#==============================================================================
# COMPLETE HYBRID INTEGRATED PIPELINE
#==============================================================================

class HybridIntegratedPipeline:
    """
    Complete hybrid pipeline: Calibrated base strengths + Softmax game dynamics
    This is the main class that orchestrates everything
    """
    
    def __init__(self, confidence_threshold=0.10, 
                 calibration_method='comprehensive',
                 temperature_strategy='adaptive',
                 calibration_size=50):
        
        self.preprocessing = HybridOptimizedPreprocessingSystem(confidence_threshold)
        self.adversarial = HybridOptimizedAdversarialFramework()
        self.hybrid_calibrator = None
        
        self.models_loaded = False
        self.calibration_method = calibration_method
        self.temperature_strategy = temperature_strategy
        self.calibration_size = calibration_size
        self.processing_times = []
        
    def initialize_all_models(self):
        """Initialize all models and hybrid calibrator."""
        if not self.models_loaded:
            print("Initializing hybrid pipeline...")
            
            # Load preprocessing models
            self.preprocessing.load_preprocessing_models()
            
            # Set up adversarial framework with same models
            self.adversarial.setup_adversarial_models(
                self.preprocessing.phi2_model, 
                self.preprocessing.phi2_tokenizer,
                self.preprocessing.qwen2_model, 
                self.preprocessing.qwen2_tokenizer
            )
            
            # Initialize hybrid calibrator
            self.hybrid_calibrator = HybridCalibrationGameSoftmax(
                self.preprocessing,
                self.adversarial
            )
            
            self.models_loaded = True
            print("Hybrid pipeline initialized successfully!")
    
    def process_single_question(self, question, choices, answer_key, subject):
        """Process a single question through the hybrid pipeline."""
        start_time = time.time()
        
        # Steps 2-3: Preprocessing
        preprocessing_result = self.preprocessing.preprocess_question(question, choices, answer_key)
        
        filtered_choices = preprocessing_result["filtered_choices"]
        new_answer_key = preprocessing_result.get("new_answer_key")
        
        # Check if only one choice remains (automatic consensus)
        if len(filtered_choices) == 1:
            consensus_answer = "A"
            processing_time = time.time() - start_time
            self.processing_times.append(processing_time)
            
            # Calculate consensus accuracy
            consensus_accuracy = 0.0
            if new_answer_key:
                consensus_accuracy = 1.0 if consensus_answer == new_answer_key else 0.0
            
            return {
                **preprocessing_result,
                'consensus_achieved': True,
                'final_answer': consensus_answer,
                'hybrid_method_used': False,
                'accuracy': consensus_accuracy,
                'processing_time': processing_time
            }
        
        # Steps 4-6: Hybrid Adversarial Training
        candidates = [f"{chr(65+i)}" for i in range(len(filtered_choices))]
        
        # Step 4: Get initial probabilities
        gen_init = self.adversarial.get_generator_initial_probs(question, filtered_choices, subject)
        disc_init = self.adversarial.get_discriminator_initial_probs(question, filtered_choices, subject)
        
        # Step 5: HYBRID equilibrium search with calibrated strengths + softmax
        equilibrium_result = self.hybrid_calibrator.game_theory_softmax_equilibrium(
            gen_init, disc_init, candidates,
            temperature_strategy=self.temperature_strategy
        )
        
        # Step 6: Get final answer with hybrid weights
        final_result = self.hybrid_calibrator.get_final_answer_with_hybrid_weights(
            equilibrium_result, candidates, new_answer_key
        )
        
        final_answer = final_result['final_answer']
        accuracy = final_result.get('accuracy', 0.0)
        
        processing_time = time.time() - start_time
        self.processing_times.append(processing_time)
        
        return {
            **preprocessing_result,
            'consensus_achieved': False,
            'final_answer': final_answer,
            'gen_answer': final_result['gen_answer'],
            'disc_answer': final_result['disc_answer'],
            'hybrid_method_used': True,
            'hybrid_weights_used': final_result['weights_used'],
            'base_strengths': final_result['base_strengths'],
            'weight_evolution': final_result['weight_evolution'],
            'equilibrium_history': equilibrium_result['equilibrium_history'],
            'convergence_iteration': equilibrium_result['convergence_iteration'],
            'temperature_strategy': self.temperature_strategy,
            'accuracy': accuracy,
            'processing_time': processing_time
        }
    
    def process_dataset(self, df, max_questions=None):
        """Process complete dataset through hybrid pipeline."""
        self.initialize_all_models()
        
        # Phase 1: Hybrid Calibration
        calibration_indices = []
        print("\nPHASE 1: HYBRID CALIBRATION")
        calibration_results = self.hybrid_calibrator.calibrate_base_strengths(
            df, 
            num_questions=self.calibration_size,
            method=self.calibration_method
        )
        
        # Remove calibration questions from main dataset
        calibration_indices = df.sample(n=min(self.calibration_size, len(df)), random_state=42).index.tolist()
        df = df.drop(calibration_indices)
        print(f"Calibration complete. Removed {len(calibration_indices)} questions from main dataset")
        print(f"Remaining questions: {len(df)}")
        
        # Update preprocessing weights with calibrated base strengths
        self.preprocessing.phi2_weight = self.hybrid_calibrator.base_phi2_strength
        self.preprocessing.qwen2_weight = self.hybrid_calibrator.base_qwen2_strength
        self.preprocessing.weights_calibrated = True
        
        # Limit to subset if specified
        if max_questions is not None:
            df = df.head(max_questions)
            print(f"Processing subset of {len(df)} questions...")
        
        # Phase 2: Main Processing
        results = []
        
        # Accuracy tracking
        total_questions = 0
        consensus_questions = 0
        hybrid_questions = 0
        total_correct = 0
        
        print(f"\nPHASE 2: HYBRID GAME THEORY PROCESSING")
        print(f"Base strengths: Phi-2={self.hybrid_calibrator.base_phi2_strength:.3f}, Qwen2={self.hybrid_calibrator.base_qwen2_strength:.3f}")
        print(f"Temperature strategy: {self.temperature_strategy}")
        print(f"Calibration method: {self.calibration_method}")
        
        start_time = time.time()
        
        # Process each question
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Hybrid pipeline"):
            total_questions += 1
            
            # Process through hybrid pipeline
            result = self.process_single_question(
                row["question"], 
                row["choices"], 
                row.get("answerKey"),
                row["subject"]
            )
            
            # Add original row data
            result.update(row.to_dict())
            
            # Track accuracy
            if result['consensus_achieved']:
                consensus_questions += 1
            else:
                hybrid_questions += 1
            
            accuracy = result.get('accuracy', 0)
            if accuracy is not None and accuracy > 0:
                total_correct += 1
            
            results.append(result)
            
            # Memory management
            if total_questions % PERF_CONFIG.clear_cache_every == 0:
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
        
        total_time = time.time() - start_time
        avg_time_per_question = total_time / total_questions if total_questions > 0 else 0
        
        # Print comprehensive results
        self._print_hybrid_results(
            results, total_questions, consensus_questions, hybrid_questions, 
            total_correct, total_time, avg_time_per_question, calibration_results
        )
        
        return pd.DataFrame(results)
    
    def _print_hybrid_results(self, results, total_questions, consensus_questions, hybrid_questions, 
                             total_correct, total_time, avg_time_per_question, calibration_results):
        """Print comprehensive hybrid results"""
        
        print(f"\n{'='*70}")
        print(f"HYBRID CALIBRATION + GAME SOFTMAX RESULTS")
        print(f"{'='*70}")
        
        # Performance metrics
        print(f"Performance Metrics:")
        print(f"   Total processing time: {total_time:.2f} seconds")
        print(f"   Average time per question: {avg_time_per_question:.3f} seconds")
        print(f"   Questions per second: {total_questions/total_time:.1f}")
        print(f"   Device used: {PERF_CONFIG.device} ({PERF_CONFIG.device_name})")
        print(f"   Precision: {'FP16' if PERF_CONFIG.use_fp16 else 'FP32'}")
        
        if self.processing_times:
            print(f"   Min processing time: {min(self.processing_times):.3f}s")
            print(f"   Max processing time: {max(self.processing_times):.3f}s")
        
        # Calibration info
        print(f"\nCalibration Details:")
        print(f"   Method: {self.calibration_method}")
        print(f"   Sample size: {self.calibration_size}")
        print(f"   Phi-2 base strength: {self.hybrid_calibrator.base_phi2_strength:.3f}")
        print(f"   Qwen2 base strength: {self.hybrid_calibrator.base_qwen2_strength:.3f}")
        print(f"   Calibration confidence: {self.hybrid_calibrator.calibration_confidence:.3f}")
        
        if 'phi2_metrics' in calibration_results:
            phi2_metrics = calibration_results['phi2_metrics']
            qwen2_metrics = calibration_results['qwen2_metrics']
            print(f"   Phi-2 metrics: Acc={phi2_metrics['accuracy']:.1%}, Conf={phi2_metrics['confidence']:.3f}")
            print(f"   Qwen2 metrics: Acc={qwen2_metrics['accuracy']:.1%}, Conf={qwen2_metrics['confidence']:.3f}")
            print(f"   Complementarity: {calibration_results.get('complementarity', 0):.1%}")
        
        # Processing breakdown
        print(f"\nProcessing Breakdown:")
        print(f"   Total questions: {total_questions}")
        print(f"   Preprocessing consensus: {consensus_questions} ({consensus_questions/total_questions*100:.1f}%)")
        print(f"   Hybrid game theory: {hybrid_questions} ({hybrid_questions/total_questions*100:.1f}%)")
        
        # Final accuracy (handle None values properly)
        valid_results = [r.get('accuracy', 0) for r in results if r.get('accuracy') is not None]
        overall_accuracy = sum(valid_results) / len(valid_results) if valid_results else 0
        total_correct = sum(1 for acc in valid_results if acc > 0)
        
        print(f"\nFinal Results:")
        print(f"   Overall accuracy: {total_correct}/{len(valid_results)} = {overall_accuracy:.1%}")
        print(f"   Valid results: {len(valid_results)}/{total_questions}")
        print(f"   Temperature strategy: {self.temperature_strategy}")
        
        # Temperature analysis
        if self.hybrid_calibrator.temperature_history:
            temps = self.hybrid_calibrator.temperature_history
            print(f"   Temperature range: {min(temps):.2f} - {max(temps):.2f}")
            print(f"   Average temperature: {np.mean(temps):.2f}")
        
        # Weight evolution analysis
        if self.hybrid_calibrator.weight_evolution:
            final_weights = self.hybrid_calibrator.weight_evolution[-1]
            print(f"   Final dynamic weights: Gen={final_weights['gen_weight']:.3f}, Disc={final_weights['disc_weight']:.3f}")
        
        print(f"\nHYBRID PIPELINE COMPLETE!")
        print(f"{'='*70}")
    
    def cleanup_all_models(self):
        """Clean up all models."""
        self.preprocessing.cleanup_preprocessing_models()
        self.models_loaded = False
        self.processing_times = []

#==============================================================================
# MAIN EXECUTION FUNCTIONS
#==============================================================================

def test_hybrid_subset(num_questions=50, 
                      calibration_method='comprehensive',
                      temperature_strategy='adaptive'):
    """Test the hybrid pipeline on a subset"""
    print("="*70)
    print(f"TESTING HYBRID PIPELINE - {num_questions} QUESTIONS")
    print(f"Calibration method: {calibration_method}")
    print(f"Temperature strategy: {temperature_strategy}")
    print(f"Device: {PERF_CONFIG.device}")
    print("="*70)
    
    # Load datasets
    arc_df, arc_easy_df = load_arc_datasets()
    
    # Initialize hybrid pipeline
    pipeline = HybridIntegratedPipeline(
        confidence_threshold=0.10,
        calibration_method=calibration_method,
        temperature_strategy=temperature_strategy,
        calibration_size=30
    )
    
    # Test on ARC Challenge subset
    print(f"\nTesting hybrid approach on {num_questions} ARC Challenge questions...")
    results = pipeline.process_dataset(arc_df, max_questions=num_questions)
    
    # Save results
    filename = f"Data/hybrid_test_{num_questions}_{calibration_method}_{temperature_strategy}_{PERF_CONFIG.device}.csv"
    results.to_csv(filename, index=False)
    print(f"Results saved to {filename}")
    
    # Cleanup
    pipeline.cleanup_all_models()
    
    return results

def compare_approaches(num_questions=100):
    """Compare hybrid vs standard approaches"""
    print("="*70)
    print(f"COMPARING APPROACHES - {num_questions} QUESTIONS EACH")
    print("="*70)
    
    # Load data
    arc_df, _ = load_arc_datasets()
    test_data = arc_df.head(num_questions + 100)  # Extra for calibration
    
    results = {}
    
    # Test different configurations
    configs = [
        {'method': 'simple', 'temp': 'fixed', 'name': 'Simple + Fixed Temp'},
        {'method': 'simple', 'temp': 'adaptive', 'name': 'Simple + Adaptive Temp'},
        {'method': 'comprehensive', 'temp': 'fixed', 'name': 'Comprehensive + Fixed Temp'},
        {'method': 'comprehensive', 'temp': 'adaptive', 'name': 'Comprehensive + Adaptive Temp'},
        {'method': 'comprehensive', 'temp': 'annealing', 'name': 'Comprehensive + Annealing Temp'},
    ]
    
    for config in configs:
        print(f"\nTesting: {config['name']}")
        
        pipeline = HybridIntegratedPipeline(
            calibration_method=config['method'],
            temperature_strategy=config['temp'],
            calibration_size=50
        )
        
        try:
            config_results = pipeline.process_dataset(test_data, max_questions=num_questions)
            
            # Handle accuracy calculation properly
            accuracies = [r.get('accuracy', 0) for r in config_results.to_dict('records') if r.get('accuracy') is not None]
            accuracy = sum(accuracies) / len(accuracies) if accuracies else 0
            
            results[config['name']] = {
                'accuracy': accuracy,
                'results': config_results,
                'config': config
            }
            
            print(f"   Accuracy: {accuracy:.1%}")
            
        except Exception as e:
            print(f"   Error: {e}")
            results[config['name']] = {'accuracy': 0, 'error': str(e)}
        
        finally:
            pipeline.cleanup_all_models()
    
    # Find best configuration
    valid_results = {k: v for k, v in results.items() if 'accuracy' in v and v['accuracy'] > 0}
    if valid_results:
        best_config = max(valid_results.keys(), key=lambda k: valid_results[k]['accuracy'])
        print(f"\nBEST CONFIGURATION: {best_config}")
        print(f"   Accuracy: {valid_results[best_config]['accuracy']:.1%}")
        
        # Save best results
        best_results = valid_results[best_config]['results']
        filename = f"Data/hybrid_best_config_{num_questions}_{PERF_CONFIG.device}.csv"
        best_results.to_csv(filename, index=False)
        print(f"Best results saved to {filename}")
    
    return results

def main_hybrid():
    """Main execution with hybrid pipeline"""
    
    print("="*70)
    print("HYBRID ADVERSARIAL QA SYSTEM")
    print("="*70)
    print("Phase 1: Comprehensive calibration (base model strengths)")
    print("Phase 2: Softmax game theory dynamics")
    print("Features: Temperature control, weight evolution, adaptive equilibrium")
    print("="*70)
    
    print("\nChoose execution mode:")
    print("1. Quick test (50 questions, comprehensive + adaptive)")
    print("2. Compare different configurations (100 questions each)")
    print("3. Full ARC Challenge dataset (1000+ questions)")
    print("4. Custom configuration test")
    
    choice = input("Enter choice (1/2/3/4): ").strip()
    
    if choice == "1":
        return test_hybrid_subset(50, 'comprehensive', 'adaptive')
    
    elif choice == "2":
        return compare_approaches(100)
    
    elif choice == "3":
        # Full dataset processing
        arc_df, arc_easy_df = load_arc_datasets()
        
        pipeline = HybridIntegratedPipeline(
            calibration_method='comprehensive',
            temperature_strategy='adaptive',
            calibration_size=100
        )
        
        print("\nProcessing FULL ARC Challenge dataset with hybrid approach...")
        results_challenge = pipeline.process_dataset(arc_df)
        
        # Save results
        challenge_filename = f"Data/hybrid_full_arc_challenge_{PERF_CONFIG.device}.csv"
        results_challenge.to_csv(challenge_filename, index=False)
        print(f"ARC Challenge results saved to {challenge_filename}")
        
        print("\nProcessing FULL ARC Easy dataset with hybrid approach...")
        results_easy = pipeline.process_dataset(arc_easy_df)
        
        # Save results
        easy_filename = f"Data/hybrid_full_arc_easy_{PERF_CONFIG.device}.csv"
        results_easy.to_csv(easy_filename, index=False)
        print(f"ARC Easy results saved to {easy_filename}")
        
        # Final cleanup
        pipeline.cleanup_all_models()
        
        print(f"\nCOMPLETE HYBRID PROCESSING FINISHED!")
        return results_challenge, results_easy
    
    elif choice == "4":
        # Custom configuration
        print("\nCustom Configuration:")
        
        cal_method = input("Calibration method (simple/comprehensive) [comprehensive]: ").strip() or 'comprehensive'
        temp_strategy = input("Temperature strategy (fixed/adaptive/annealing) [adaptive]: ").strip() or 'adaptive'
        num_questions = int(input("Number of test questions [100]: ").strip() or '100')
        
        return test_hybrid_subset(num_questions, cal_method, temp_strategy)
    
    else:
        print("Invalid choice. Running quick test...")
        return test_hybrid_subset(50, 'comprehensive', 'adaptive')

if __name__ == "__main__":
    main_hybrid()