In [None]:
from pathlib import Path
from pprint import pprint

import numpy as np
import pandas as pd


def get_all_fids(search_key: str, exp_dir: str = "all_yahpo_25") -> dict:  # noqa: D103
    results_dir = Path("/home/soham/Master_Thesis/code/momf_priors_results")
    exp_dir = results_dir / exp_dir
    all_files = sorted(exp_dir.glob(f"**/*{search_key}*.parquet"))
    all_fidelities = {}

    for file in all_files:
        pq = pd.read_parquet(file)
        fids = pq["fidelity"].astype(np.int32)
        benchmark = pq["benchmark"].iloc[0]
        seed = pq["seed"].iloc[0]  # Assuming each parquet file has a `seed` column
        if benchmark not in all_fidelities:
            all_fidelities[benchmark] = {}
        all_fidelities[benchmark][seed] = fids.tolist()
    return all_fidelities


def get_percentiles(all_fids: dict) -> dict: # noqa: D103
    percentiles = {}
    for benchmark, seed_dict in all_fids.items():
        per_seed_stats = []
        for _, fids in seed_dict.items():
            max_fid_count = sum(fid == max(fids) for fid in fids)
            stats = {
                "25th": np.percentile(fids, 25),
                "50th": np.percentile(fids, 50),
                "75th": np.percentile(fids, 75),
                "90th": np.percentile(fids, 90),
                "95th": np.percentile(fids, 95),
                "mean": np.mean(fids),
                "std": np.std(fids),
                "max_fid_num_sampled": max_fid_count
            }
            per_seed_stats.append(stats)
        # Now average across seeds
        percentiles[benchmark] = {
            key: np.mean([s[key] for s in per_seed_stats]) for key in per_seed_stats[0]
        }
    return percentiles


def plot_fid_distribution(all_fids: dict):  # noqa: D103
    import matplotlib.pyplot as plt
    import numpy as np

    fig, axes = plt.subplots(3, 3, figsize=(15, 10))
    axes = axes.flatten()

    avg_counts = {}

    for i, (benchmark, seed_dict) in enumerate(all_fids.items()):
        max_fid = max([max(fids) for fids in seed_dict.values()])
        avg_counts[benchmark] = np.mean([np.bincount(fids, minlength=max_fid + 1)[1:] for fids in seed_dict.values()], axis=0)

        axes[i].bar(range(len(avg_counts[benchmark])), avg_counts[benchmark], width=2.0)
        axes[i].set_xticks(np.linspace(1, max_fid, num=5))
        axes[i].set_title(benchmark)
        axes[i].set_xlabel("Fidelity")
        axes[i].set_ylabel("Count")

    plt.tight_layout()
    plt.show()


def do_all_eda(search_key: str, exp_dir: str = "all_yahpo_25"):  # noqa: D103
    all_fids = get_all_fids(search_key, exp_dir)
    percentiles = get_percentiles(all_fids)
    pprint(percentiles, depth=4)  # noqa: T203
    plot_fid_distribution(all_fids)

In [None]:
do_all_eda("NepsHyperbandRW")

In [None]:
do_all_eda("NepsMOASHA")

In [None]:
do_all_eda("MOMFBO")