In [None]:
import math
import numpy as np
import pandas as pd
import os
import gc
import torch
from collections import defaultdict
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

#==============================================================================
# DATA LOADING
#==============================================================================

def load_arc_datasets():
    """Load both ARC Challenge and Easy datasets."""
    print("Loading ARC datasets...")

    # ARC Challenge
    arc_data = load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test")
    arc_df = arc_data.to_pandas()
    arc_df = arc_df.drop_duplicates(subset=['question'])
    arc_df["choices_dic"] = arc_df["choices"]
    arc_df["choices"] = arc_df["choices"].apply(lambda x: x["text"])
    arc_df["subject"] = "science"
    print(f"ARC Challenge shape: {arc_df.shape}")

    # ARC Easy
    arc_data_easy = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test")
    arc_df_easy = arc_data_easy.to_pandas()
    arc_df_easy = arc_df_easy.drop_duplicates(subset=['question'])
    arc_df_easy["choices_dic"] = arc_df_easy["choices"]
    arc_df_easy["choices"] = arc_df_easy["choices"].apply(lambda x: x["text"])
    arc_df_easy["subject"] = "science"
    print(f"ARC Easy shape: {arc_df_easy.shape}")

    return arc_df, arc_df_easy

#==============================================================================
# GENERATOR FUNCTIONS (Phi-2)
#==============================================================================

def format_subject(subject):
    """Format subject string."""
    l = subject.split("_")
    s = ""
    for entry in l:
        s += " " + entry
    return s

def build_generator_prompt(subject, target_question, target_choices, get_correct):
    """Build prompt for generator."""
    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
        format_subject(subject))

    prompt += f"{target_question}"
    for i, c in enumerate(target_choices):
        prompt += "\n{}".format(c)

    if get_correct:
        prompt += "\nAnswer:"
    else:
        prompt += "\nIncorrect Answer:"
    return prompt

def get_generator_answer_probs(model, tokenizer, prompt_text, choices_list):
    """Get generator answer probabilities."""
    try:
        input_ids = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512).input_ids

        if torch.cuda.is_available():
            input_ids = input_ids.to(model.device)

        with torch.no_grad():
            logits = model(input_ids=input_ids).logits[0, -1]

        choices = [f"{chr(65+i)}" for i, choice in enumerate(choices_list)]
        choice_logits = []

        for letter in choices:
            token_id = tokenizer(letter, return_tensors="pt").input_ids[0, -1].item()
            choice_logits.append(logits[token_id].item())

        choice_logits = torch.tensor(choice_logits).float()
        probs = torch.nn.functional.softmax(choice_logits, dim=0).detach().cpu().numpy()

        choice_probs = {choice: prob for choice, prob in zip(choices, probs)}
        return choice_probs

    except Exception as e:
        print(f"Error in generator: {e}")
        # Return uniform distribution as fallback
        choices = [f"{chr(65+i)}" for i, choice in enumerate(choices_list)]
        return {choice: 1.0/len(choices) for choice in choices}

def generator_probs(subject, question, choices_list, get_correct, model, tokenizer):
    """Get generator probabilities."""
    choices = [f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices_list)]
    prompt = build_generator_prompt(subject, question, choices, get_correct)
    probs = get_generator_answer_probs(model, tokenizer, prompt, choices_list)
    return probs

def get_initial_generator_probs(row, model, tokenizer):
    """Get initial generator probabilities for correct/incorrect."""
    gen_init = {"correct": {}, "incorrect": {}}
    x, y_list, subject = row["question"], row["choices"], row["subject"]

    for v in [True, False]:
        choices_letter_prob = generator_probs(subject, x, y_list, v, model, tokenizer)
        if v:
            for key, val in choices_letter_prob.items():
                gen_init["correct"][key] = val
        else:
            for key, val in choices_letter_prob.items():
                gen_init["incorrect"][key] = val

    return gen_init

#==============================================================================
# DISCRIMINATOR FUNCTIONS (Qwen2-1.5B)
#==============================================================================

def build_discriminator_prompt(subject: str, question: str, proposed_answer: str) -> str:
    """Build prompt for discriminator."""
    prompt = f"""You are an expert evaluator of questions about {format_subject(subject)}.
Determine if the proposed answer is correct. Output ONLY 'A' or 'B'.
Question: {question}
Proposed Answer: {proposed_answer}

Is this answer correct? Respond ONLY with:
A. Correct
B. Incorrect

Answer:"""
    return prompt

def get_discriminator_probs(model, tokenizer, prompt_text, choices_list):
    """Get discriminator probabilities."""
    try:
        input_ids = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512).input_ids

        if torch.cuda.is_available():
            input_ids = input_ids.to(model.device)

        with torch.no_grad():
            logits = model(input_ids=input_ids).logits[0, -1]

        choice_logits = torch.tensor([
            logits[tokenizer("A").input_ids[-1]],
            logits[tokenizer("B").input_ids[-1]],
        ]).float()

        disc_dict = {"A":"correct", "B":"incorrect"}
        probs = torch.nn.functional.softmax(choice_logits, dim=0).detach().cpu().numpy()

        choices = ["A", "B"]
        choice_probs = {disc_dict[choice]: prob for choice, prob in zip(choices, probs)}

        return choice_probs

    except Exception as e:
        print(f"Error in discriminator: {e}")
        return {"correct": 0.5, "incorrect": 0.5}

def evaluate_answer_correctness(row, model, tokenizer):
    """Evaluate each answer choice with discriminator."""
    subject = row["subject"]
    question = row["question"]
    choices = row["choices"]

    results = {}

    for idx, answer in enumerate(choices):
        prompt = build_discriminator_prompt(
            subject=subject,
            question=question,
            proposed_answer=f"{answer}"
        )

        probs = get_discriminator_probs(model, tokenizer, prompt, choices)
        answer_letter = f"{chr(65+idx)}"
        results[answer_letter] = probs

    return results

def get_initial_discriminator_probs(row, model, tokenizer):
    """Get initial discriminator probabilities."""
    disc_init = evaluate_answer_correctness(row, model, tokenizer)
    return disc_init

#==============================================================================
# EQUILIBRIUM SEARCH
#==============================================================================

def softmax(arr):
    """Numerically stable softmax."""
    m = np.max(arr)
    exp_vals = np.exp(arr - m)
    return exp_vals / np.sum(exp_vals)

def equilibrium_search(gen_init, disc_init, candidates, T=20, eta_G=0.1, eta_D=0.1, lam_G=0.1, lam_D=0.01):
    """
    Find Nash equilibrium between generator and discriminator
    """
    gen = {"correct": dict(gen_init["correct"]),
           "incorrect": dict(gen_init["incorrect"])}
    disc = {}
    for y in candidates:
        disc[y] = dict(disc_init[y])

    Qg = {"correct": {y: 0.0 for y in candidates},
          "incorrect": {y: 0.0 for y in candidates}}
    Qd = {y: {"correct": 0.0, "incorrect": 0.0} for y in candidates}

    for t in range(1, T+1):
        # Update Q
        for v in ["correct", "incorrect"]:
            for y in candidates:
                Qg[v][y] += (1.0/(2.0*t)) * disc[y][v]

        for y in candidates:
            for v in ["correct", "incorrect"]:
                Qd[y][v] += (1.0/(2.0*t)) * gen[v][y]

        # Update generator policy
        for v in ["correct", "incorrect"]:
            logits = []
            for y in candidates:
                val = (Qg[v][y] + lam_G * math.log(gen_init[v][y] + 1e-12)) / (1/eta_G + lam_G)
                logits.append(val)

            new_probs = softmax(np.array(logits))

            for i, y in enumerate(candidates):
                gen[v][y] = new_probs[i]

        # Update discriminator policy
        for y in candidates:
            logits = [
                (Qd[y]["correct"] + lam_D * math.log(disc_init[y]["correct"] + 1e-12)) / (1/eta_D + lam_D),
                (Qd[y]["incorrect"] + lam_D * math.log(disc_init[y]["incorrect"] + 1e-12)) / (1/eta_D + lam_D)
            ]

            probs = softmax(np.array(logits))
            disc[y]["correct"] = probs[0]
            disc[y]["incorrect"] = probs[1]

    return gen, disc

def pick_answer(gen, disc, candidates, method="generator"):
    """Pick best answer from final policies."""
    if method == "generator":
        best_y = None
        best_prob = -1.0
        for y in candidates:
            p = gen["correct"][y]
            if p > best_prob:
                best_prob = p
                best_y = y
        return best_y
    else:
        best_y = None
        best_prob = -1.0
        for y in candidates:
            p = disc[y]["correct"]
            if p > best_prob:
                best_prob = p
                best_y = y
        return best_y

#==============================================================================
# MODEL LOADING
#==============================================================================

def load_model(model_name):
    """Load model with appropriate settings."""
    print(f"Loading {model_name}...")

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        low_cpu_mem_usage=True,
        device_map="auto" if torch.cuda.is_available() else "cpu",
        trust_remote_code=True
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print(f"✓ {model_name} loaded successfully")
    return model, tokenizer

#==============================================================================
# MAIN PROCESSING
#==============================================================================

def process_dataset(generator_model, generator_tokenizer,
                   discriminator_model, discriminator_tokenizer,
                   df, dataset_name="Dataset"):
    """Process dataset through consensus game."""

    total_questions = len(df)
    results = []

    # Initialize result tracking
    gen_init_correct = 0
    disc_init_correct = 0
    gen_final_correct = 0
    disc_final_correct = 0

    print(f"\nProcessing {dataset_name}...")
    print(f"Total questions: {total_questions}")

    for idx, (_, row) in enumerate(tqdm(df.iterrows(), total=len(df), desc=dataset_name)):
        try:
            # Get initial discriminator probabilities
            disc_init = get_initial_discriminator_probs(row, discriminator_model, discriminator_tokenizer)

            # Get initial generator probabilities
            gen_init = get_initial_generator_probs(row, generator_model, generator_tokenizer)

            # Initial answers
            candidates = [f"{chr(65+i)}" for i, choice in enumerate(row["choices"])]
            gen_init_answer = pick_answer(gen_init, disc_init, candidates, method="generator")
            disc_init_answer = pick_answer(gen_init, disc_init, candidates, method="discriminator")

            # Equilibrium search
            gen_final, disc_final = equilibrium_search(gen_init, disc_init, candidates, T=20)

            # Final answers
            gen_final_answer = pick_answer(gen_final, disc_final, candidates, method="generator")
            disc_final_answer = pick_answer(gen_final, disc_final, candidates, method="discriminator")

            # Track accuracy
            correct_answer = row.get("answerKey")
            if correct_answer:
                if gen_init_answer == correct_answer:
                    gen_init_correct += 1
                if disc_init_answer == correct_answer:
                    disc_init_correct += 1
                if gen_final_answer == correct_answer:
                    gen_final_correct += 1
                if disc_final_answer == correct_answer:
                    disc_final_correct += 1

            # Store results
            result = {
                'question': row['question'],
                'choices': row['choices'],
                'answerKey': correct_answer,
                'gen_init_answer': gen_init_answer,
                'disc_init_answer': disc_init_answer,
                'gen_answer': gen_final_answer,
                'disc_answer': disc_final_answer,
                'gen_init_policy': gen_init,
                'disc_init_policy': disc_init,
                'gen_final_policy': gen_final,
                'disc_final_policy': disc_final
            }
            results.append(result)

            # Memory management
            if (idx + 1) % 100 == 0:
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        except Exception as e:
            print(f"Error processing question {idx}: {e}")
            # Add default values for failed questions
            result = {
                'question': row['question'],
                'choices': row['choices'],
                'answerKey': row.get('answerKey'),
                'gen_init_answer': 'A',
                'disc_init_answer': 'A',
                'gen_answer': 'A',
                'disc_answer': 'A',
                'error': str(e)
            }
            results.append(result)

    # Print accuracy summary
    print(f"\n{dataset_name} Results:")
    print(f"Generator (Phi-2) Performance:")
    print(f"  Initial Accuracy: {gen_init_correct}/{total_questions} = {gen_init_correct/total_questions*100:.2f}%")
    print(f"  Final Accuracy:   {gen_final_correct}/{total_questions} = {gen_final_correct/total_questions*100:.2f}%")
    print(f"  Improvement:      {(gen_final_correct - gen_init_correct)/total_questions*100:+.2f}%")

    print(f"\nDiscriminator (Qwen2-1.5B) Performance:")
    print(f"  Initial Accuracy: {disc_init_correct}/{total_questions} = {disc_init_correct/total_questions*100:.2f}%")
    print(f"  Final Accuracy:   {disc_final_correct}/{total_questions} = {disc_final_correct/total_questions*100:.2f}%")
    print(f"  Improvement:      {(disc_final_correct - disc_init_correct)/total_questions*100:+.2f}%")

    overall_improvement = ((gen_final_correct + disc_final_correct) - (gen_init_correct + disc_init_correct)) / (2 * total_questions) * 100
    print(f"\nOverall Consensus Game Impact: {overall_improvement:+.2f}%")

    return pd.DataFrame(results)

#==============================================================================
# MAIN EXECUTION
#==============================================================================

def main():
    """Main execution for baseline consensus game."""
    print("="*70)
    print("BASELINE CONSENSUS GAME")
    print("="*70)
    print("Simple consensus game between:")
    print("  Generator: Phi-2")
    print("  Discriminator: Qwen2-1.5B")
    print("Fixed parameters, no preprocessing")
    print("="*70)

    # Load datasets
    arc_df, arc_easy_df = load_arc_datasets()

    # Load models
    print("\nLoading models...")
    generator_model, generator_tokenizer = load_model("microsoft/phi-2")
    discriminator_model, discriminator_tokenizer = load_model("Qwen/Qwen2-1.5B-Instruct")

    print("\nModels loaded successfully!")

    # Process ARC Challenge
    arc_challenge_results = process_dataset(
        generator_model, generator_tokenizer,
        discriminator_model, discriminator_tokenizer,
        arc_df, "ARC Challenge"
    )

    # Save results
    os.makedirs("results", exist_ok=True)
    challenge_filename = "results/baseline_arc_challenge.csv"
    arc_challenge_results.to_csv(challenge_filename, index=False)
    print(f"\nARC Challenge results saved to {challenge_filename}")

    # Process ARC Easy
    arc_easy_results = process_dataset(
        generator_model, generator_tokenizer,
        discriminator_model, discriminator_tokenizer,
        arc_easy_df, "ARC Easy"
    )

    # Save results
    easy_filename = "results/baseline_arc_easy.csv"
    arc_easy_results.to_csv(easy_filename, index=False)
    print(f"ARC Easy results saved to {easy_filename}")

    # Cleanup
    del generator_model, generator_tokenizer
    del discriminator_model, discriminator_tokenizer
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    print("\n" + "="*70)
    print("BASELINE PROCESSING COMPLETE")
    print("="*70)

    return arc_challenge_results, arc_easy_results

if __name__ == "__main__":
    main()