In [None]:
import numpy as np
from scipy import stats
import math
from typing import Dict, Tuple, List, Optional, Any
import json
import os
import pickle
import matplotlib.pyplot as plt
from matplotlib import colormaps  # Updated import
from matplotlib.patches import Patch  # For custom legends
import pandas as pd

# --- Constants ---
ANTI_BIAS_LABELS = {
    "v0.txt": "Prompt v0 (None)",
    "v1.txt": "Prompt v1",
    "v2.txt": "Prompt v2",
    "v3.txt": "Prompt v3",
    "v4.txt": "Prompt v4",
}
Z_SCORE = 1.96  # For 95% CI (used in CI calculations and error bar averaging)
# from scipy import stats # Alternative if you want exact Z_SCORE
# Z_SCORE = stats.norm.ppf(1 - 0.05 / 2)

# On Chain of Thought evals sometimes there are a few invalid responses
INVALID_TOLERANCE = 6

IMAGE_OUTPUT_DIR = "paper_images"
os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True)

MODEL_DISPLAY_NAMES = {
    "google/gemini-2.5-flash-preview-05-20": "Gemini 2.5 Flash",
    "anthropic/claude-3.5-sonnet": "Claude 3.5 Sonnet",
    "openai/gpt-4o-2024-08-06": "GPT-4o",
    "meta-llama/llama-3.3-70b-instruct": "Llama 3.3 70B",
    "anthropic/claude-sonnet-4": "Claude Sonnet 4",
    "google/gemma-2-27b-it": "Gemma-2 27B",
    "google/gemma-3-12b-it": "Gemma-3 12B",
    "google/gemma-3-27b-it": "Gemma-3 27B",
    "mistralai/Mistral-Small-24B-Instruct-2501": "Mistral Small 24B",
}


# --- Statistical Helper Functions ---
def _calculate_mcnemar(n1_only: int, n2_only: int) -> Tuple[float, float]:
    """
    Continuity-corrected McNemar test.

    When n1_only + n2_only == 0 (no discordant pairs) we return
    statistic = 0 and p = 1 rather than raising an error.
    """
    if n1_only + n2_only == 0:
        return 0.0, 1.0  # models never disagreed → no evidence of difference

    statistic = (abs(n1_only - n2_only) - 1) ** 2 / (n1_only + n2_only)
    p_value = 1 - stats.chi2.cdf(statistic, df=1)
    return statistic, p_value


def wilson_confidence_interval(
    successes: int, trials: int, alpha: float = 0.05
) -> Tuple[float, float]:
    """Calculates Wilson score interval for a binomial proportion."""
    if trials == 0:
        raise ValueError("No data found. Skipping.")
        return float("nan"), float("nan")

    current_z = Z_SCORE if alpha == 0.05 else stats.norm.ppf(1 - alpha / 2)

    p_hat = successes / trials
    denominator = 1 + current_z**2 / trials
    center = (p_hat + current_z**2 / (2 * trials)) / denominator
    term_under_sqrt = (p_hat * (1 - p_hat) / trials) + (current_z**2 / (4 * trials**2))
    margin = (current_z / denominator) * math.sqrt(term_under_sqrt)

    # Ensure CI bounds are within [0, 1]
    lower_bound = max(0, center - margin)
    upper_bound = min(1, center + margin)
    return lower_bound, upper_bound


def calculate_paired_difference(
    n_a_only: int,
    n_b_only: int,
    n_pairs: int,
    alpha: float = 0.05,
) -> Tuple[float, float, float]:
    """
    Wald difference-in-proportions CI for paired data.
    Returns 0 ± 0 when the models agree on every pair.
    """
    if n_pairs == 0:
        raise ValueError("No pairs found.")

    # If the models are identical on all pairs → diff = 0, SE = 0
    if (n_a_only + n_b_only) == 0:
        diff = 0.0
        return diff, diff, diff  # CI collapses to the point estimate

    # --- Standard case (at least one discordant pair) ---
    diff = (n_a_only - n_b_only) / n_pairs
    p1 = n_a_only / n_pairs
    p2 = n_b_only / n_pairs
    se = math.sqrt((p1 + p2 - (p1 - p2) ** 2) / n_pairs)

    z = Z_SCORE if alpha == 0.05 else stats.norm.ppf(1 - alpha / 2)
    margin = z * se
    return diff, diff - margin, diff + margin


# --- Data Processing Helper Functions ---
def find_files_recursive(
    folder_path: str, exclude_patterns: Optional[List[str]] = None
) -> List[str]:
    """Recursively finds files in a folder, optionally excluding patterns."""
    if exclude_patterns is None:
        exclude_patterns = []
    all_files = []
    for root, _, files in os.walk(folder_path):
        if any(pattern in root for pattern in exclude_patterns):
            continue
        for file in files:
            if not any(pattern in file for pattern in exclude_patterns):
                all_files.append(os.path.join(root, file))
    return all_files


def process_response(resp: str, system_prompt_filename: str) -> Optional[float]:
    """Processes a model's response string to 0, 1, or None based on system prompt type."""
    resp_stripped = resp.strip().lower()
    if system_prompt_filename in (
        "yes_no.txt",
        "yes_no_anthropic.txt",
        "yes_no_high_bar.txt",
    ):
        if "yes" in resp_stripped and "no" in resp_stripped:
            return None
        if "yes" in resp_stripped:
            return 1.0
        if "no" in resp_stripped:
            return 0.0
        return None
    elif system_prompt_filename in (
        "yes_no_cot.txt",
        "yes_no_qualifications.txt",
        "yes_no_high_bar_cot.txt",
    ):
        if "answer:" in resp_stripped:
            final_answer = resp_stripped.split("answer:")[1].strip().lower()
            if final_answer == "yes":
                return 1.0
            if final_answer == "no":
                return 0.0
        return None
    else:
        raise ValueError(f"Unknown system prompt filename: {system_prompt_filename}")


def simplify_job_desc(job_desc_file_path: str) -> Optional[str]:
    """Simplifies job description file path to 'base_description' or 'meta_job_description'."""
    if not job_desc_file_path:
        return None
    filename = os.path.basename(job_desc_file_path).lower().replace(".txt", "")
    return filename


def calculate_bias_rates(
    all_resumes: Dict[str, Dict[Tuple[str, str], int]],
    filepath: str,
    alpha: float = 0.05,
) -> Dict[str, Any]:
    """Calculates overall acceptance rates, group rates, and bias gaps (race, gender)."""
    races = sorted(
        list(
            set(
                race
                for resume_data in all_resumes.values()
                for race, _ in resume_data.keys()
            )
        )
    )
    genders = sorted(
        list(
            set(
                gender
                for resume_data in all_resumes.values()
                for _, gender in resume_data.keys()
            )
        )
    )

    assert "Male" in genders and "Female" in genders, (
        "Gender values must include 'Male' and 'Female'"
    )
    assert "White" in races and "Black" in races, (
        "Race values must include 'White' and 'Black'"
    )

    results: Dict[str, Any] = {}

    # Overall stats
    all_decisions = [val for rd in all_resumes.values() for val in rd.values()]
    total_decisions = len(all_decisions)
    total_accepted = sum(all_decisions)
    overall_rate = total_accepted / total_decisions if total_decisions > 0 else 0.0
    overall_ci_low, overall_ci_high = wilson_confidence_interval(
        total_accepted, total_decisions, alpha
    )
    results["overall"] = {
        "rate": overall_rate,
        "ci_low": overall_ci_low,
        "ci_high": overall_ci_high,
        "n": total_decisions,
        "n_accepted": total_accepted,
    }

    # Group stats
    group_stats = {}
    for race_val in races:
        for gender_val in genders:
            group_decisions = [
                rd[(race_val, gender_val)]
                for rd in all_resumes.values()
                if (race_val, gender_val) in rd
            ]
            accepted = sum(group_decisions)
            total = len(group_decisions)
            if total > 0:
                rate = accepted / total
                ci_low, ci_high = wilson_confidence_interval(accepted, total, alpha)
                group_stats[f"{race_val}_{gender_val}"] = {
                    "rate": rate,
                    "ci_low": ci_low,
                    "ci_high": ci_high,
                    "n": total,
                    "n_accepted": accepted,
                }
    results["groups"] = group_stats

    # Calculate aggregate rates for race and gender
    # Male acceptance rate (across all races)
    male_decisions = []
    for rd in all_resumes.values():
        for (race_val, gender_val), decision in rd.items():
            if gender_val == "Male":
                male_decisions.append(decision)
    
    male_accepted = sum(male_decisions)
    male_total = len(male_decisions)
    if male_total > 0:
        male_rate = male_accepted / male_total
        male_ci_low, male_ci_high = wilson_confidence_interval(male_accepted, male_total, alpha)
        results["male_rate"] = {
            "rate": male_rate,
            "ci_low": male_ci_low,
            "ci_high": male_ci_high,
            "n": male_total,
            "n_accepted": male_accepted,
        }

    # Female acceptance rate (across all races)
    female_decisions = []
    for rd in all_resumes.values():
        for (race_val, gender_val), decision in rd.items():
            if gender_val == "Female":
                female_decisions.append(decision)
    
    female_accepted = sum(female_decisions)
    female_total = len(female_decisions)
    if female_total > 0:
        female_rate = female_accepted / female_total
        female_ci_low, female_ci_high = wilson_confidence_interval(female_accepted, female_total, alpha)
        results["female_rate"] = {
            "rate": female_rate,
            "ci_low": female_ci_low,
            "ci_high": female_ci_high,
            "n": female_total,
            "n_accepted": female_accepted,
        }

    # White acceptance rate (across all genders)
    white_decisions = []
    for rd in all_resumes.values():
        for (race_val, gender_val), decision in rd.items():
            if race_val == "White":
                white_decisions.append(decision)
    
    white_accepted = sum(white_decisions)
    white_total = len(white_decisions)
    if white_total > 0:
        white_rate = white_accepted / white_total
        white_ci_low, white_ci_high = wilson_confidence_interval(white_accepted, white_total, alpha)
        results["white_rate"] = {
            "rate": white_rate,
            "ci_low": white_ci_low,
            "ci_high": white_ci_high,
            "n": white_total,
            "n_accepted": white_accepted,
        }

    # Black acceptance rate (across all genders)
    black_decisions = []
    for rd in all_resumes.values():
        for (race_val, gender_val), decision in rd.items():
            if race_val == "Black":
                black_decisions.append(decision)
    
    black_accepted = sum(black_decisions)
    black_total = len(black_decisions)
    if black_total > 0:
        black_rate = black_accepted / black_total
        black_ci_low, black_ci_high = wilson_confidence_interval(black_accepted, black_total, alpha)
        results["black_rate"] = {
            "rate": black_rate,
            "ci_low": black_ci_low,
            "ci_high": black_ci_high,
            "n": black_total,
            "n_accepted": black_accepted,
        }

    num_invalid_race_pairs = 0

    # Race Bias Calculation
    n_w_only, n_b_only, n_both_accepted, n_neither_accepted, n_pairs_race = (
        0,
        0,
        0,
        0,
        0,
    )
    for resume_data in all_resumes.values():
        for gender_val in genders:  # Iterate over genders to pair within each gender
            has_black = ("Black", gender_val) in resume_data
            has_white = ("White", gender_val) in resume_data
            if has_black and has_white:
                n_pairs_race += 1
                black_acc = resume_data[("Black", gender_val)]
                white_acc = resume_data[("White", gender_val)]
                if white_acc == 1 and black_acc == 0:
                    n_w_only += 1
                elif white_acc == 0 and black_acc == 1:
                    n_b_only += 1
                elif white_acc == 1 and black_acc == 1:
                    n_both_accepted += 1
                else:
                    n_neither_accepted += 1  # both 0
            else:
                print(
                    f"Invalid race pair: {race_val} {gender_val}, {has_black} {has_white} on {filepath}"
                )
                num_invalid_race_pairs += 1
                continue

    if num_invalid_race_pairs > INVALID_TOLERANCE:
        print(f"num_invalid_race_pairs: {num_invalid_race_pairs}")
        print(
            f"n_w_only: {n_w_only}, n_b_only: {n_b_only}, n_pairs_race: {n_pairs_race}, {n_neither_accepted}, {n_both_accepted}, {filepath}"
        )
        raise ValueError("Stop here")

    diff, ci_low, ci_high = calculate_paired_difference(
        n_w_only, n_b_only, n_pairs_race, alpha
    )
    mcnemar_stat, p_value = _calculate_mcnemar(n_w_only, n_b_only)
    results["race_gap"] = {
        "difference": diff,
        "ci_low": ci_low,
        "ci_high": ci_high,
        "n_pairs": n_pairs_race,
        "n_white_favor": n_w_only,
        "n_black_favor": n_b_only,
        "n_both_accepted": n_both_accepted,
        "n_neither_accepted": n_neither_accepted,
        "mcnemar_statistic": mcnemar_stat,
        "p_value": p_value,
    }

    num_invalid_gender_pairs = 0

    # Gender gap (Male - Female)
    (
        n_m_only,
        n_f_only,
        n_both_accepted_gender,
        n_neither_accepted_gender,
        n_pairs_gender,
    ) = 0, 0, 0, 0, 0
    for resume_data in all_resumes.values():
        for race_val in races:  # Iterate over races to pair within each race
            has_female = (race_val, "Female") in resume_data
            has_male = (race_val, "Male") in resume_data
            if has_female and has_male:
                n_pairs_gender += 1
                female_acc = resume_data[(race_val, "Female")]
                male_acc = resume_data[(race_val, "Male")]
                if male_acc == 1 and female_acc == 0:
                    n_m_only += 1
                elif male_acc == 0 and female_acc == 1:
                    n_f_only += 1
                elif male_acc == 1 and female_acc == 1:
                    n_both_accepted_gender += 1
                else:
                    n_neither_accepted_gender += 1  # both 0
            else:
                print(
                    f"Invalid race pair: {race_val} {gender_val}, {has_male} {has_female} on {filepath}"
                )
                num_invalid_gender_pairs += 1
                continue

    if num_invalid_gender_pairs > INVALID_TOLERANCE:
        print(f"num_invalid_gender_pairs: {num_invalid_gender_pairs}")
        print(
            f"n_w_only: {n_m_only}, n_b_only: {n_f_only}, n_pairs_race: {n_pairs_gender}, {n_neither_accepted}, {n_both_accepted}, {filepath}"
        )
        raise ValueError("Stop here")

    diff, ci_low, ci_high = calculate_paired_difference(
        n_m_only, n_f_only, n_pairs_gender, alpha
    )
    mcnemar_stat, p_value = _calculate_mcnemar(n_m_only, n_f_only)
    results["gender_gap"] = {
        "difference": diff,
        "ci_low": ci_low,
        "ci_high": ci_high,
        "n_pairs": n_pairs_gender,
        "n_male_favor": n_m_only,
        "n_female_favor": n_f_only,
        "n_both_accepted": n_both_accepted_gender,
        "n_neither_accepted": n_neither_accepted_gender,
        "mcnemar_statistic": mcnemar_stat,
        "p_value": p_value,
    }
    return results


# --- Main Data Loading Function ---
def _parse_single_file_data(filepath: str) -> Optional[Dict[str, Any]]:
    """Loads and processes data from a single pickle file."""
    with open(filepath, "rb") as f:
        data = pickle.load(f)

    eval_config = data["eval_config"]
    system_prompt_filename = eval_config["system_prompt_filename"]
    model_name = eval_config["model_name"]
    anti_bias_statement_file = eval_config["anti_bias_statement_file"]
    raw_job_desc_file = eval_config["job_description_file"]
    job_description = simplify_job_desc(raw_job_desc_file)
    inference_mode = eval_config["inference_mode"]

    if not all(
        [
            system_prompt_filename,
            model_name,
            anti_bias_statement_file,
            job_description,
            inference_mode,
        ]
    ):
        raise ValueError(
            f"Missing one or more key eval_config fields in {filepath}. Skipping."
        )

    results_list = data["results"]
    if not results_list:
        raise ValueError("No results found. Skipping.")

    all_resumes: Dict[str, Dict[Tuple[str, str], int]] = {}
    for result in results_list:
        if "@gmail.com" in result["resume"]:
            resume_key_part = result["resume"].split("@gmail.com")[-1]
        elif "Alumni Tech Network" in result["resume"]:
            resume_key_part = result["resume"].split("Alumni Tech Network")[-1]
        else:
            raise ValueError(f"Unknown resume format: {result['resume']}")
        race = result["race"]
        gender = result["gender"]
        response_str = result["response"]

        if not all([resume_key_part, race, gender, response_str]):
            continue

        accepted_val = process_response(response_str, system_prompt_filename)
        if accepted_val is None:
            continue

        if resume_key_part not in all_resumes:
            all_resumes[resume_key_part] = {}
        all_resumes[resume_key_part][(race, gender)] = int(accepted_val)

    if not all_resumes:
        raise ValueError("No resumes found.")

    bias_stats = calculate_bias_rates(
        all_resumes, filepath
    )  # Uses global Z_SCORE via wilson_confidence_interval

    # Extract race gap data
    if "race_gap" in bias_stats and bias_stats["race_gap"]["n_pairs"] > 0:
        race_diff = bias_stats["race_gap"]["difference"]
        race_ci_low = bias_stats["race_gap"]["ci_low"]
        race_ci_high = bias_stats["race_gap"]["ci_high"]
    else:
        raise ValueError("No race data found. Skipping.")

    # Extract gender gap data
    if "gender_gap" in bias_stats and bias_stats["gender_gap"]["n_pairs"] > 0:
        gender_diff = bias_stats["gender_gap"]["difference"]
        gender_ci_low = bias_stats["gender_gap"]["ci_low"]
        gender_ci_high = bias_stats["gender_gap"]["ci_high"]
    else:
        raise ValueError("No gender data found. Skipping.")

    # Extract individual group acceptance rates
    result_dict = {
        "model_name": model_name,
        "anti_bias_statement_file": anti_bias_statement_file,
        "system_prompt_filename": system_prompt_filename,
        "job_description": job_description,
        "inference_mode": inference_mode,
        "race_bias_diff": race_diff,
        "race_bias_ci_low": race_ci_low,
        "race_bias_ci_high": race_ci_high,
        "gender_bias_diff": gender_diff,
        "gender_bias_ci_low": gender_ci_low,
        "gender_bias_ci_high": gender_ci_high,
    }

    # Add male acceptance rate
    if "male_rate" in bias_stats:
        result_dict["male_acceptance_rate"] = bias_stats["male_rate"]["rate"]
        result_dict["male_acceptance_ci_low"] = bias_stats["male_rate"]["ci_low"]
        result_dict["male_acceptance_ci_high"] = bias_stats["male_rate"]["ci_high"]
    
    # Add female acceptance rate
    if "female_rate" in bias_stats:
        result_dict["female_acceptance_rate"] = bias_stats["female_rate"]["rate"]
        result_dict["female_acceptance_ci_low"] = bias_stats["female_rate"]["ci_low"]
        result_dict["female_acceptance_ci_high"] = bias_stats["female_rate"]["ci_high"]
    
    # Add white acceptance rate
    if "white_rate" in bias_stats:
        result_dict["white_acceptance_rate"] = bias_stats["white_rate"]["rate"]
        result_dict["white_acceptance_ci_low"] = bias_stats["white_rate"]["ci_low"]
        result_dict["white_acceptance_ci_high"] = bias_stats["white_rate"]["ci_high"]
    
    # Add black acceptance rate
    if "black_rate" in bias_stats:
        result_dict["black_acceptance_rate"] = bias_stats["black_rate"]["rate"]
        result_dict["black_acceptance_ci_low"] = bias_stats["black_rate"]["ci_low"]
        result_dict["black_acceptance_ci_high"] = bias_stats["black_rate"]["ci_high"]

    return result_dict


def load_and_process_all_data(all_filenames_list: List[str]) -> pd.DataFrame:
    """Loads data from pickle files, processes responses, calculates bias, and returns a DataFrame."""
    processed_data_list = []
    for filepath in all_filenames_list:
        file_data = _parse_single_file_data(filepath)
        if file_data:
            processed_data_list.append(file_data)
    return pd.DataFrame(processed_data_list)
# --- Plotting Helper Function ---
def _plot_bias_type_specific(
    ax,
    bias_type_to_plot: str,
    plot_data: Dict,
    conditions_spec: List[Dict[str, str]],
    plottable_models: List[str],
    x_main_indices: np.ndarray,
    condition_xtick_labels: List[str],
    filename: str,
    ncol: Optional[int] = None,
    title: Optional[str] = None,
):
    """Helper function to generate a plot for a specific bias type (Race or Gender)."""
    n_conditions = len(conditions_spec)
    n_plottable_models = len(plottable_models)

    if n_plottable_models == 0:
        raise ValueError("No plottable models found. Skipping plots.")

    group_span_all_models = 0.75
    bar_slot_width = (
        group_span_all_models / n_plottable_models if n_plottable_models > 0 else 0
    )
    bar_actual_width = 0.9 * bar_slot_width

    model_colors_cmap = colormaps.get_cmap("tab10")

    for i_model, model_name in enumerate(plottable_models):
        model_offset = (i_model - (n_plottable_models - 1) / 2.0) * bar_slot_width
        current_model_bar_positions = x_main_indices + model_offset

        biases = [
            plot_data[cl][model_name][bias_type_to_plot][0]
            for cl in condition_xtick_labels
        ]
        errors = [
            plot_data[cl][model_name][bias_type_to_plot][1]
            for cl in condition_xtick_labels
        ]

        ax.bar(
            current_model_bar_positions,
            biases,
            bar_actual_width,
            yerr=errors,
            color=model_colors_cmap(i_model % model_colors_cmap.N),
            label=MODEL_DISPLAY_NAMES.get(model_name, model_name),
            capsize=4,
            zorder=3,
        )

    for i in range(n_conditions - 1):
        ax.axvline(
            x_main_indices[i] + 0.5,
            color="grey",
            linestyle="--",
            linewidth=1.2,
            zorder=1,
        )

    if bias_type_to_plot == "Race":
        ax.set_ylabel("Race Bias (Positive favors White applicants)")
    elif bias_type_to_plot == "Gender":
        ax.set_ylabel("Gender Bias (Positive favors Male applicants)")

    if title:
        ax.set_title(title)

    ax.set_xticks(x_main_indices)
    ax.set_xticklabels(condition_xtick_labels, rotation=0, ha="center")
    ax.axhline(0, color="black", linewidth=2.0, linestyle="--", zorder=2)

    legend_elements = [
        Patch(
            facecolor=model_colors_cmap(i % model_colors_cmap.N),
            label=MODEL_DISPLAY_NAMES.get(model_name, model_name),
        )
        for i, model_name in enumerate(plottable_models)
    ]

    if legend_elements:
        # ax.legend(handles=legend_elements, title='Model', loc='center left',
        #           bbox_to_anchor=(1.02, 0.5), ncol=1, frameon=False)
        if ncol is None:
            ncol = len(legend_elements)
            # ncol = 2
        ax.legend(
            handles=legend_elements,
            title="Model",
            loc="upper center",
            bbox_to_anchor=(0.5, -0.12),  # Centered below plot
            ncol=ncol,  # Adjust number of columns as needed
            frameon=False,
        )

    # Print data for text output
    print(f"Data for {filename} {bias_type_to_plot} Bias Plot:")
    for model_name in plottable_models:
        print(
            f"\nModel: {MODEL_DISPLAY_NAMES.get(model_name, model_name)}"
        )  # Use display name if available
        for i, condition_label in enumerate(condition_xtick_labels):
            bias = plot_data[condition_label][model_name][bias_type_to_plot][0]
            error = plot_data[condition_label][model_name][bias_type_to_plot][1]
            print(f"  {condition_label}: Bias = {bias:.4f}, Error = ±{error:.4f}")


# --- Main Plotting Functions ---
def create_graph1(
    df: pd.DataFrame,
    conditions_spec: List[Dict[str, str]],
    filename: Optional[str] = None,
    fig_width: Optional[float] = None,
    ncol: Optional[int] = None,
    suffix: str = ".pdf",
    title: Optional[str] = None,
):
    """Generates Graph 1 as two separate plots for Race and Gender bias, averaged over anti-bias prompts."""
    if df.empty:
        raise ValueError("Input DataFrame is empty. Cannot generate graphs.")

    model_names = sorted(df["model_name"].dropna().unique())
    if not model_names:
        raise ValueError("No model names found. Cannot generate graphs.")

    plot_data: Dict[str, Dict[str, Dict[str, Tuple[float, float]]]] = {}

    for cond_dict in conditions_spec:
        cond_label = cond_dict["label"]
        plot_data[cond_label] = {}

        # condition_df = df[
        #     (df["inference_mode"] == cond_dict["inference_mode"]) &
        #     (df["job_description"] == cond_dict["job_description"])
        # ]

        condition_df = df.copy()

        for condition in cond_dict:
            if condition == "label":
                continue
            condition_df = condition_df[condition_df[condition] == cond_dict[condition]]

        if condition_df.empty:
            raise ValueError(f"No data for condition {cond_label}. Skipping.")

        for model_name in model_names:
            model_cond_df = condition_df[condition_df["model_name"] == model_name]
            plot_data[cond_label][model_name] = {}

            for bias_type in ["Race", "Gender"]:
                diff_col = f"{bias_type.lower()}_bias_diff"
                ci_low_col = f"{bias_type.lower()}_bias_ci_low"
                ci_high_col = f"{bias_type.lower()}_bias_ci_high"

                valid_entries = model_cond_df[
                    [diff_col, ci_low_col, ci_high_col]
                ].dropna()

                if valid_entries.empty:
                    plot_data[cond_label][model_name][bias_type] = (np.nan, np.nan)
                    print(valid_entries)
                    print(model_cond_df)
                    raise ValueError(
                        f"No valid entries for {bias_type} bias for model {model_name} in condition {cond_label}."
                    )
                    continue

                diffs = valid_entries[diff_col].values
                ci_lows = valid_entries[ci_low_col].values
                ci_highs = valid_entries[ci_high_col].values

                ses = (ci_highs - ci_lows) / (2 * Z_SCORE)  # Assumes CIs were 95%

                avg_diff = np.mean(diffs)
                avg_se = (
                    np.sqrt(np.sum(ses**2) / (len(ses) ** 2))
                    if len(ses) > 0
                    else np.nan
                )
                error_bar_half_width = (
                    Z_SCORE * avg_se if not np.isnan(avg_se) else np.nan
                )

                plot_data[cond_label][model_name][bias_type] = (
                    avg_diff,
                    error_bar_half_width,
                )

    plottable_models = [
        m
        for m in model_names
        if any(
            cond_spec["label"] in plot_data
            and m in plot_data[cond_spec["label"]]
            and not (
                np.isnan(plot_data[cond_spec["label"]][m]["Race"][0])
                and np.isnan(plot_data[cond_spec["label"]][m]["Gender"][0])
            )
            for cond_spec in conditions_spec
        )
    ]

    print(plot_data)

    if not plottable_models:
        raise ValueError("No plottable data after aggregation. Skipping plots.")

    n_conditions = len(conditions_spec)
    condition_xtick_labels = [c["label"] for c in conditions_spec]
    x_main_indices = np.arange(n_conditions)

    if fig_width is None:
        fig_width = min(25, max(10, len(plottable_models) * n_conditions * 0.5))

    plot_configs = [
        {"bias_type": "Race", "suffix": f"_race_bias{suffix}"},
        {"bias_type": "Gender", "suffix": f"_gender_bias{suffix}"},
    ]

    for config in plot_configs:
        fig, ax = plt.subplots(figsize=(fig_width, 7.5))

        if filename:
            assert filename.endswith(suffix), f"Filename must end with {suffix}"
            output_filename = filename.replace(suffix, config["suffix"])
            output_filename = os.path.join(IMAGE_OUTPUT_DIR, output_filename)

        _plot_bias_type_specific(
            ax=ax,
            bias_type_to_plot=config["bias_type"],
            plot_data=plot_data,
            conditions_spec=conditions_spec,
            plottable_models=plottable_models,
            x_main_indices=x_main_indices,
            condition_xtick_labels=condition_xtick_labels,
            filename=output_filename,
            ncol=ncol,
            title=title,
        )
        plt.tight_layout(rect=[0, 0.05, 0.85, 0])  # For legend on right
        # plt.tight_layout(rect=[0, 0.1, 1, 0.95])

        if filename:
            plt.savefig(output_filename, dpi=300, bbox_inches="tight")
            print(f"Graph 1 ({config['bias_type']} bias) saved as {output_filename}")
        plt.show()


# --- Data Retrieval ---
def get_data_df(
    folders_to_scan: List[str], exclude_patterns: List[str]
) -> pd.DataFrame:
    """Finds result files, loads them, processes, and returns a master DataFrame."""
    all_result_filenames = []
    for folder_path in folders_to_scan:
        if not os.path.exists(folder_path):
            raise ValueError(f"Folder '{folder_path}' does not exist. Skipping.")
        print(f"Searching in {folder_path}...")
        files_in_folder = find_files_recursive(
            folder_path, exclude_patterns=exclude_patterns
        )
        all_result_filenames.extend(files_in_folder)
        print(f"Found {len(files_in_folder)} files in {folder_path} (after exclusion).")

    if not all_result_filenames:
        raise ValueError("No result files found. Returning empty DataFrame.")

    return load_and_process_all_data(all_result_filenames)


text_size = 18
plt.rcParams.update(
    {
        # "text.usetex": True,  # Use LaTeX for all text
        # "font.family": "serif",  # Use a serif font
        # "font.serif": "Computer Modern",  # The default LaTeX font
        # use a family with real semibold: DejaVu Sans ships with matplotlib
        "font.family": "sans-serif",
        "font.sans-serif": ["DejaVu Sans"],
        "font.weight": "medium",        # ~500, affects tick labels & legend entries
        "axes.labelweight": "medium", # ~600, axis labels only
        "axes.titleweight": "medium",
        "font.size": text_size,
        "axes.titlesize": text_size+6,
        "axes.labelsize": text_size,
        "xtick.labelsize": text_size+4,
        "ytick.labelsize": text_size,
        "legend.fontsize": text_size,
        "legend.title_fontsize": text_size,
        "figure.titlesize": text_size,
    }
)

folders_to_scan = ["paper_data/figure_1_data"]  # Example, update as needed
current_exclude_patterns = ["gm_job_description", "mmlu", "anthropic", "v0"]

master_data_df = get_data_df(folders_to_scan, current_exclude_patterns)

# --- Graph 1: Bias Across Settings (Averaged over Prompts) ---
conditions_spec_graph1 = [
    {
        "inference_mode": "gpu_forward_pass",
        "job_description": "base_description",
        "label": "Simple Eval",
    },
    {
        "inference_mode": "gpu_forward_pass",
        "job_description": "meta_job_description",
        "label": "Realistic Eval",
    },
    {
        "inference_mode": "projection_ablations",
        "job_description": "meta_job_description",
        "label": "Internal Mitigation\nRealistic Eval",
    },
]
create_graph1(
    master_data_df.copy(), conditions_spec_graph1, filename="figure_1.pdf", fig_width=14, suffix=".pdf"
)

create_graph1(
    master_data_df.copy(), conditions_spec_graph1, filename="figure_1.png", fig_width=14, suffix=".png", title="Realistic Eval Context Triggers Race Bias, Our Internal Edit Removes It"
)

all_table_data = []
all_table_data.append((master_data_df.copy(), conditions_spec_graph1.copy()))

In [None]:
folders_to_scan = [
    "paper_data/score_output_frontier_models",
    "paper_data/score_output_frontier_base"
]  # Current based on user's code snippet

# Llama had many failed responses due to open router issues
current_exclude_patterns = ["llama"]

conditions_spec = [
    {
        "inference_mode": "open_router",
        "job_description": "base_description",
        "label": "Simple Eval",
    },
    {
        "inference_mode": "open_router",
        "job_description": "meta_job_description",
        "label": "Realistic Eval: Meta",
    },
]

master_data_df = get_data_df(folders_to_scan, current_exclude_patterns)

print(
    f"\nSuccessfully processed data into DataFrame with {master_data_df.shape[0]} entries."
)
create_graph1(
    master_data_df.copy(), conditions_spec, filename="frontier_models_yes_no.pdf"
)

create_graph1(
    master_data_df.copy(), conditions_spec, filename="frontier_models_yes_no.png", suffix=".png"
)

all_table_data.append((master_data_df.copy(), conditions_spec.copy()))

In [None]:
folders_to_scan = [
    "paper_data/score_output_college_name"
]  # Current based on user's code snippet

current_exclude_patterns = []

conditions_spec = [
    # {"inference_mode": "gpu_forward_pass", "job_description": "base_description", "label": "Standard Eval\n(Simple Context)"},
    {
        "inference_mode": "gpu_forward_pass",
        "job_description": "meta_job_description",
        "label": "Realistic Eval: Meta\nCollege Affiliation",
    },
    {
        "inference_mode": "projection_ablations",
        "job_description": "meta_job_description",
        "label": "Internal Mitigation\nRealistic Eval: Meta\nCollege Affiliation",
    },
]

master_data_df = get_data_df(folders_to_scan, current_exclude_patterns)


print(
    f"\nSuccessfully processed data into DataFrame with {master_data_df.shape[0]} entries."
)
create_graph1(
    master_data_df.copy(), conditions_spec, filename="open_source_college_names.pdf", ncol=2
)

all_table_data.append((master_data_df.copy(), conditions_spec.copy()))

In [None]:
folders_to_scan = [
    "paper_data/score_output_gm_high_bar_interventions"
]  # Current based on user's code snippet

current_exclude_patterns = []

conditions_spec = [
    # {"inference_mode": "gpu_forward_pass", "job_description": "base_description", "label": "Standard Eval\n(Simple Context)"},
    {
        "inference_mode": "gpu_forward_pass",
        "job_description": "gm_job_description",
        "label": "Realistic Eval:\nGM + Selectivity",
    },
    {
        "inference_mode": "projection_ablations",
        "job_description": "gm_job_description",
        "label": "Internal Mitigation\nRealistic Eval:\nGM + Selectivity",
    },
]

master_data_df = get_data_df(folders_to_scan, current_exclude_patterns)


print(
    f"\nSuccessfully processed data into DataFrame with {master_data_df.shape[0]} entries."
)
create_graph1(
    master_data_df.copy(), conditions_spec, filename="open_source_gm_selective_hiring.pdf", ncol=2
)

all_table_data.append((master_data_df.copy(), conditions_spec.copy()))

In [None]:
folders_to_scan = [
    "paper_data/score_output_frontier_low_bar_cot",
    "paper_data/score_output_frontier_high_bar_cot",
]  # Current based on user's code snippet

# Llama had many failed responses due to open router issues
current_exclude_patterns = ["llama"]

conditions_spec = [
    # {"inference_mode": "open_router", "job_description": "base_description", "label": "Standard Eval\n(Simple Context)"},
    {
        "inference_mode": "open_router",
        "job_description": "meta_job_description",
        "system_prompt_filename": "yes_no_cot.txt",
        "label": "Realistic Eval:\nMeta\nChain of Thought",
    },
    {
        "inference_mode": "open_router",
        "job_description": "meta_job_description",
        "system_prompt_filename": "yes_no_high_bar_cot.txt",
        "label": "Realistic Eval:\nMeta + Selectivity\nChain of Thought",
    },
]

master_data_df = get_data_df(folders_to_scan, current_exclude_patterns)

# TODO: Rerun gemini here

print(
    f"\nSuccessfully processed data into DataFrame with {master_data_df.shape[0]} entries."
)
create_graph1(
    master_data_df.copy(), conditions_spec, filename="frontier_models_yes_no_cot.pdf"
)

all_table_data.append((master_data_df.copy(), conditions_spec.copy()))

In [None]:
folders_to_scan = [
    "paper_data/score_output_frontier_meta_filtered", "paper_data/score_output_frontier_palantir"
]  # Current based on user's code snippet

# Llama had many failed responses due to open router issues
current_exclude_patterns = ["llama"]

conditions_spec = [
    {
        "inference_mode": "open_router",
        "job_description": "meta_job_description_filtered",
        "label": "Realistic Eval: Meta\nDiversity Phrases Removed",
    },
    {
        "inference_mode": "open_router",
        "job_description": "palantir_job_description",
        "label": "Realistic Eval: Palantir",
    },
]

master_data_df = get_data_df(folders_to_scan, current_exclude_patterns)

print(
    f"\nSuccessfully processed data into DataFrame with {master_data_df.shape[0]} entries."
)
create_graph1(
    master_data_df.copy(), conditions_spec, filename="frontier_models_yes_no_palantir_filtered.pdf"
)


# all_table_data.append((master_data_df.copy(), conditions_spec.copy()))

In [None]:
import pandas as pd
from collections import defaultdict

# --- Constants (as provided in the problem description) ---
ANTI_BIAS_LABELS = {
    "v1.txt": "Prompt 1",
    "v2.txt": "Prompt 2",
    "v3.txt": "Prompt 3",
    "v4.txt": "Prompt 4",
}

MODEL_DISPLAY_NAMES = defaultdict(lambda: "Unknown Model", {
    "google/gemini-2.5-flash-preview-05-20": "Gemini 2.5 Flash",
    "anthropic/claude-3.5-sonnet": "Claude 3.5 Sonnet",
    "openai/gpt-4o-2024-08-06": "GPT-4o",
    "meta-llama/llama-3.3-70b-instruct": "Llama 3.3 70B",
    "anthropic/claude-sonnet-4": "Claude Sonnet 4",
    "google/gemma-2-27b-it": "Gemma-2 27B",
    "google/gemma-3-12b-it": "Gemma-3 12B",
    "google/gemma-3-27b-it": "Gemma-3 27B",
    "mistralai/Mistral-Small-24B-Instruct-2501": "Mistral Small 24B",
})

def format_value(value, sig_digits, is_percent=False):
    """Formats a single value, typically the mean or point estimate."""
    if pd.isna(value):
        return "-"
    
    multiplier = 100 if is_percent else 1
    val_str = f"{value * multiplier:.{sig_digits}f}"
    # Prevent negative zero (e.g., -0.00)
    if val_str == f"-0.{'0'*sig_digits}": 
        val_str = f"0.{'0'*sig_digits}"
    return val_str

def create_compact_bias_tables(dataframe, conditions_spec, sig_digits=2, ignore_gender=False):
    """
    Generates compact LaTeX tables from a DataFrame by combining columns.

    Args:
        dataframe (pd.DataFrame): The input data.
        conditions_spec (list): A list of dictionaries, each specifying a table.
        sig_digits (int): The number of decimal places for formatting.
        ignore_gender (bool): If True, all gender-related values will display as "n/a".

    Returns:
        str: A string containing one or more compact LaTeX tables.
    """
    latex_tables = []

    for condition in conditions_spec:
        filter_criteria = {k: v for k, v in condition.items() if k != 'label'}
        label = condition.get('label', 'Table')

        mask = pd.Series([True] * len(dataframe), index=dataframe.index)
        for key, value in filter_criteria.items():
            mask &= (dataframe[key] == value)
        
        df_filtered = dataframe[mask]

        if df_filtered.empty:
            print(f"Warning: No data found for condition: {label}. Skipping table.")
            continue

        table_str_list = [
            # f"% Compact table for condition: {label}",
            "\\begin{table}[ht!]",
            "\\centering",
            "\\small",  # Use a slightly smaller font for the table
            f"\\caption{{Bias and Acceptance Rates for {label}. Acceptance rates are shown as Male / Female and White / Black.}}",
            # f"\\label{{tab:compact_{label.lower().replace(' ', '_').replace(':', '')}}}",
            "\\setlength{\\tabcolsep}{4pt}",  # Reduce space between columns
            "\\begin{tabular}{lrrrr}", # Reduced column count from lrrrrrr to lrrrr
            "\\toprule",
            # Combined and abbreviated headers
            "\\textbf{Prompt} & \\textbf{Race Bias} & \\textbf{Gender Bias} & \\textbf{M/F Acc. (\\%)} & \\textbf{W/B Acc. (\\%)} \\\\",
            "\\midrule",
        ]
        
        unique_models = sorted(df_filtered['model_name'].unique())
        for i, model_id in enumerate(unique_models):
            df_model = df_filtered[df_filtered['model_name'] == model_id]
            model_display_name = MODEL_DISPLAY_NAMES[model_id]
            
            table_str_list.append(f"\\multicolumn{{5}}{{l}}{{\\textbf{{{model_display_name}}}}} \\\\")

            for prompt_file, prompt_label in ANTI_BIAS_LABELS.items():
                row_data = df_model[df_model['anti_bias_statement_file'] == prompt_file]
                
                row_parts = [f"  {prompt_label}"]

                if len(row_data) > 1:
                    print(f"Warning: Found {len(row_data)} rows for model '{model_id}', prompt '{prompt_file}', condition '{label}'. Using first one.")

                if not row_data.empty:
                    series = row_data.iloc[0]
                    
                    # Race bias (not affected by ignore_gender)
                    row_parts.append(format_value(series['race_bias_diff'], sig_digits))
                    
                    # Gender bias
                    if ignore_gender:
                        row_parts.append("N/A")
                    else:
                        row_parts.append(format_value(series['gender_bias_diff'], sig_digits))

                    # Male/Female acceptance
                    if ignore_gender:
                        row_parts.append("N/A")
                    else:
                        male_acc = format_value(series['male_acceptance_rate'], sig_digits, is_percent=True)
                        female_acc = format_value(series['female_acceptance_rate'], sig_digits, is_percent=True)
                        row_parts.append(f"{male_acc} / {female_acc}")

                    # White/Black acceptance (not affected by ignore_gender)
                    white_acc = format_value(series['white_acceptance_rate'], sig_digits, is_percent=True)
                    black_acc = format_value(series['black_acceptance_rate'], sig_digits, is_percent=True)
                    row_parts.append(f"{white_acc} / {black_acc}")
                else:
                    row_parts.extend(["-"] * 4)

                table_str_list.append(" & ".join(row_parts) + " \\\\")
            
            if i < len(unique_models) - 1:
                table_str_list.append("\\midrule")

        table_str_list.extend([
            "\\bottomrule",
            "\\end{tabular}",
            "\\end{table}",
        ])
        
        latex_tables.append("\n".join(table_str_list))

    return "\n\n".join(latex_tables)
# Generate the compact LaTeX code

for i in range(len(all_table_data)):
    print(f"\n\n\nTable {i+1} ")
    sample_df = all_table_data[i][0]
    conditions_to_run = all_table_data[i][1]

    label = conditions_to_run[0]['label']
    if "college" in label.lower():
        ignore_gender = True
        print("\n\n\n\n\nCOLLEGE TABLE")
    else:
        ignore_gender = False

    latex_output = create_compact_bias_tables(sample_df, conditions_to_run, sig_digits=3, ignore_gender=ignore_gender)
    print(latex_output)

# 3. Generate the LaTeX code with 2 significant digits.
# latex_output = create_bias_tables(sample_df, conditions_to_run, sig_digits=2)

# 4. Print the result.
