In [None]:
import sys
sys.path.append('../..')

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

from notebooks.star.plot_utils import set_default_matplotlib_settings, method_style_map

set_default_matplotlib_settings()

In [None]:
experiment_ids = []
# experiment_ids = [
#     "20251001_115347",
#     "20251001_163806",
#     "20251001_203917",
#     "20251002_173317",
# ]  # target: rural
# experiment_ids = [
#     "20251001_131642",
#     "20251001_151204",
#     "20251001_220615",
#     "20251002_093614",
# ]  # target: urban
# experiment_ids = ['20251002_142531'] #kallus, target rural
# experiment_ids = ['20251002_134509', '20251002_141407'] #kallus, target urban

experiment_ids_str = "-".join(experiment_ids)
if len(experiment_ids) > 0:
    dataframes = []
    for ts in experiment_ids:
        filename = f"star_results/experiment_{ts}.csv"
        try:
            df = pd.read_csv(filename)
            dataframes.append(df)
        except FileNotFoundError:
            print(f"File {filename} not found.")
    if dataframes:
        results_df = pd.concat(dataframes, ignore_index=True)
    else:
        raise ValueError('experiment_ids is empty')
else:
    raise ValueError('experiment_ids is empty')


############# Counts NaNs in data ###############
nan_counts_by_method = (
    df.groupby("method")
    .apply(lambda g: g.isna().any(axis=1).sum())
    .reset_index(name="rows_with_nan")
)
print(nan_counts_by_method)
# Find rows with any NaN
nan_rows = df[df.isna().any(axis=1)]

# Print how many and show a sample
print(f"Total rows with NaN: {len(nan_rows)}")
print(nan_rows.head(10))  # show first 10 rows


############# Filtering and preprocessing ###############
methods_to_exclude = ["Kallus et al. [2018]"]
results_df = results_df[~results_df["method"].isin(methods_to_exclude)]


############# Rename methods ###########################
results_df["method"] = results_df["method"].replace(
    {"QR-learner": "QR-learner (ours)", "Combined learner": "Combined learner (ours)"}
)


# Add NaN for missing 'fraction_rural' column if it doesn't exist
if "fraction_rural" not in results_df.columns:
    results_df["fraction_rural"] = np.nan
if "dropped_covar" not in results_df.columns:
    results_df["dropped_covar"] = "g1surban"

results_df["mse"] = results_df["rmse"] ** 2

In [None]:
grouped_results = results_df.groupby(by=["n1", "n0", "dropped_covar", "method"]).agg(
    rmse_mean=("rmse", "mean"),
    rmse_se=("rmse", lambda x: np.std(x) / np.sqrt(len(x))),
    bias_mean=("abs_bias", "mean"),
    bias_se=("abs_bias", lambda x: np.std(x) / np.sqrt(len(x))),
    var_mean=("var", "mean"),
    var_se=("var", lambda x: np.std(x) / np.sqrt(len(x)))
)

# Print unique values with labels
print("Unique n1 values:", results_df.n1.unique())
print("Unique n0 values:", results_df.n0.unique())
print("Unique dropped_covar values:", results_df.dropped_covar.unique())
print("Unique fraction_rural values:", results_df.fraction_rural.unique())

print("\nGrouped results:")
# Optionally, show only the first few rows for brevity
print(grouped_results.head(6))  # show first 10 rows

In [None]:
import math
from matplotlib.ticker import MultipleLocator
import seaborn as sns

metrics = [
    ("rmse", "RMSE"),
    ("bias", "Bias"),
    #("var", "Variance"),
]

# Get all unique n0 values
n0_values = grouped_results.index.get_level_values(1).unique()
n_panels = len(n0_values)

# choose grid layout (square-ish)
ncols = math.ceil(math.sqrt(n_panels))
nrows = math.ceil(n_panels / ncols)

for metric_key, metric_label in metrics:
    fig, axes = plt.subplots(
        nrows, ncols, figsize=(6 * ncols, 6 * nrows), squeeze=False
    )

    for idx, n0_val in enumerate(n0_values):
        row, col = divmod(idx, ncols)
        ax = axes[row, col]

        # Filter results for this n0
        filtered_results = grouped_results.loc[
            (slice(None), n0_val, "g1surban"), :
        ].reset_index()

        for method in filtered_results["method"].unique():
            method_data = filtered_results[filtered_results["method"] == method]
            style = method_style_map.get(
                method, {"color": "black", "linestyle": "-", "marker": "o"}
            )

            # Line plot
            sns.lineplot(
                data=method_data,
                x="n1",
                y=f"{metric_key}_mean",
                label=method,
                color=style["color"],
                linestyle=style["linestyle"],
                marker=style["marker"],
                ax=ax,
                legend=False,
                alpha=0.8,
                markersize=7,
                linewidth=2.5
            )

            # Error bars
            ax.errorbar(
                method_data["n1"],
                method_data[f"{metric_key}_mean"],
                yerr=method_data[f"{metric_key}_se"],
                fmt="none",
                capsize=5,
                color=style["color"],
                linestyle=style["linestyle"],
            )

        ax.set_title(f"STAR dataset", fontsize=22)
        ax.set_xlabel("Trial sample size $n_1$", fontsize=22)
        ax.set_ylabel(metric_label, fontsize=22)
        ax.tick_params(labelsize=18)
        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)
        ax.xaxis.set_major_locator(MultipleLocator(100))         # add ticks every 100

    # Hide unused panels if grid too big
    for idx in range(n_panels, nrows * ncols):
        row, col = divmod(idx, ncols)
        axes[row, col].set_visible(False)

    # Shared legend
    handles, labels = axes[0, 0].get_legend_handles_labels()
    ordered_methods = list(method_style_map.keys())
    ordered_handles_labels = sorted(
        zip(handles, labels), key=lambda x: ordered_methods.index(x[1])
    )
    handles, labels = zip(*ordered_handles_labels)

    fig.legend(
        handles,
        labels,
        title="Method",
        fontsize=20,
        title_fontsize=20,
        loc="upper center",
        bbox_to_anchor=(0.57, 0.08),
        ncol=2,
    )

    plt.tight_layout(rect=[0, 0.05, 1, 1])  # leave space for legend
    plt.savefig(
            f"star_results/{experiment_ids_str}_{metric_key}_n0-{n0_val}_results.pdf",  # filename includes metric name
            format="pdf",
            dpi=300,
            bbox_inches="tight"
        )