In [None]:
import os, sys
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import seaborn as sns

cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
base_dir = os.path.dirname(parent_dir)
src_dir = os.path.join(base_dir, "src")
package_dir = os.path.join(parent_dir, "testing", "evaluation_packages")
plot_dir = cwd + "/plots/"

sys.path.insert(0, src_dir)

from evaluate import ModelPhantom, EvaluationPackage, metrics # From main src dir 
from paper import latex_utils

In [None]:
test_package_name = "test_eval_pack.pkl"
training_package_name = "train_eval_pack.pkl"

In [None]:
test_package, training_package = None, None

try:
    print("Loading test set evaluation packages...")
    test_package = EvaluationPackage.load_from(package_dir + "/" + test_package_name)
    print("Successfully loaded test set evaluation packages.\n")
except Exception as e:
    print(f"Error loading test set evaluation package: {e}\n")

try:
    print("Loading training set evaluation packages...")
    training_package = EvaluationPackage.load_from(package_dir + "/" + training_package_name)
    print("Successfully loaded training set evaluation packages.\n")
except Exception as e:
    print(f"Error loading training set evaluation package: {e}\n")

In [None]:
###############################
####### Which metrics? ########
###############################

metrics_list = [
    metrics.grad_curl_div
]

##############################
####### Which models? ########
##############################

models = [
    ModelPhantom(name="ActuationNet", dataset_percentage=100, structure=(512, 512, 512)),
    ModelPhantom(name="ActuationNet", dataset_percentage=100, structure=(256, 256, 256)),
    ModelPhantom(name="ActuationNet", dataset_percentage=100, structure=(256, 256)),
    ModelPhantom(name="ActuationNet", dataset_percentage=50, structure=(512, 512, 512)),
    ModelPhantom(name="ActuationNet", dataset_percentage=20, structure=(512, 512, 512)),
    ModelPhantom(name="ActuationNet", dataset_percentage=5, structure=(512, 512, 512)),
    ModelPhantom(name="ActuationNet", dataset_percentage=1, structure=(512, 512, 512)),
    ModelPhantom(name="PotentialNet", dataset_percentage=100, structure=(512, 512, 512)),
    ModelPhantom(name="PotentialNet", dataset_percentage=100, structure=(256, 256, 256)),
    ModelPhantom(name="PotentialNet", dataset_percentage=100, structure=(256, 256)),
    ModelPhantom(name="PotentialNet", dataset_percentage=50, structure=(512, 512, 512)),
    ModelPhantom(name="PotentialNet", dataset_percentage=20, structure=(512, 512, 512)),
    ModelPhantom(name="PotentialNet", dataset_percentage=5, structure=(512, 512, 512)),
    ModelPhantom(name="PotentialNet", dataset_percentage=1, structure=(512, 512, 512)),
    ModelPhantom(name="DirectNet", dataset_percentage=100, structure=(512, 512, 512)),
    ModelPhantom(name="DirectNet", dataset_percentage=100, structure=(256, 256, 256)),
    ModelPhantom(name="DirectNet", dataset_percentage=100, structure=(256, 256)),
    ModelPhantom(name="DirectNet", dataset_percentage=50, structure=(512, 512, 512)),
    ModelPhantom(name="DirectNet", dataset_percentage=20, structure=(512, 512, 512)),
    ModelPhantom(name="DirectNet", dataset_percentage=5, structure=(512, 512, 512)),
    ModelPhantom(name="DirectNet", dataset_percentage=1, structure=(512, 512, 512)),
    ModelPhantom(name="DirectGBT", dataset_percentage=100, structure=(128,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=100, structure=(64,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=100, structure=(32,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=50, structure=(128,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=20, structure=(128,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=5, structure=(128,)),
    ModelPhantom(name="DirectGBT", dataset_percentage=1, structure=(128,)),
    ModelPhantom(name="MPEM", dataset_percentage=5, structure=(1,)),
    ModelPhantom(name="MPEM", dataset_percentage=5, structure=(2,)),
    ModelPhantom(name="MPEM", dataset_percentage=5, structure=(3,)),
    ModelPhantom(name="MPEM", dataset_percentage=1, structure=(3,)),
]

full_models = [model for model in models if model.dataset_percentage == 100]
large_models = [model for model in models if model.structure == (512, 512, 512) or model.structure == (128,)]
large_full_models = [model for model in large_models if model.dataset_percentage == 100]

# Compute metrics
for metric in metrics_list:
    print(f"Computing {metric.__name__} for test set...")
    test_package.apply_gradient_metric(metric)
    print(f"Computing {metric.__name__} for training set...")
    training_package.apply_gradient_metric(metric)
    print()

In [None]:
def plot_div_curl_box_suite_latex(
    pkg,
    models,
    *,
    div_key="div",
    curl_key="curl",          # or "curl_mag"
    verbose=True,

    # --- single-column sizing ---
    figsize=(3.45, 5.25),

    # --- overall figure title  ---
    suptitle=r"Maxwell residuals",

    # --- labels (WITH units) ---
    ylabel_div=r"$\nabla \cdot \field \;\;(\mathrm{mT}\,\mathrm{mm}^{-1})$",
    ylabel_curl=r"$\|\nabla\times \field\| \;\;(\mathrm{mT}\,\mathrm{mm}^{-1})$",

    # --- per-subplot titles ---
    title_div="Divergence",
    title_curl="Curl",

    # --- x tick labels ---
    rotate_xticks=28,
    xtick_ha="right",

    # --- box appearance / spacing ---
    showfliers=False,
    add_points=False,
    points_alpha=0.12,
    points_size=1.6,
    point_jitter=0.18,

    box_width=0.45,
    linewidth=1.0,
    xpad=-0.01,

    # --- layout  ---
    hspace=0.25,
    top=0.8,
    bottom=0.30,
    left=0.22,
    right=0.98,

    # --- same color interface as RMSE suite ---
    name_to_color=None,   # dict {model_name: color}; keeps given colors, fills missing deterministically

    # --- title panel prefixes ---
    panel_prefixes=("", ""),      # will render "(a) <title_div>" and "(b) <title_curl>"

    # --- latex rc context ---
    latex_params=None,
    usetex=True,

    # --- output ---
    out="maxwell.pdf",
    save=True,
    show=True,
):
    """
    2x1 single-column div/curl boxplot using pkg.gradient_metrics.
    """

    def _metric_leaf_grad(pkg_, phantom_):
        name, dataset_percentage, structure = phantom_.keys()
        try:
            return pkg_.gradient_metrics[name][int(dataset_percentage)][structure]
        except Exception as e:
            raise KeyError(
                f"Missing gradient metrics for model '{phantom_.string(verbose=True)}'. "
                f"Did you run pkg.apply_gradient_metric(metrics.grad_curl_div) first?"
            ) from e

    def _base_name(ph):
        return ph.keys()[0]

    def _as_1d(a, key, lbl):
        a = np.asarray(a)
        if a.ndim == 0:
            return a.reshape(1)
        if a.ndim != 1:
            raise ValueError(f"{key} for '{lbl}' must be 1D (per-sample), got shape {a.shape}")
        return a

    # ----------------------------
    # Build tidy DF with base-name labels only
    # ----------------------------
    labels = [_base_name(m) for m in models]
    if len(set(labels)) != len(labels):
        raise ValueError(
            "Base-name labels are not unique. Pass only one phantom per base model "
            "(e.g., large_full_models), or change labeling."
        )

    rows = []
    for ph, lbl in zip(models, labels):
        leaf = _metric_leaf_grad(pkg, ph)

        if div_key not in leaf or curl_key not in leaf:
            raise KeyError(
                f"Gradient-metric leaf for '{lbl}' missing keys. "
                f"Expected '{div_key}' and '{curl_key}'. Available keys: {sorted(list(leaf.keys()))}"
            )

        div_vals = _as_1d(leaf[div_key], div_key, lbl)
        curl_vals = _as_1d(leaf[curl_key], curl_key, lbl)

        rows += [{"model_label": lbl, "metric": div_key,  "value": float(v)} for v in div_vals]
        rows += [{"model_label": lbl, "metric": curl_key, "value": float(v)} for v in curl_vals]

    df = pd.DataFrame(rows)
    df["model_label"] = pd.Categorical(df["model_label"], categories=labels, ordered=True)
    df["model_name"] = df["model_label"].astype(str)

    # ----------------------------
    # Palette keyed ONLY by model name (same style as RMSE suite)
    # ----------------------------
    model_names = sorted(df["model_name"].unique())
    sns_palette = sns.color_palette("deep", n_colors=len(model_names))

    if name_to_color is None:
        palette = {nm: sns_palette[i] for i, nm in enumerate(model_names)}
    else:
        palette = dict(name_to_color)  # don't mutate caller
        for i, nm in enumerate(model_names):
            if nm not in palette:
                palette[nm] = sns_palette[i]

    label_palette = {lbl: palette[lbl] for lbl in labels}

    sns.set_style("whitegrid")

    # ----------------------------
    # LaTeX rc context
    # ----------------------------
    latex_params = {} if latex_params is None else dict(latex_params)
    latex_params.setdefault("figsize", figsize)
    latex_params.setdefault("usetex", usetex)

    # Prefixed titles
    pa, pb = panel_prefixes
    title_div_pref  = f"{title_div}"
    title_curl_pref = f"{title_curl}"

    with latex_utils.rc_context_latex(**latex_params):
        fig, axes = plt.subplots(
            2, 1,
            figsize=figsize,
            constrained_layout=False,
            sharex=True
        )
        ax_div, ax_curl = axes

        fig.subplots_adjust(
            hspace=hspace, top=top, bottom=bottom, left=left, right=right
        )

        if suptitle:
            fig.suptitle(suptitle)

        def _panel(ax, metric_key, ylabel, panel_title, show_xticks):
            sub = df[df["metric"] == metric_key]

            # Grad is mT/mm
            sub["value"] = sub["value"] * 1e-3


            sns.boxplot(
                data=sub,
                x="model_label",
                y="value",
                ax=ax,
                order=labels,
                palette=label_palette,
                width=box_width,
                showfliers=showfliers,
                linewidth=linewidth,
                whis=1.5,
            )

            if add_points:
                sns.stripplot(
                    data=sub,
                    x="model_label",
                    y="value",
                    ax=ax,
                    order=labels,
                    color="k",
                    alpha=points_alpha,
                    size=points_size,
                    jitter=point_jitter,
                )

            ax.set_title(panel_title)
            ax.set_xlabel("")
            ax.set_ylabel(ylabel)

            ax.grid(True, alpha=0.22, linewidth=0.6)
            ax.spines["top"].set_visible(False)
            ax.spines["right"].set_visible(False)

            if show_xticks:
                ax.tick_params(axis="x", rotation=rotate_xticks)
                for t in ax.get_xticklabels():
                    t.set_horizontalalignment(xtick_ha)
            else:
                ax.tick_params(axis="x", labelbottom=False)

            n = len(labels)
            ax.set_xlim(-0.5 + xpad, (n - 1) + 0.5 - xpad)

        _panel(ax_div,  div_key,  ylabel_div,  title_div_pref,  show_xticks=False)
        _panel(ax_curl, curl_key, ylabel_curl, title_curl_pref, show_xticks=True)

        fig.align_ylabels([ax_div, ax_curl])

        if save and out:
            # Ensure /plots dir exists
            plot_dir = os.path.dirname(out)
            os.makedirs(plot_dir, exist_ok=True)            

            fig.savefig(out, bbox_inches="tight")

        if show:
            plt.show()
        else:
            plt.close(fig)

    return fig, (ax_div, ax_curl), df, palette

In [None]:
palette_order = ["ActuationNet", "PotentialNet", "DirectNet", "DirectGBT", "MPEM"]
sns_pal = sns.color_palette("deep", n_colors=len(palette_order))
name_to_color = {nm: sns_pal[i] for i, nm in enumerate(palette_order)}

fig_maxwell, axes_maxwell, df_maxwell, pal_used = plot_div_curl_box_suite_latex(
    test_package,
    models=large_full_models,
    name_to_color=name_to_color,
    latex_params=latex_utils.latex_prms_singlecol,
    usetex=True,
    suptitle="",
    out=plot_dir + "maxwell.pdf",
    save=True,
    show=True,
)