In [None]:
import math
import numpy as np
import pandas as pd
import os
import gc
import torch
import torch.nn.functional as F
import random
import time
import warnings
from collections import defaultdict
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
warnings.filterwarnings('ignore')

class PerformanceConfig:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device_name = torch.cuda.get_device_name() if torch.cuda.is_available() else "CPU"
        self.use_fp16 = self.device == "cuda"
        self.dtype = torch.float16 if self.use_fp16 else torch.float32
        self.clear_cache_every = 50
        self.equilibrium_iterations = 5
        self.early_stop_threshold = 0.001

PERF_CONFIG = PerformanceConfig()

def load_arc_datasets():
    arc_data = load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test")
    arc_df = arc_data.to_pandas()
    arc_df = arc_df.drop_duplicates(subset=['question'])
    arc_df["choices"] = arc_df["choices"].apply(lambda x: x["text"])
    arc_df["subject"] = "science"

    arc_easy_data = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test")
    arc_easy_df = arc_easy_data.to_pandas()
    arc_easy_df = arc_easy_df.drop_duplicates(subset=['question'])
    arc_easy_df["choices"] = arc_easy_df["choices"].apply(lambda x: x["text"])
    arc_easy_df["subject"] = "science"

    return arc_df, arc_easy_df

class HybridCalibrationGameSoftmax:
    def __init__(self, preprocessing_system, adversarial_framework):
        self.preprocessing = preprocessing_system
        self.adversarial = adversarial_framework

        self.base_phi2_strength = 0.5
        self.base_qwen2_strength = 0.5
        self.calibration_confidence = 0.5
        self.calibration_details = {}

        self.temperature_history = []
        self.weight_evolution = []

    def calibrate_base_strengths(self, calibration_df, num_questions=50, method='comprehensive'):
        sample_df = calibration_df.sample(n=min(num_questions, len(calibration_df)), random_state=42)

        if method == 'comprehensive':
            strengths = self._comprehensive_strength_calibration(sample_df)
        elif method == 'adversarial_aware':
            strengths = self._adversarial_aware_calibration(sample_df)
        else:
            strengths = self._simple_strength_calibration(sample_df)

        self.base_phi2_strength = strengths['phi2_strength']
        self.base_qwen2_strength = strengths['qwen2_strength']
        self.calibration_confidence = strengths['confidence']
        self.calibration_details = strengths

        return strengths

    def _comprehensive_strength_calibration(self, sample_df):
        phi2_metrics = {'accuracy': 0, 'confidence': 0, 'consistency': 0, 'robustness': 0}
        qwen2_metrics = {'accuracy': 0, 'confidence': 0, 'consistency': 0, 'robustness': 0}

        phi2_predictions = []
        qwen2_predictions = []
        correct_answers = []
        phi2_confidences = []
        qwen2_confidences = []

        for idx, (_, row) in enumerate(sample_df.iterrows()):
            question = row["question"]
            choices = row["choices"]
            correct_answer = row.get("answerKey")

            if not correct_answer:
                continue

            correct_idx = ord(correct_answer) - ord('A')
            prompt = self.preprocessing.build_analysis_prompt(question, choices)

            phi2_probs = self.preprocessing.get_model_confidence_scores(
                self.preprocessing.phi2_model, self.preprocessing.phi2_tokenizer, prompt, len(choices)
            )
            qwen2_probs = self.preprocessing.get_model_confidence_scores(
                self.preprocessing.qwen2_model, self.preprocessing.qwen2_tokenizer, prompt, len(choices)
            )

            phi2_pred = np.argmax(phi2_probs)
            qwen2_pred = np.argmax(qwen2_probs)
            phi2_conf = phi2_probs[phi2_pred]
            qwen2_conf = qwen2_probs[qwen2_pred]

            phi2_predictions.append(phi2_pred)
            qwen2_predictions.append(qwen2_pred)
            correct_answers.append(correct_idx)
            phi2_confidences.append(phi2_conf)
            qwen2_confidences.append(qwen2_conf)

            phi2_metrics['accuracy'] += (phi2_pred == correct_idx)
            qwen2_metrics['accuracy'] += (qwen2_pred == correct_idx)
            phi2_metrics['confidence'] += phi2_conf
            qwen2_metrics['confidence'] += qwen2_conf

        n = len(sample_df)

        phi2_metrics['accuracy'] /= n
        qwen2_metrics['accuracy'] /= n
        phi2_metrics['confidence'] /= n
        qwen2_metrics['confidence'] /= n

        phi2_metrics['consistency'] = self._calculate_consistency(phi2_predictions)
        qwen2_metrics['consistency'] = self._calculate_consistency(qwen2_predictions)

        phi2_metrics['robustness'] = self._calculate_robustness(phi2_confidences, phi2_predictions, correct_answers)
        qwen2_metrics['robustness'] = self._calculate_robustness(qwen2_confidences, qwen2_predictions, correct_answers)

        complementarity = self._calculate_complementarity(phi2_predictions, qwen2_predictions, correct_answers)

        criteria_weights = {
            'accuracy': 0.4,
            'confidence': 0.2,
            'consistency': 0.2,
            'robustness': 0.2
        }

        phi2_strength = sum(criteria_weights[k] * phi2_metrics[k] for k in criteria_weights)
        qwen2_strength = sum(criteria_weights[k] * qwen2_metrics[k] for k in criteria_weights)

        strength_variance = abs(phi2_strength - qwen2_strength)
        confidence = min(1.0, 0.3 + 0.7 * (1 - strength_variance))

        return {
            'phi2_strength': phi2_strength,
            'qwen2_strength': qwen2_strength,
            'confidence': confidence,
            'phi2_metrics': phi2_metrics,
            'qwen2_metrics': qwen2_metrics,
            'complementarity': complementarity,
            'sample_size': n,
            'method': 'comprehensive'
        }

    def _simple_strength_calibration(self, sample_df):
        phi2_correct = 0
        qwen2_correct = 0
        total_phi2_conf = 0
        total_qwen2_conf = 0

        for idx, (_, row) in enumerate(sample_df.iterrows()):
            question = row["question"]
            choices = row["choices"]
            correct_answer = row.get("answerKey")

            if not correct_answer:
                continue

            correct_idx = ord(correct_answer) - ord('A')
            prompt = self.preprocessing.build_analysis_prompt(question, choices)

            phi2_probs = self.preprocessing.get_model_confidence_scores(
                self.preprocessing.phi2_model, self.preprocessing.phi2_tokenizer, prompt, len(choices)
            )
            qwen2_probs = self.preprocessing.get_model_confidence_scores(
                self.preprocessing.qwen2_model, self.preprocessing.qwen2_tokenizer, prompt, len(choices)
            )

            phi2_pred = np.argmax(phi2_probs)
            qwen2_pred = np.argmax(qwen2_probs)

            if phi2_pred == correct_idx:
                phi2_correct += 1
            if qwen2_pred == correct_idx:
                qwen2_correct += 1

            total_phi2_conf += phi2_probs[phi2_pred]
            total_qwen2_conf += qwen2_probs[qwen2_pred]

        phi2_strength = phi2_correct / len(sample_df)
        qwen2_strength = qwen2_correct / len(sample_df)
        avg_phi2_conf = total_phi2_conf / len(sample_df)
        avg_qwen2_conf = total_qwen2_conf / len(sample_df)

        strength_diff = abs(phi2_strength - qwen2_strength)
        confidence = min(1.0, 0.5 + strength_diff)

        return {
            'phi2_strength': phi2_strength,
            'qwen2_strength': qwen2_strength,
            'confidence': confidence,
            'phi2_accuracy': phi2_strength,
            'qwen2_accuracy': qwen2_strength,
            'phi2_avg_confidence': avg_phi2_conf,
            'qwen2_avg_confidence': avg_qwen2_conf,
            'sample_size': len(sample_df),
            'method': 'simple'
        }

    def _adversarial_aware_calibration(self, sample_df):
        return self._comprehensive_strength_calibration(sample_df)

    def game_theory_softmax_equilibrium(self, gen_init, disc_init, candidates, T=None,
                                       temperature_strategy='adaptive'):
        if T is None:
            T = PERF_CONFIG.equilibrium_iterations

        gen = {"correct": dict(gen_init["correct"]), "incorrect": dict(gen_init["incorrect"])}
        disc = {y: dict(disc_init[y]) for y in candidates}

        Qg = {"correct": {y: 0.0 for y in candidates}, "incorrect": {y: 0.0 for y in candidates}}
        Qd = {y: {"correct": 0.0, "incorrect": 0.0} for y in candidates}

        equilibrium_history = []

        for t in range(1, T+1):
            temperature = self._get_temperature(t, T, temperature_strategy, gen, disc, candidates)

            base_logits = torch.tensor([self.base_phi2_strength, self.base_qwen2_strength], dtype=torch.float32)
            current_weights = F.softmax(base_logits / temperature, dim=0)

            gen_weight = float(current_weights[0])
            disc_weight = float(current_weights[1])

            iteration_info = {
                'iteration': t,
                'temperature': temperature,
                'gen_weight': gen_weight,
                'disc_weight': disc_weight,
                'base_phi2_strength': self.base_phi2_strength,
                'base_qwen2_strength': self.base_qwen2_strength
            }

            for v in ["correct", "incorrect"]:
                for y in candidates:
                    Qg[v][y] += (disc_weight / (2.0 * t)) * disc[y][v]

            for y in candidates:
                for v in ["correct", "incorrect"]:
                    Qd[y][v] += (gen_weight / (2.0 * t)) * gen[v][y]

            eta_G = 0.1 * gen_weight
            for v in ["correct", "incorrect"]:
                logits = []
                for y in candidates:
                    val = (Qg[v][y] + 0.1 * math.log(gen_init[v][y] + 1e-12)) / (1/eta_G + 0.1)
                    logits.append(val)

                if max(logits) > min(logits):
                    new_probs = self._softmax(np.array(logits))
                    for i, y in enumerate(candidates):
                        gen[v][y] = new_probs[i]

            eta_D = 0.1 * disc_weight
            for y in candidates:
                logits = [
                    (Qd[y]["correct"] + 0.01 * math.log(disc_init[y]["correct"] + 1e-12)) / (1/eta_D + 0.01),
                    (Qd[y]["incorrect"] + 0.01 * math.log(disc_init[y]["incorrect"] + 1e-12)) / (1/eta_D + 0.01)
                ]

                if max(logits) > min(logits):
                    probs = self._softmax(np.array(logits))
                    disc[y]["correct"] = probs[0]
                    disc[y]["incorrect"] = probs[1]

            convergence_metrics = self._evaluate_convergence(gen, disc, candidates, t, T)
            iteration_info.update(convergence_metrics)

            equilibrium_history.append(iteration_info)

            if convergence_metrics.get('converged', False):
                break

        self.temperature_history.extend([info['temperature'] for info in equilibrium_history])
        self.weight_evolution.extend(equilibrium_history)

        return {
            'gen_final': gen,
            'disc_final': disc,
            'equilibrium_history': equilibrium_history,
            'final_weights': {
                'gen_weight': equilibrium_history[-1]['gen_weight'],
                'disc_weight': equilibrium_history[-1]['disc_weight']
            },
            'convergence_iteration': len(equilibrium_history),
            'temperature_strategy': temperature_strategy
        }

    def _get_temperature(self, iteration, total_iterations, strategy, gen, disc, candidates):
        if strategy == 'fixed':
            return 1.0

        elif strategy == 'annealing':
            return 2.0 * (total_iterations - iteration) / total_iterations + 0.5

        elif strategy == 'adaptive':
            base_temp = 0.5 + (2.0 - 0.5) * (1 - self.calibration_confidence)

            if iteration <= 2:
                return base_temp * 1.5

            if iteration > 3:
                stability = self._calculate_strategy_stability(gen, disc, candidates)
                if stability > 0.95:
                    return base_temp * 0.7
                elif stability < 0.8:
                    return base_temp * 1.3

            return base_temp

        elif strategy == 'confidence_based':
            return 0.5 + 1.5 * (1 - self.calibration_confidence)

        else:
            return 1.0

    def _calculate_strategy_stability(self, gen, disc, candidates):
        gen_probs = [gen["correct"][y] for y in candidates]
        gen_entropy = self._entropy(gen_probs)

        disc_probs = [disc[y]["correct"] for y in candidates]
        disc_entropy = self._entropy(disc_probs)

        max_entropy = np.log(len(candidates))
        if max_entropy > 0:
            gen_stability = 1.0 - (gen_entropy / max_entropy)
            disc_stability = 1.0 - (disc_entropy / max_entropy)
        else:
            gen_stability = disc_stability = 1.0

        return (gen_stability + disc_stability) / 2

    def _evaluate_convergence(self, gen, disc, candidates, iteration, total_iterations):
        gen_concentration = max(gen["correct"][y] for y in candidates)
        disc_concentration = max(disc[y]["correct"] for y in candidates)

        converged = False
        if iteration >= 3:
            if gen_concentration > 0.8 and disc_concentration > 0.8:
                converged = True
            elif iteration >= total_iterations:
                converged = True

        return {
            'converged': converged,
            'gen_concentration': gen_concentration,
            'disc_concentration': disc_concentration,
            'stability': self._calculate_strategy_stability(gen, disc, candidates)
        }

    def get_final_answer_with_hybrid_weights(self, equilibrium_result, candidates, correct_answer=None):
        gen_final = equilibrium_result['gen_final']
        disc_final = equilibrium_result['disc_final']
        final_weights = equilibrium_result['final_weights']

        gen_weight = final_weights['gen_weight']
        disc_weight = final_weights['disc_weight']

        gen_answer = max(candidates, key=lambda x: gen_final["correct"][x])
        disc_answer = max(candidates, key=lambda x: disc_final[x]["correct"])

        combined_scores = {}
        for y in candidates:
            gen_score = gen_final["correct"][y]
            disc_score = disc_final[y]["correct"]
            combined_scores[y] = gen_weight * gen_score + disc_weight * disc_score

        best_answer = max(candidates, key=lambda x: combined_scores[x])

        accuracy = 0.0
        if correct_answer:
            accuracy = 1.0 if best_answer == correct_answer else 0.0

        return {
            'final_answer': best_answer,
            'gen_answer': gen_answer,
            'disc_answer': disc_answer,
            'combined_scores': combined_scores,
            'weights_used': final_weights,
            'base_strengths': {
                'phi2': self.base_phi2_strength,
                'qwen2': self.base_qwen2_strength
            },
            'weight_evolution': equilibrium_result['equilibrium_history'],
            'accuracy': accuracy
        }

    def _calculate_consistency(self, predictions):
        if not predictions:
            return 0.0
        unique, counts = np.unique(predictions, return_counts=True)
        probs = counts / len(predictions)
        entropy = -sum(p * np.log(p + 1e-12) for p in probs if p > 0)
        max_entropy = np.log(len(unique)) if len(unique) > 1 else 1
        return 1.0 - (entropy / max_entropy) if max_entropy > 0 else 1.0

    def _calculate_robustness(self, confidences, predictions, correct_answers):
        correct_mask = np.array(predictions) == np.array(correct_answers)
        if not any(correct_mask):
            return 0.0
        correct_confidences = np.array(confidences)[correct_mask]
        return float(np.mean(correct_confidences))

    def _calculate_complementarity(self, phi2_preds, qwen2_preds, correct_answers):
        phi2_correct = np.array(phi2_preds) == np.array(correct_answers)
        qwen2_correct = np.array(qwen2_preds) == np.array(correct_answers)

        either_correct = phi2_correct | qwen2_correct
        combined_accuracy = np.mean(either_correct)

        best_individual = max(np.mean(phi2_correct), np.mean(qwen2_correct))

        if best_individual > 0:
            return max(0.0, (combined_accuracy - best_individual) / best_individual)
        else:
            return 0.0

    def _softmax(self, arr):
        m = np.max(arr)
        exp_vals = np.exp(arr - m)
        return exp_vals / np.sum(exp_vals)

    def _entropy(self, probs):
        probs = np.array(probs)
        probs = probs[probs > 0]
        if len(probs) == 0:
            return 0.0
        return -np.sum(probs * np.log(probs + 1e-12))

class HybridOptimizedPreprocessingSystem:
    def __init__(self, confidence_threshold=0.10):
        self.confidence_threshold = confidence_threshold
        self.phi2_model = None
        self.phi2_tokenizer = None
        self.qwen2_model = None
        self.qwen2_tokenizer = None

        self.phi2_weight = 0.5
        self.qwen2_weight = 0.5
        self.weights_calibrated = False
        self.calibration_results = None

    def load_preprocessing_models(self):
        self.phi2_model = AutoModelForCausalLM.from_pretrained(
            "microsoft/phi-2",
            torch_dtype=PERF_CONFIG.dtype,
            low_cpu_mem_usage=True,
            device_map="auto" if PERF_CONFIG.device == "cuda" else "cpu",
            trust_remote_code=True
        )
        self.phi2_tokenizer = AutoTokenizer.from_pretrained(
            "microsoft/phi-2",
            trust_remote_code=True
        )
        if self.phi2_tokenizer.pad_token is None:
            self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token

        self.qwen2_model = AutoModelForCausalLM.from_pretrained(
            "Qwen/Qwen2-1.5B-Instruct",
            torch_dtype=PERF_CONFIG.dtype,
            low_cpu_mem_usage=True,
            device_map="auto" if PERF_CONFIG.device == "cuda" else "cpu",
            trust_remote_code=True
        )
        self.qwen2_tokenizer = AutoTokenizer.from_pretrained(
            "Qwen/Qwen2-1.5B-Instruct",
            trust_remote_code=True
        )
        if self.qwen2_tokenizer.pad_token is None:
            self.qwen2_tokenizer.pad_token = self.qwen2_tokenizer.eos_token

        self.phi2_model.eval()
        self.qwen2_model.eval()

    def build_analysis_prompt(self, question, choices):
        prompt = f"""The following is a multiple choice science question. Analyze all choices and select the most likely correct answer.

Question: {question}
"""
        for i, choice in enumerate(choices):
            prompt += f"{chr(65+i)}. {choice}\n"
        prompt += "\nAnswer:"
        return prompt

    def get_model_confidence_scores(self, model, tokenizer, prompt, num_choices):
        try:
            inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)

            if PERF_CONFIG.device == "cuda":
                inputs = {k: v.to(model.device) for k, v in inputs.items()}

            with torch.no_grad():
                with torch.cuda.amp.autocast(enabled=PERF_CONFIG.use_fp16 and PERF_CONFIG.device == "cuda"):
                    outputs = model(**inputs)
                    logits = outputs.logits[0, -1]

            choice_logits = []
            for i in range(num_choices):
                letter = chr(65 + i)
                token_id = tokenizer.encode(letter, add_special_tokens=False)[-1]
                choice_logits.append(logits[token_id].item())

            choice_logits = torch.tensor(choice_logits)
            probs = torch.nn.functional.softmax(choice_logits, dim=0).numpy()

            return probs

        except Exception as e:
            return np.ones(num_choices) / num_choices

    def analyze_all_choices(self, question, choices):
        prompt = self.build_analysis_prompt(question, choices)
        num_choices = len(choices)

        phi2_probs = self.get_model_confidence_scores(
            self.phi2_model, self.phi2_tokenizer, prompt, num_choices
        )
        qwen2_probs = self.get_model_confidence_scores(
            self.qwen2_model, self.qwen2_tokenizer, prompt, num_choices
        )

        combined_probs = self.phi2_weight * phi2_probs + self.qwen2_weight * qwen2_probs

        return {
            'phi2_probs': phi2_probs,
            'qwen2_probs': qwen2_probs,
            'combined_probs': combined_probs,
            'weights_used': {
                'phi2': self.phi2_weight,
                'qwen2': self.qwen2_weight
            }
        }

    def filter_low_confidence_choices(self, choices, combined_probs):
        high_confidence_indices = [
            i for i, prob in enumerate(combined_probs)
            if prob >= self.confidence_threshold
        ]

        if not high_confidence_indices:
            high_confidence_indices = list(range(len(choices)))

        filtered_choices = [choices[i] for i in high_confidence_indices]
        filtered_probs = [combined_probs[i] for i in high_confidence_indices]

        return filtered_choices, filtered_probs, high_confidence_indices

    def randomize_choice_order(self, choices, probs, original_indices):
        combined = list(zip(choices, probs, original_indices))
        random.shuffle(combined)

        shuffled_choices, shuffled_probs, shuffled_original_indices = zip(*combined)
        position_mapping = {i: orig_idx for i, orig_idx in enumerate(shuffled_original_indices)}

        return list(shuffled_choices), list(shuffled_probs), position_mapping

    def preprocess_question(self, question, choices, answer_key=None):
        analysis_results = self.analyze_all_choices(question, choices)

        filtered_choices, filtered_probs, high_confidence_indices = self.filter_low_confidence_choices(
            choices, analysis_results['combined_probs']
        )

        final_choices, final_probs, position_mapping = self.randomize_choice_order(
            filtered_choices, filtered_probs, high_confidence_indices
        )

        new_answer_key = None
        if answer_key:
            original_answer_idx = ord(answer_key) - ord('A')
            if original_answer_idx in high_confidence_indices:
                for new_pos, orig_pos in position_mapping.items():
                    if orig_pos == original_answer_idx:
                        new_answer_key = chr(ord('A') + new_pos)
                        break

        return {
            'original_question': question,
            'original_choices': choices,
            'filtered_choices': final_choices,
            'analysis_results': analysis_results,
            'position_mapping': position_mapping,
            'high_confidence_indices': high_confidence_indices,
            'original_answer_key': answer_key,
            'new_answer_key': new_answer_key,
            'filtering_applied': len(final_choices) < len(choices)
        }

    def cleanup_preprocessing_models(self):
        if self.phi2_model:
            del self.phi2_model
        if self.phi2_tokenizer:
            del self.phi2_tokenizer
        if self.qwen2_model:
            del self.qwen2_model
        if self.qwen2_tokenizer:
            del self.qwen2_tokenizer
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

class HybridOptimizedAdversarialFramework:
    def __init__(self):
        self.generator_model = None
        self.generator_tokenizer = None
        self.discriminator_model = None
        self.discriminator_tokenizer = None

    def setup_adversarial_models(self, phi2_model, phi2_tokenizer, qwen2_model, qwen2_tokenizer):
        self.generator_model = phi2_model
        self.generator_tokenizer = phi2_tokenizer
        self.discriminator_model = qwen2_model
        self.discriminator_tokenizer = qwen2_tokenizer

    def format_subject(self, subject):
        return " ".join(subject.split("_"))

    def build_generator_prompt(self, subject, question, choices, get_correct):
        prompt = f"The following are multiple choice questions (with answers) about {self.format_subject(subject)}.\n\n"
        prompt += f"{question}"

        for i, choice in enumerate(choices):
            prompt += f"\n{chr(65+i)}. {choice}"

        if get_correct:
            prompt += "\nAnswer:"
        else:
            prompt += "\nIncorrect Answer:"
        return prompt

    def get_generator_probabilities(self, prompt_text, num_choices):
        try:
            inputs = self.generator_tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512)

            if PERF_CONFIG.device == "cuda":
                inputs = {k: v.to(self.generator_model.device) for k, v in inputs.items()}

            with torch.no_grad():
                with torch.cuda.amp.autocast(enabled=PERF_CONFIG.use_fp16 and PERF_CONFIG.device == "cuda"):
                    outputs = self.generator_model(**inputs)
                    logits = outputs.logits[0, -1]

            choice_logits = []
            for i in range(num_choices):
                letter = chr(65 + i)
                token_id = self.generator_tokenizer.encode(letter, add_special_tokens=False)[-1]
                choice_logits.append(logits[token_id].item())

            choice_logits = torch.tensor(choice_logits)
            probs = torch.nn.functional.softmax(choice_logits, dim=0).numpy()

            return probs

        except Exception as e:
            return np.ones(num_choices) / num_choices

    def get_generator_initial_probs(self, question, choices, subject):
        gen_init = {"correct": {}, "incorrect": {}}
        candidates = [f"{chr(65+i)}" for i in range(len(choices))]

        for get_correct in [True, False]:
            prompt = self.build_generator_prompt(subject, question, choices, get_correct)
            probs = self.get_generator_probabilities(prompt, len(choices))

            for i, candidate in enumerate(candidates):
                if get_correct:
                    gen_init["correct"][candidate] = probs[i]
                else:
                    gen_init["incorrect"][candidate] = probs[i]

        return gen_init

    def build_discriminator_prompt(self, subject, question, proposed_answer):
        prompt = f"""You are an expert evaluator of questions about {self.format_subject(subject)}.
Determine if the proposed answer is correct. Output ONLY 'A' or 'B'.
Question: {question}
Proposed Answer: {proposed_answer}

Is this answer correct? Respond ONLY with:
A. Correct
B. Incorrect

Answer:"""
        return prompt

    def get_discriminator_probabilities(self, prompt_text):
        try:
            inputs = self.discriminator_tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512)

            if PERF_CONFIG.device == "cuda":
                inputs = {k: v.to(self.discriminator_model.device) for k, v in inputs.items()}

            with torch.no_grad():
                with torch.cuda.amp.autocast(enabled=PERF_CONFIG.use_fp16 and PERF_CONFIG.device == "cuda"):
                    outputs = self.discriminator_model(**inputs)
                    logits = outputs.logits[0, -1]

            a_token = self.discriminator_tokenizer.encode("A", add_special_tokens=False)[-1]
            b_token = self.discriminator_tokenizer.encode("B", add_special_tokens=False)[-1]

            choice_logits = torch.tensor([logits[a_token], logits[b_token]])
            probs = torch.nn.functional.softmax(choice_logits, dim=0).numpy()

            return {"correct": float(probs[0]), "incorrect": float(probs[1])}

        except Exception as e:
            return {"correct": 0.5, "incorrect": 0.5}

    def get_discriminator_initial_probs(self, question, choices, subject):
        results = {}

        for idx, answer in enumerate(choices):
            prompt = self.build_discriminator_prompt(subject, question, answer)
            probs = self.get_discriminator_probabilities(prompt)

            candidate = f"{chr(65+idx)}"
            results[candidate] = probs

        return results

    def softmax(self, arr):
        m = np.max(arr)
        exp_vals = np.exp(arr - m)
        return exp_vals / np.sum(exp_vals)

    def equilibrium_search(self, gen_init, disc_init, candidates, T=None,
                          eta_G=0.1, eta_D=0.1, lam_G=0.1, lam_D=0.01):
        if T is None:
            T = PERF_CONFIG.equilibrium_iterations

        gen = {"correct": dict(gen_init["correct"]),
               "incorrect": dict(gen_init["incorrect"])}
        disc = {}
        for y in candidates:
            disc[y] = dict(disc_init[y])

        Qg = {"correct": {y: 0.0 for y in candidates},
              "incorrect": {y: 0.0 for y in candidates}}
        Qd = {y: {"correct": 0.0, "incorrect": 0.0} for y in candidates}

        prev_gen = None
        prev_disc = None

        for t in range(1, T+1):
            for v in ["correct", "incorrect"]:
                for y in candidates:
                    Qg[v][y] += (1.0/(2.0*t)) * disc[y][v]

            for y in candidates:
                for v in ["correct", "incorrect"]:
                    Qd[y][v] += (1.0/(2.0*t)) * gen[v][y]

            for v in ["correct", "incorrect"]:
                logits = []
                for y in candidates:
                    val = (Qg[v][y] + lam_G * math.log(gen_init[v][y] + 1e-12)) / (1/eta_G + lam_G)
                    logits.append(val)

                new_probs = self.softmax(np.array(logits))
                for i, y in enumerate(candidates):
                    gen[v][y] = new_probs[i]

            for y in candidates:
                logits = [
                    (Qd[y]["correct"] + lam_D * math.log(disc_init[y]["correct"] + 1e-12)) / (1/eta_D + lam_D),
                    (Qd[y]["incorrect"] + lam_D * math.log(disc_init[y]["incorrect"] + 1e-12)) / (1/eta_D + lam_D)
                ]
                probs = self.softmax(np.array(logits))
                disc[y]["correct"] = probs[0]
                disc[y]["incorrect"] = probs[1]

            if t > 2 and prev_gen is not None:
                gen_change = sum(abs(gen[v][y] - prev_gen[v][y])
                               for v in ["correct", "incorrect"] for y in candidates)
                disc_change = sum(abs(disc[y][v] - prev_disc[y][v])
                                for y in candidates for v in ["correct", "incorrect"])

                max_change = max(gen_change, disc_change) / (len(candidates) * 2)

                if max_change < PERF_CONFIG.early_stop_threshold:
                    break

            prev_gen = {v: dict(gen[v]) for v in ["correct", "incorrect"]}
            prev_disc = {y: dict(disc[y]) for y in candidates}

        return gen, disc

    def get_final_answers(self, gen_final, disc_final, candidates):
        gen_answer = None
        best_gen_prob = -1.0
        for y in candidates:
            p = gen_final["correct"][y]
            if p > best_gen_prob:
                best_gen_prob = p
                gen_answer = y

        disc_answer = None
        best_disc_prob = -1.0
        for y in candidates:
            p = disc_final[y]["correct"]
            if p > best_disc_prob:
                best_disc_prob = p
                disc_answer = y

        return gen_answer, disc_answer

class HybridIntegratedPipeline:
    def __init__(self, confidence_threshold=0.10,
                 calibration_method='comprehensive',
                 temperature_strategy='adaptive',
                 calibration_size=50):

        self.preprocessing = HybridOptimizedPreprocessingSystem(confidence_threshold)
        self.adversarial = HybridOptimizedAdversarialFramework()
        self.hybrid_calibrator = None

        self.models_loaded = False
        self.calibration_method = calibration_method
        self.temperature_strategy = temperature_strategy
        self.calibration_size = calibration_size
        self.processing_times = []

    def initialize_all_models(self):
        if not self.models_loaded:
            self.preprocessing.load_preprocessing_models()

            self.adversarial.setup_adversarial_models(
                self.preprocessing.phi2_model,
                self.preprocessing.phi2_tokenizer,
                self.preprocessing.qwen2_model,
                self.preprocessing.qwen2_tokenizer
            )

            self.hybrid_calibrator = HybridCalibrationGameSoftmax(
                self.preprocessing,
                self.adversarial
            )

            self.models_loaded = True

    def process_single_question(self, question, choices, answer_key, subject):
        start_time = time.time()

        preprocessing_result = self.preprocessing.preprocess_question(question, choices, answer_key)

        filtered_choices = preprocessing_result["filtered_choices"]
        new_answer_key = preprocessing_result.get("new_answer_key")

        if len(filtered_choices) == 1:
            consensus_answer = "A"
            processing_time = time.time() - start_time
            self.processing_times.append(processing_time)

            consensus_accuracy = 0.0
            if new_answer_key:
                consensus_accuracy = 1.0 if consensus_answer == new_answer_key else 0.0

            return {
                **preprocessing_result,
                'consensus_achieved': True,
                'final_answer': consensus_answer,
                'hybrid_method_used': False,
                'accuracy': consensus_accuracy,
                'processing_time': processing_time
            }

        candidates = [f"{chr(65+i)}" for i in range(len(filtered_choices))]

        gen_init = self.adversarial.get_generator_initial_probs(question, filtered_choices, subject)
        disc_init = self.adversarial.get_discriminator_initial_probs(question, filtered_choices, subject)

        equilibrium_result = self.hybrid_calibrator.game_theory_softmax_equilibrium(
            gen_init, disc_init, candidates,
            temperature_strategy=self.temperature_strategy
        )

        final_result = self.hybrid_calibrator.get_final_answer_with_hybrid_weights(
            equilibrium_result, candidates, new_answer_key
        )

        final_answer = final_result['final_answer']
        accuracy = final_result.get('accuracy', 0.0)

        processing_time = time.time() - start_time
        self.processing_times.append(processing_time)

        return {
            **preprocessing_result,
            'consensus_achieved': False,
            'final_answer': final_answer,
            'gen_answer': final_result['gen_answer'],
            'disc_answer': final_result['disc_answer'],
            'hybrid_method_used': True,
            'hybrid_weights_used': final_result['weights_used'],
            'base_strengths': final_result['base_strengths'],
            'weight_evolution': final_result['weight_evolution'],
            'equilibrium_history': equilibrium_result['equilibrium_history'],
            'convergence_iteration': equilibrium_result['convergence_iteration'],
            'temperature_strategy': self.temperature_strategy,
            'accuracy': accuracy,
            'processing_time': processing_time
        }

    def process_dataset(self, df, max_questions=None):
        self.initialize_all_models()

        calibration_indices = []
        calibration_results = self.hybrid_calibrator.calibrate_base_strengths(
            df,
            num_questions=self.calibration_size,
            method=self.calibration_method
        )

        calibration_indices = df.sample(n=min(self.calibration_size, len(df)), random_state=42).index.tolist()
        df = df.drop(calibration_indices)

        self.preprocessing.phi2_weight = self.hybrid_calibrator.base_phi2_strength
        self.preprocessing.qwen2_weight = self.hybrid_calibrator.base_qwen2_strength
        self.preprocessing.weights_calibrated = True

        if max_questions is not None:
            df = df.head(max_questions)

        results = []

        total_questions = 0
        consensus_questions = 0
        hybrid_questions = 0
        total_correct = 0

        start_time = time.time()

        for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing"):
            total_questions += 1

            result = self.process_single_question(
                row["question"],
                row["choices"],
                row.get("answerKey"),
                row["subject"]
            )

            result.update(row.to_dict())

            if result['consensus_achieved']:
                consensus_questions += 1
            else:
                hybrid_questions += 1

            accuracy = result.get('accuracy', 0)
            if accuracy is not None and accuracy > 0:
                total_correct += 1

            results.append(result)

            if total_questions % PERF_CONFIG.clear_cache_every == 0:
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        total_time = time.time() - start_time
        avg_time_per_question = total_time / total_questions if total_questions > 0 else 0

        self._print_hybrid_results(
            results, total_questions, consensus_questions, hybrid_questions,
            total_correct, total_time, avg_time_per_question, calibration_results
        )

        return pd.DataFrame(results)

    def _print_hybrid_results(self, results, total_questions, consensus_questions, hybrid_questions,
                             total_correct, total_time, avg_time_per_question, calibration_results):

        valid_results = [r.get('accuracy', 0) for r in results if r.get('accuracy') is not None]
        overall_accuracy = sum(valid_results) / len(valid_results) if valid_results else 0
        total_correct = sum(1 for acc in valid_results if acc > 0)

        print(f"\nOverall accuracy: {total_correct}/{len(valid_results)} = {overall_accuracy:.1%}")
        print(f"Total processing time: {total_time:.2f} seconds")
        print(f"Average time per question: {avg_time_per_question:.3f} seconds")

    def cleanup_all_models(self):
        self.preprocessing.cleanup_preprocessing_models()
        self.models_loaded = False
        self.processing_times = []

def main():
    arc_df, arc_easy_df = load_arc_datasets()

    pipeline = HybridIntegratedPipeline(
        calibration_method='comprehensive',
        temperature_strategy='adaptive',
        calibration_size=100
    )

    print("\nProcessing ARC Challenge dataset...")
    results_challenge = pipeline.process_dataset(arc_df)

    challenge_filename = f"Data/hybrid_full_arc_challenge_{PERF_CONFIG.device}.csv"
    results_challenge.to_csv(challenge_filename, index=False)
    print(f"ARC Challenge results saved to {challenge_filename}")

    print("\nProcessing ARC Easy dataset...")
    results_easy = pipeline.process_dataset(arc_easy_df)

    easy_filename = f"Data/hybrid_full_arc_easy_{PERF_CONFIG.device}.csv"
    results_easy.to_csv(easy_filename, index=False)
    print(f"ARC Easy results saved to {easy_filename}")

    pipeline.cleanup_all_models()

    return results_challenge, results_easy

if __name__ == "__main__":
    main()

README.md: 0.00B [00:00, ?B/s]

ARC-Challenge/train-00000-of-00001.parqu(…):   0%|          | 0.00/190k [00:00<?, ?B/s]

ARC-Challenge/test-00000-of-00001.parque(…):   0%|          | 0.00/204k [00:00<?, ?B/s]

ARC-Challenge/validation-00000-of-00001.(…):   0%|          | 0.00/55.7k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1119 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1172 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/299 [00:00<?, ? examples/s]

ARC-Easy/train-00000-of-00001.parquet:   0%|          | 0.00/331k [00:00<?, ?B/s]

ARC-Easy/test-00000-of-00001.parquet:   0%|          | 0.00/346k [00:00<?, ?B/s]

ARC-Easy/validation-00000-of-00001.parqu(…):   0%|          | 0.00/86.1k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2251 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2376 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/570 [00:00<?, ? examples/s]


Processing ARC Challenge dataset...


config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/564M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/660 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]