# Evaluate model eval time (python interface, later do cpp as well)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import sys
import torch
import time
import pickle
import seaborn as sns
from tqdm.auto import tqdm
from matplotlib.lines import Line2D

cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
base_dir = os.path.dirname(parent_dir)
data_dir = base_dir + "/data/octomag_data/split_dataset"
src_dir = base_dir + "/src"
params_dir = parent_dir + "/training/params"
mpem_params_dir = base_dir + "/mpem/optimized_parameters"

results_dir = cwd + "/data"
results_file = results_dir + "/eval_time_results.pkl"

plot_dir = cwd + "/plots/"

sys.path.insert(0, src_dir)

from calibration import MPEM, MPEM_AVAILABLE, ActuationNet, PotentialNet
from evaluate import ModelPhantom
from paper import latex_utils

###################################
########## Include MPEM? ##########
###################################
include_mpem = False

######################################
########## Recompute times? ##########
######################################
recompute_times = True

print("MPEM_AVAILABLE:", MPEM_AVAILABLE)

## Load the dataset

In [None]:
currents = None
positions = None
if recompute_times:
    test_data_path = data_dir + "/test_data.pkl"
    try:
        test_data = pd.read_pickle(test_data_path)    

        # Keep only relevant columns
        pos_cols = ["x", "y", "z"]
        em_cols = [col for col in test_data.columns if col.startswith("em_")]

        # Extract positions and EM currents
        currents = test_data[em_cols].values
        positions = test_data[pos_cols].values

    except FileNotFoundError:
        print(f"Test data file not found at {test_data_path}. Running with constant data.")

## Load the data

In [None]:
if recompute_times:
    models = [
        ActuationNet.load_from(params_dir + "/ActuationNet_100_256x256.pt", map_location="cpu"),
        ActuationNet.load_from(params_dir + "/ActuationNet_100_256x256x256.pt", map_location="cpu"),
        ActuationNet.load_from(params_dir + "/ActuationNet_100_512x512x512.pt", map_location="cpu"),
        PotentialNet.load_from(params_dir + "/PotentialNet_100_256x256.pt", map_location="cpu"),
        PotentialNet.load_from(params_dir + "/PotentialNet_100_256x256x256.pt", map_location="cpu"),
        PotentialNet.load_from(params_dir + "/PotentialNet_100_512x512x512.pt", map_location="cpu"),
    ]


    if MPEM_AVAILABLE and include_mpem:
        for i, order in enumerate(["dipole", "quadrupole", "octopole"]):
            mpem_model = MPEM(mpem_params_dir + f"/optimized_{order}_5.yaml", f"MPEM_5_{i+1}")
            models.append(mpem_model)
            print(f"MPEM model of order {order} loaded.")
    else:
        mpem_model = None
        print("MPEM model not available/included.")


    models_dict = {}
    for model in models:
        phantom = ModelPhantom.from_calibration(model)
        name = phantom.name
        structure = phantom.structure
        models_dict.setdefault(name, {})[structure] = { "model": model}

## Run time eval

In [None]:
# Helpers
def _sync_if_cuda(model):
    try:
        p = next(model.parameters())
        if p.is_cuda:
            torch.cuda.synchronize()
    except Exception:
        pass

def _is_cuda(model) -> bool:
    try:
        return next(model.parameters()).is_cuda
    except Exception:
        return False


# Timing loop (GPU-safe)
N = 10_000

# warmup iterations to avoid first-call overhead
WARMUP = 0


# Check if loaded data
if (positions is None or currents is None) and recompute_times:
    print("No test data available. Using constant data for timing.")
    positions = np.zeros((N, 3))
    currents = np.ones((N, 8))  # assuming octomag

# Check if results_dir exists
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

if recompute_times:
    for name in models_dict.keys():
        for structure in models_dict[name].keys():
            model = models_dict[name][structure]["model"]

            # warmup
            for pos in positions[:WARMUP]:
                _ = model.currents_field_jacobian_and_bias(pos)
            _sync_if_cuda(model)

            # --- field only ---
            _sync_if_cuda(model)
            t0 = time.perf_counter()
            for pos in tqdm(positions[:N], desc=f"Timing {name} {structure} field eval", leave=False):
                _ = model.currents_field_jacobian_and_bias(pos)
            _sync_if_cuda(model)
            t1 = time.perf_counter()
            total_time_field = t1 - t0

            # warmup for full
            for pos in positions[:WARMUP]:
                _ = model.currents_full_jacobian_and_bias(pos)
            _sync_if_cuda(model)

            # --- full (field + grad5) ---
            _sync_if_cuda(model)
            t0 = time.perf_counter()
            for pos in tqdm(positions[:N], desc=f"Timing {name} {structure} full eval", leave=False):
                _ = model.currents_full_jacobian_and_bias(pos)
            _sync_if_cuda(model)
            t1 = time.perf_counter()
            total_time_full = t1 - t0

            models_dict[name][structure]["time_field"] = total_time_field / N *1e3
            models_dict[name][structure]["time_full"]  = total_time_full / N  *1e3

            # Delete model from dict to allow serialization
            models_dict[name][structure]["model"] = None

    with open(results_file, "wb") as f:
        pickle.dump(models_dict, f)


## Compare evaluation times

In [None]:
if not recompute_times:
    with open(results_file, "rb") as f:
        timings = pickle.load(f)
else:
    timings = models_dict

In [None]:
def timings_dict_to_rank_df(timings):
    rows = []
    for name, structs in timings.items():
        for structure, d in structs.items():
            # structure can be tuple/int/etc; normalize to tuple
            st = None if structure is None else tuple(structure)

            size_score = 1.0 if st is None else float(np.prod(np.asarray(st, dtype=np.float64)))

            rows.append({
                "name": name,
                "structure": st,
                "size_score": size_score,
                "time_field": float(d["time_field"]),
                "time_full": float(d["time_full"]),
            })

    df = pd.DataFrame(rows)
    if df.empty:
        raise ValueError("No timing rows found in timings dict.")

    # define rank per name (like your RMSE code)
    df = df.sort_values(["name", "size_score", "structure"], na_position="last").reset_index(drop=True)
    df["size_rank"] = df.groupby("name").cumcount() + 1
    return df

In [None]:
def plot_eval_time_rank_single_plot(
    timings,
    *,
    title="",
    figsize=(3.5, 2.4),
    ylabel=r"Evaluation time (ms)",
    xlabel=r"Model complexity rank",
    latex_params=None,
    usetex=True,
    name_to_color=None,
    marker_ab="o",
    marker_abg="^",
    linestyle_ab="-",
    linestyle_abg="--",
    lw=1.6,
    ms=4.2,
    legend_fontsize=None,
    savepath="eval_time_rank_singleplot.pdf",
):

    df = timings_dict_to_rank_df(timings)
    model_names = sorted(df["name"].unique())

    # palette + ordering
    if name_to_color is None:
        prop_cycle = plt.rcParams["axes.prop_cycle"].by_key().get("color", [])
        palette = {nm: prop_cycle[i % len(prop_cycle)] for i, nm in enumerate(model_names)}
        legend_order = model_names
    else:
        palette = dict(name_to_color)
        for nm in model_names:
            palette.setdefault(nm, None)
        legend_order = [nm for nm in name_to_color.keys() if nm in model_names]
        legend_order += [nm for nm in model_names if nm not in legend_order]

    latex_params = {} if latex_params is None else dict(latex_params)
    latex_params.setdefault("figsize", figsize)
    latex_params.setdefault("usetex", usetex)

    if legend_fontsize is None:
        legend_fontsize = latex_params.get("axes_labelsize", None)

    with latex_utils.rc_context_latex(**latex_params):
        fig, ax = plt.subplots()

        # ---- plot data ----
        for name, g in df.groupby("name", sort=False):
            g = g.sort_values("size_rank")
            c = palette[name]
            ax.plot(g["size_rank"], g["time_field"],
                    color=c, lw=lw, linestyle=linestyle_ab,
                    marker=marker_ab, markersize=ms)
            ax.plot(g["size_rank"], g["time_full"],
                    color=c, lw=lw, linestyle=linestyle_abg,
                    marker=marker_abg, markersize=ms)

        # ---- axes cosmetics ----
        ax.grid(True, alpha=0.25, linewidth=0.6)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.set_ylabel(ylabel)
        ax.set_xlabel(xlabel)
        # keep integer ticks
        xticks = np.arange(1, int(df["size_rank"].max()) + 1)
        ax.set_xticks(xticks)

        # Roman numerals
        roman = ["I","II","III","IV","V","VI","VII","VIII","IX","X",
                "XI","XII","XIII","XIV","XV","XVI","XVII","XVIII","XIX","XX"]

        ax.set_xticklabels(roman[:len(xticks)])
        if title is not None:
            ax.set_title(title)

        # More room below so x-label isn't crowded by legends
        fig.subplots_adjust(bottom=0.0)

        # --- spacing knobs (tune these) ---
        y_models   = -0.18   # closer to axes -> more separation from x-label
        y_quantity = -0.28   # bring quantity legend closer to model legend

        # ---- Legend 1: models (colors) ----
        model_handles = [Line2D([0], [0], color=palette[nm], lw=lw) for nm in legend_order]
        leg1 = ax.legend(
            model_handles, legend_order,
            loc="upper center",
            bbox_to_anchor=(0.5, y_models),
            ncol=min(3, max(1, len(legend_order))),
            frameon=False,
            fontsize=legend_fontsize,
            handlelength=2.2,
            columnspacing=1.0,
        )
        ax.add_artist(leg1)

        # ---- Legend 2: Ab vs Abg (styles) ----
        style_handles = [
            Line2D([0], [0], color="k", lw=lw, linestyle=linestyle_ab,
                   marker=marker_ab, markersize=ms),
            Line2D([0], [0], color="k", lw=lw, linestyle=linestyle_abg,
                   marker=marker_abg, markersize=ms),
        ]
        style_labels = [r"$\actuation_b$", r"$\actuation_{b,g}$"]

        ax.legend(
            style_handles, style_labels,
            loc="upper center",
            bbox_to_anchor=(0.5, y_quantity),
            ncol=2,
            frameon=False,
            fontsize=legend_fontsize,
            handlelength=2.2,
            columnspacing=1.4,
        )

        # Ensure plot directory exists
        os.makedirs(os.path.dirname(savepath), exist_ok=True)
        fig.savefig(savepath, bbox_inches="tight")
        plt.show()

    return fig, (ax,), df, palette

In [None]:
palette_order = ["ActuationNet", "PotentialNet", "DirectNet", "DirectGBT", "MPEM"] # Include all models so we keep the correct colors for those included
sns_pal = sns.color_palette("deep", n_colors=len(palette_order))
name_to_color = {nm: sns_pal[i] for i, nm in enumerate(palette_order)}

if recompute_times:
    timings = models_dict
else:
    with open(results_file, "rb") as f:
        timings = pickle.load(f)

fig_time, axes_time, df_time, time_palette = plot_eval_time_rank_single_plot(
    timings,
    latex_params=latex_utils.latex_prms_singlecol,
    usetex=True,
    title="",
    name_to_color=name_to_color,
    figsize=(3.5, 2),
    savepath=plot_dir + "eval_time_rank_singleplot.pdf",
)