# Setup
## Imports

In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
import json
import pandas as pd
import numpy as np

from generative_social_choice.utils.postprocessing import (
    plot_sorted_utility_distributions,
    scalar_utility_metrics,
    plot_candidate_distribution_stacked,
    
)


## Load data

In [None]:
from dataclasses import dataclass
from typing import Literal
from generative_social_choice.utils.helper_functions import get_results_paths
from generative_social_choice.utils.postprocessing import plot_sorted_utility_CIs
from generative_social_choice.ratings.utility_matrix import get_baseline_generate_slate_results


def get_results(labelling_model: str, run_id: str, embedding_type: str = "llm"):
    result_paths = get_results_paths(labelling_model=labelling_model, baseline=False,  embedding_type=embedding_type, run_id=run_id)

    algo_assignment_result_dir = result_paths["assignments"]
    algo_assignment_files = {
        path.stem: path for path in algo_assignment_result_dir.glob("*.json")
    }

    algo_assignments = pd.DataFrame(columns=list(algo_assignment_files.keys())) #, index=baseline_assignments.index)
    utilities = pd.DataFrame(columns=list(algo_assignment_files.keys()))

    for algo_name, file_path in algo_assignment_files.items():
        with open(file_path, "r") as f:
            algo_assignment_data = (json.load(f))
            algo_utilities = pd.Series(algo_assignment_data['utilities'], index=algo_assignment_data['agent_ids'])
            utilities[algo_name] = algo_utilities
            cur_algo_assignments = pd.Series(algo_assignment_data['assignments'], index=algo_assignment_data['agent_ids'])
            algo_assignments[algo_name] = cur_algo_assignments


    #algo_assignments.head()
    #utilities.head()
    return utilities, algo_assignments


@dataclass
class ResultConfig:
    name: str
    embedding_type: str
    run_ids: list[str]
    labelling_model: str = "4o-mini"
    pipeline: Literal["ours", "fish"] = "ours"


def show_results(configs: list[ResultConfig], method: str) -> tuple[dict[str, pd.DataFrame], dict[str, pd.DataFrame], dict[str, pd.DataFrame]]:
    utility_dfs = {}
    performance_dfs = {}
    all_algo_assignments = {}
    pipelines = {}

    for config in configs:
        # Collect metrics for all runs
        all_metrics = []
        config_algo_assignments = []
        all_utilities = []
        pipelines[config.name] = config.pipeline

        if config.pipeline=="fish":
            utilities, algo_assignments = get_baseline_generate_slate_results(run_ids=config.run_ids, embedding_type=config.embedding_type)

            utility_dfs[config.name] = [utilities[col] for col in utilities.columns]
            all_algo_assignments[config.name] = [algo_assignments[col] for col in algo_assignments.columns]
            metrics_df = scalar_utility_metrics(utilities)
        else:
            for run_id in config.run_ids:
                utilities, algo_assignments = get_results(labelling_model=config.labelling_model, run_id=run_id, embedding_type=config.embedding_type)
                metrics = scalar_utility_metrics(utilities)
                all_metrics.append(metrics.loc[method])
                config_algo_assignments.append(algo_assignments)
                all_utilities.append(utilities)

            # Convert to DataFrame for easier analysis
            metrics_df = pd.DataFrame(all_metrics)

            utility_dfs[config.name] = all_utilities

            all_algo_assignments[config.name] = config_algo_assignments

        # Calculate mean and 95% confidence intervals
        mean_metrics = metrics_df.mean()
        std_metrics = metrics_df.std()
        ci_95 = 1.96 * std_metrics / np.sqrt(len(config.run_ids))

        # Create summary DataFrame with mean and confidence intervals
        performance_summary = pd.DataFrame({
            'Mean': mean_metrics,
            '95% CI Lower': mean_metrics - ci_95,
            '95% CI Upper': mean_metrics + ci_95
        })
        performance_dfs[config.name] = performance_summary

    # Combine utilities for the selected method across all runs with MultiIndex columns
    utilities_for_CI = pd.DataFrame({
        (name, i): utility_dfs[name][i][method] if pipelines[name] == "ours" else utility_dfs[name][i]
        for name in utility_dfs.keys()
        #for i in (range(len(utility_dfs[name])) if pipelines[name] == "ours" else range(len(utility_dfs[name].columns)))
        for i in range(len(utility_dfs[name]))
    })
    utilities_for_CI.columns = pd.MultiIndex.from_tuples(utilities_for_CI.columns)

    # Now plot with CIs
    plot_sorted_utility_CIs(utilities_for_CI)

    return performance_dfs, utility_dfs, all_algo_assignments


In [None]:

METHOD = "exact"

result_configs = [
    ResultConfig(
        name="Ours (LLM Embeddings)",
        embedding_type="llm",
        run_ids=list(range(10)),
    ),
    ResultConfig(
        name="Ours (Fish Embeddings)",
        embedding_type="fish",
        run_ids=[f"fish_{i}" for i in range(10)],
    ),
    ResultConfig(
        name="Baseline (LLM Embeddings)",
        embedding_type="llm",
        run_ids=range(10),
        pipeline="fish",
    )
]

performance_dfs, utility_dfs, algo_assignments = show_results(result_configs, method=METHOD)

for name in performance_dfs.keys():
    print(name)
    print(performance_dfs[name].head())


### Plots

In [None]:
fig = plot_sorted_utility_distributions(utility_dfs["LLM Embeddings"][0])

In [None]:

fig = plot_candidate_distribution_stacked(algo_assignments["LLM Embeddings"][0])