In [None]:
import os, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
base_dir = os.path.dirname(parent_dir)
src_dir = os.path.join(base_dir, "src")
package_dir = os.path.join(parent_dir, "testing", "evaluation_packages")
plot_dir = cwd + "/plots/"

sys.path.insert(0, src_dir)

from evaluate import ModelPhantom, EvaluationPackage, metrics # From main src dir 
from paper import latex_utils

In [None]:
test_package_name = "test_eval_pack.pkl"
training_package_name = "train_eval_pack.pkl"

In [None]:
test_package, training_package = None, None

try:
    print("Loading test set evaluation packages...")
    test_package = EvaluationPackage.load_from(package_dir + "/" +test_package_name)
    print("Successfully loaded test set evaluation packages.\n")
except Exception as e:
    print(f"Error loading test set evaluation package: {e}\n")

try:
    print("Loading training set evaluation packages...")
    training_package = EvaluationPackage.load_from(package_dir + "/" + training_package_name)
    print("Successfully loaded training set evaluation packages.\n")
except Exception as e:
    print(f"Error loading training set evaluation package: {e}\n")

In [None]:
###############################
####### Which metrics? ########
###############################

metrics_list = [
    metrics.rse,
    metrics.mag_and_angle, # Not actually used in the paper
]

##############################
####### Which models? ########
##############################

models = [
    ModelPhantom(name="ActuationNet", dataset_percentage=100, structure=(512, 512, 512)),
    ModelPhantom(name="ActuationNet", dataset_percentage=100, structure=(256, 256, 256)),
    ModelPhantom(name="ActuationNet", dataset_percentage=100, structure=(256, 256)),
    ModelPhantom(name="ActuationNet", dataset_percentage=50, structure=(512, 512, 512)),
    ModelPhantom(name="ActuationNet", dataset_percentage=20, structure=(512, 512, 512)),
    ModelPhantom(name="ActuationNet", dataset_percentage=5, structure=(512, 512, 512)),
    ModelPhantom(name="ActuationNet", dataset_percentage=1, structure=(512, 512, 512)),
    ModelPhantom(name="PotentialNet", dataset_percentage=100, structure=(512, 512, 512)),
    ModelPhantom(name="PotentialNet", dataset_percentage=100, structure=(256, 256, 256)),
    ModelPhantom(name="PotentialNet", dataset_percentage=100, structure=(256, 256)),
    ModelPhantom(name="PotentialNet", dataset_percentage=50, structure=(512, 512, 512)),
    ModelPhantom(name="PotentialNet", dataset_percentage=20, structure=(512, 512, 512)),
    ModelPhantom(name="PotentialNet", dataset_percentage=5, structure=(512, 512, 512)),
    ModelPhantom(name="PotentialNet", dataset_percentage=1, structure=(512, 512, 512)),
    ModelPhantom(name="DirectNet", dataset_percentage=100, structure=(512, 512, 512)),
    ModelPhantom(name="DirectNet", dataset_percentage=100, structure=(256, 256, 256)),
    ModelPhantom(name="DirectNet", dataset_percentage=100, structure=(256, 256)),
    ModelPhantom(name="DirectNet", dataset_percentage=50, structure=(512, 512, 512)),
    ModelPhantom(name="DirectNet", dataset_percentage=20, structure=(512, 512, 512)),
    ModelPhantom(name="DirectNet", dataset_percentage=5, structure=(512, 512, 512)),
    ModelPhantom(name="DirectNet", dataset_percentage=1, structure=(512, 512, 512)),
    ModelPhantom(name="DirectGBT", dataset_percentage=100, structure=(128,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=100, structure=(64,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=100, structure=(32,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=50, structure=(128,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=20, structure=(128,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=5, structure=(128,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=1, structure=(128,)),
    ModelPhantom(name="MPEM", dataset_percentage=5, structure=(1,)),
    ModelPhantom(name="MPEM", dataset_percentage=5, structure=(2,)),
    ModelPhantom(name="MPEM", dataset_percentage=5, structure=(3,)),
    ModelPhantom(name="MPEM", dataset_percentage=1, structure=(3,)),

]

full_models = [model for model in models if (model.dataset_percentage == 100 or model.name == "MPEM" and model.dataset_percentage == 5)]
large_models = [model for model in models if model.structure == (512, 512, 512) or model.structure == (128,) or model.structure == (3,)]
large_full_models = [model for model in large_models if model.dataset_percentage == 100]

# Compute metrics
for metric in metrics_list:
    print(f"Computing {metric.__name__} for test set...")
    test_package.apply_field_metric(metric)
    print(f"Computing {metric.__name__} for training set...")
    training_package.apply_field_metric(metric)
    print()

In [None]:
def plot_train_test_rmse_rank_and_pct_suite_latex(
    train_pkg,
    test_pkg,
    models_rank,
    models_pct,
    *,
    title=None,
    figsize=(10, 4.8),
    ylabel_test=r"Test RMSE (mT)",
    ylabel_gap=r"Test $-$ Train RMSE (mT)",
    xlabel_rank=r"Model complexity rank",
    xlabel_pct=r"Dataset percentage (\%)",
    marker="o",
    ymin_test=None,
    ymax_test=None,
    ymin_gap=None,
    ymax_gap=None,
    legend_out=True,
    legend_ncol=1,
    latex_params=None,      # e.g. latex_prms_2img or latex_prms_3img
    usetex=True,
    panel_labels=("a", "b", "c", "d"),
    panel_label_y_top=-0.15,
    panel_label_y_bot=-0.38,
    # visual grouping controls
    column_titles=(r"vs model complexity", r"vs dataset percentage"),
    show_column_titles=True,

    # --- legend sizing control ---
    legend_match_axes_labelsize=True, 
    legend_fontsize=None,               # override (takes precedence)

    # --- pass a fixed name->color mapping (kept as-is for overlapping models) ---
    name_to_color=None,                 # dict: { "ActuationNet": (r,g,b), ... } or {name: "C0", ...}

    store_to="rmse_suite.pdf"
):
    
    def _metric_leaf(pkg, phantom):
        name, dataset_percentage, structure = phantom.keys()
        try:
            return pkg.field_metrics[name][int(dataset_percentage)][structure]
        except Exception as e:
            raise KeyError(
                f"Missing field metrics for model '{phantom.string(verbose=True)}'. "
                f"Tried dataset_percentage keys: {dataset_percentage!r}, {int(dataset_percentage)!r}, {str(dataset_percentage)!r}"
            ) from e

    # ----------------------------
    # Build DF for rank panel
    # ----------------------------
    rows_rank = []
    for ph in models_rank:
        name, dataset_percentage, structure = ph.keys()
        leaf_tr = _metric_leaf(train_pkg, ph)
        leaf_te = _metric_leaf(test_pkg, ph)

        rmse_train = float(leaf_tr["rmse"])
        rmse_test  = float(leaf_te["rmse"])

        structure_tuple = None if structure is None else tuple(structure)
        size_score = 1.0 if structure_tuple is None else float(np.prod(np.asarray(structure_tuple, dtype=np.float64)))

        rows_rank.append({
            "name": name,
            "dataset_percentage": int(dataset_percentage),
            "structure": structure_tuple,
            "size_score": size_score,
            "rmse_train": rmse_train,
            "rmse_test": rmse_test,
            "rmse_gap": rmse_test - rmse_train,
        })

    df_rank = pd.DataFrame(rows_rank)
    if df_rank.empty:
        raise ValueError("models_rank produced no rows (empty list or missing rmse).")

    df_rank = df_rank.sort_values(["name", "size_score", "structure"], na_position="last").reset_index(drop=True)
    df_rank["size_rank"] = df_rank.groupby("name").cumcount() + 1

    # ----------------------------
    # Build DF for pct panel
    # ----------------------------
    rows_pct = []
    for ph in models_pct:
        name, dataset_percentage, structure = ph.keys()
        leaf_tr = _metric_leaf(train_pkg, ph)
        leaf_te = _metric_leaf(test_pkg, ph)

        rmse_train = float(leaf_tr["rmse"])
        rmse_test  = float(leaf_te["rmse"])

        structure_tuple = None if structure is None else tuple(structure)
        line_id = name if structure_tuple is None else f"{name}_{'x'.join(map(str, structure_tuple))}"

        rows_pct.append({
            "name": name,
            "line_id": line_id,
            "dataset_percentage": int(dataset_percentage),
            "structure": structure_tuple,
            "rmse_train": rmse_train,
            "rmse_test": rmse_test,
            "rmse_gap": rmse_test - rmse_train,
        })

    df_pct = pd.DataFrame(rows_pct)
    if df_pct.empty:
        raise ValueError("models_pct produced no rows (empty list or missing rmse).")

    df_pct = df_pct.sort_values(["line_id", "dataset_percentage"]).reset_index(drop=True)

    # ----------------------------
    # Palette keyed by model name
    # ----------------------------
    model_names = sorted(set(df_rank["name"].unique()).union(set(df_pct["name"].unique())))
    sns_palette = sns.color_palette("deep", n_colors=len(model_names))

    if name_to_color is None:
        palette = {nm: sns_palette[i] for i, nm in enumerate(model_names)}
    else:
        palette = dict(name_to_color)  # don't mutate caller
        for i, nm in enumerate(model_names):
            if nm not in palette:
                palette[nm] = sns_palette[i]

    # ----------------------------
    # Legend order = order models were passed in
    # ----------------------------
    legend_order = []
    for ph in list(models_rank) + list(models_pct):
        nm = ph.keys()[0]
        if nm not in legend_order:
            legend_order.append(nm)

    # ----------------------------
    # Plotting (inside LaTeX rc_context)
    # ----------------------------
    latex_params = {} if latex_params is None else dict(latex_params)
    latex_params.setdefault("figsize", figsize)
    latex_params.setdefault("usetex", usetex)

    if legend_fontsize is None and legend_match_axes_labelsize:
        legend_fontsize = latex_params.get("axes_labelsize", None)

    with latex_utils.rc_context_latex(**latex_params):
        fig = plt.figure(constrained_layout=True)
        gs = fig.add_gridspec(
            2, 2,
            height_ratios=[1, 1],
            width_ratios=[1, 1],
            wspace=0.20,
            hspace=0.06
        )

        ax_rank_top = fig.add_subplot(gs[0, 0])
        ax_rank_bot = fig.add_subplot(gs[1, 0], sharex=ax_rank_top)

        ax_pct_top  = fig.add_subplot(gs[0, 1], sharey=ax_rank_top)
        ax_pct_bot  = fig.add_subplot(gs[1, 1], sharey=ax_rank_bot, sharex=ax_pct_top)

        def _prettify(ax, show_xticklabels=True):
            ax.grid(True, alpha=0.25, linewidth=0.6)
            ax.spines["top"].set_visible(False)
            ax.spines["right"].set_visible(False)
            if not show_xticklabels:
                ax.tick_params(axis="x", labelbottom=False)

        legend_handles = {}

        # left: rank (one line per name)
        for name, g in df_rank.groupby("name", sort=False):
            g = g.sort_values("size_rank")
            color = palette[name]

            ln, = ax_rank_top.plot(g["size_rank"], g["rmse_test"], marker=marker, color=color)
            ax_rank_bot.plot(g["size_rank"], g["rmse_gap"], marker=marker, color=color)

            if name not in legend_handles:
                legend_handles[name] = ln

        # right: pct (connect per line_id; same color per name)
        xmax_pct = int(df_pct["dataset_percentage"].max())

        for line_id, g in df_pct.groupby("line_id", sort=False):
            g = g.sort_values("dataset_percentage")
            name = g["name"].iloc[0]
            color = palette[name]

            ax_pct_top.plot(g["dataset_percentage"], g["rmse_test"], marker=marker, color=color)
            ax_pct_bot.plot(g["dataset_percentage"], g["rmse_gap"],  marker=marker, color=color)

            # --- NEW: dashed horizontal extension for MPEM ---
            if name == "MPEM" and len(g) >= 1:
                x_last = int(g["dataset_percentage"].iloc[-1])
                if x_last < xmax_pct:
                    y_last_top = float(g["rmse_test"].iloc[-1])
                    y_last_bot = float(g["rmse_gap"].iloc[-1])

                    ax_pct_top.plot([x_last, xmax_pct], [y_last_top, y_last_top],
                                    linestyle="--", linewidth=1.2, color=color)
                    ax_pct_bot.plot([x_last, xmax_pct], [y_last_bot, y_last_bot],
                                    linestyle="--", linewidth=1.2, color=color)

        ax_rank_top.set_ylabel(ylabel_test)
        ax_rank_bot.set_ylabel(ylabel_gap)
        ax_rank_bot.set_xlabel(xlabel_rank)
        ax_pct_bot.set_xlabel(xlabel_pct)

        ax_pct_top.tick_params(axis="y", labelleft=False)
        ax_pct_bot.tick_params(axis="y", labelleft=False)

        ax_rank_bot.axhline(0.0, linewidth=1.0, alpha=0.6)
        ax_pct_bot.axhline(0.0, linewidth=1.0, alpha=0.6)

        roman_upper = {
            1:"I",2:"II",3:"III",4:"IV",5:"V",6:"VI",7:"VII",8:"VIII",9:"IX",10:"X",
            11:"XI",12:"XII",13:"XIII",14:"XIV",15:"XV",
        }
        max_rank = int(df_rank["size_rank"].max())
        ax_rank_bot.set_xticks(np.arange(1, max_rank + 1))
        ax_rank_bot.set_xticklabels([rf"$\mathrm{{{roman_upper.get(i, str(i))}}}$" for i in range(1, max_rank + 1)])

        xticks_pct = np.sort(df_pct["dataset_percentage"].unique())
        ax_pct_bot.set_xticks(xticks_pct)

        if (ymin_test is not None) or (ymax_test is not None):
            ax_rank_top.set_ylim(bottom=ymin_test, top=ymax_test)
        if (ymin_gap is not None) or (ymax_gap is not None):
            ax_rank_bot.set_ylim(bottom=ymin_gap, top=ymax_gap)

        _prettify(ax_rank_top, show_xticklabels=False)
        _prettify(ax_rank_bot, show_xticklabels=True)
        _prettify(ax_pct_top,  show_xticklabels=False)
        _prettify(ax_pct_bot,  show_xticklabels=True)

        if show_column_titles and column_titles is not None:
            ax_rank_top.set_title(column_titles[0])
            ax_pct_top.set_title(column_titles[1])

        fig.suptitle(title if title is not None else r"Test RMSE and (Test $-$ Train) gap")
        fig.align_ylabels([ax_rank_top, ax_rank_bot])

        a, b, c, d = panel_labels
        ax_rank_top.text(0.5, panel_label_y_top, f"({a})", transform=ax_rank_top.transAxes,
                         ha="center", va="top", clip_on=False)
        ax_rank_bot.text(0.5, panel_label_y_bot, f"({b})", transform=ax_rank_bot.transAxes,
                         ha="center", va="top", clip_on=False)
        ax_pct_top.text(0.5, panel_label_y_top, f"({c})", transform=ax_pct_top.transAxes,
                        ha="center", va="top", clip_on=False)
        ax_pct_bot.text(0.5, panel_label_y_bot, f"({d})", transform=ax_pct_bot.transAxes,
                        ha="center", va="top", clip_on=False)

        leg_labels = [nm for nm in legend_order if nm in legend_handles]
        leg_handles = [legend_handles[nm] for nm in leg_labels]

        if legend_out:
            fig.legend(
                leg_handles, leg_labels,
                frameon=False,
                loc="center left",
                bbox_to_anchor=(1.01, 0.5),
                ncol=legend_ncol,
                fontsize=legend_fontsize,
            )
        else:
            ax_pct_top.legend(
                leg_handles, leg_labels,
                frameon=False,
                loc="best",
                fontsize=legend_fontsize,
            )

        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(store_to), exist_ok=True)
        fig.savefig(store_to, bbox_inches="tight")
        plt.show()

    return fig, (ax_rank_top, ax_rank_bot, ax_pct_top, ax_pct_bot), {"rank": df_rank, "pct": df_pct}, palette

In [None]:
palette_order = ["ActuationNet", "PotentialNet", "DirectNet", "DirectGBT", "MPEM"]
sns_pal = sns.color_palette("deep", n_colors=len(palette_order))
name_to_color = {nm: sns_pal[i] for i, nm in enumerate(palette_order)}


fig_rmse, axes_rmse, dfs_rmse, rmse_palette = plot_train_test_rmse_rank_and_pct_suite_latex(
    training_package,
    test_package,
    models_rank=full_models,
    models_pct=large_models,
    latex_params=latex_utils.latex_prms_2img,
    usetex=True,
    title="",
    name_to_color=name_to_color,
    ylabel_gap=r"Gap RMSE (mT)",
    ymin_test=0.0,
    store_to=plot_dir + "rmse_suite.pdf",
)