In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import sys
import pickle


cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
base_dir = os.path.dirname(parent_dir)
data_dir = base_dir + "/data/navion_data/split_dataset/"
src_dir = base_dir + "/src"
nn_params_dir = parent_dir + "/training/params_navion"
mpem_params_dir = base_dir + "/mpem/mpem_navion/optimized_parameters"

results_dir = cwd + "/results/"
results_path = results_dir + "results_navion.pkl"
recompute = False

sys.path.insert(0, src_dir)
from calibration import MPEM, MPEM_AVAILABLE, ActuationNet, PotentialNet, DirectNet

## Load models

In [None]:
# Load models
segment_type = "cyl" # "cyl" or "lin". In the paper we used "cyl"
segments = [0, 1, 2]

if recompute:
    models = [
        MPEM(mpem_params_dir + f"/Navion_{segment_type}_{seg}.yaml", name = f"MPEM Segment {seg}") for seg in segments
    ] + [
        MPEM(mpem_params_dir + "/Navion_dipole.yaml", "MPEM Dipole"),
        MPEM(mpem_params_dir + "/Navion_quadrupole.yaml", "MPEM Quadrupole"),
        MPEM(mpem_params_dir + "/Navion_octopole.yaml", "MPEM Octopole")
    ]

    model_names = [model.name for model in models]

else:
    model_names = [
        "MPEM Segment 0",
        "MPEM Segment 1",
        "MPEM Segment 2",
        "MPEM Dipole",
        "MPEM Quadrupole",
        "MPEM Octopole"
    ]

## Prepare functions

In [None]:
def get_A_and_I(model, positions, target_field, As, Is):

    if len(As) != len(positions):
        As = np.empty((len(positions), 3, 3))
    if len(Is) != len(positions):
        Is = np.empty((len(positions), 3))
    
    for i, pos in enumerate(positions):
        A, b = model.currents_field_jacobian_and_bias(pos)
        A_pinv = np.linalg.pinv(A)
        I = A_pinv @ (target_field - b)

        As[i] = A
        Is[i] = I
    return As, Is


def compute_metrics(models, pos, target_field, max_current):
    # Make sure max_current is positive
    if max_current <= 0:
        raise Warning("Only silly billies use non-positive max current values!\nUsing the absolute value instead.")

    model_metrics = {"positions": pos, "target_field": target_field}

    for model in models:
        # Prepare
        As = np.empty((len(pos), 3, 3))
        Is = np.empty((len(pos), 3))

        # Compute A and I
        As, Is = get_A_and_I(model, pos, target_field, As, Is)
        
        # Compute condition numbers
        conditions = np.linalg.cond(As, p=2)

        # Compute margin
        current_margins = max_current - np.max(np.abs(Is), axis=1)

        # Compute feasibility mask
        feasibility_mask = current_margins >= 0

        model_dict = {
            "As": As,
            "Is": Is,
            "conditions": conditions,
            "current_margins": current_margins,
            "feasible": feasibility_mask
        }

        model_metrics[model.name] = model_dict

    return model_metrics

## Setup slice to analyze

In [None]:
if recompute:
    z_coord = 0.0
    x_min, x_max = -0.30, 0.30
    y_min, y_max = -0.20, 0.20
    voxel_size = 0.001


    x_coords = np.arange(x_min, x_max + voxel_size, voxel_size)
    y_coords = np.arange(y_min, y_max + voxel_size, voxel_size)
    xx, yy = np.meshgrid(x_coords, y_coords)

    positions = np.vstack([xx.ravel(), yy.ravel(), np.full(xx.size, z_coord)]).T


    print(f"Grid spans x: [{x_min}, {x_max}] m, y: [{y_min}, {y_max}] m at z: {z_coord} m")
    print(f"Voxel size: {voxel_size} m")
    print(f"Corresponds to {len(positions)} positions")

## Setup max currents and target field

In [None]:
if recompute:
    max_current = 25 # Amps
    target_field = np.array([0.0, 0.0, 20.0])  # mT

## Run analysis

In [None]:
if recompute:
    results = compute_metrics(models, positions, target_field, max_current)

    # Create results directory if it doesn't exist
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    
    # Save results
    with open(results_path, "wb") as f:
        pickle.dump(results, f)
else:
    with open(results_path, "rb") as f:
        results = pickle.load(f)

## Setup plotting

In [None]:
def plot_workspace_suite(model_metrics, positions, model_name=None,
                         log_condition=False, mask_infeasible=True,
                         apply_navion_transform=False, max_current=None,
                         data_points=None):
    """
    model_metrics: results[model.name] dict with keys:
        "conditions" (N,), "current_margins" (N,), "feasible" (N,)
    positions: (N,3) array used to compute the slice (must be a rectilinear grid in x-y)
    max_current: if not None, set current-margin colorbar limits to [-max_current, +max_current]
    data_points: optional array-like of shape (M,2) (x,y) points to overlay as circles
    """

    pos = np.asarray(positions).copy()

    if apply_navion_transform:
        pos[:, 0] = -pos[:, 0]
        pos[:, 1] = -pos[:, 1] + 0.2

    x = pos[:, 0]
    y = pos[:, 1]

    x_coords = np.unique(x)
    y_coords = np.unique(y)
    Nx, Ny = len(x_coords), len(y_coords)

    x_to_i = {v: i for i, v in enumerate(x_coords)}
    y_to_j = {v: j for j, v in enumerate(y_coords)}

    def to_grid(arr):
        arr = np.asarray(arr).reshape(-1)
        grid = np.full((Ny, Nx), np.nan)
        for k in range(len(arr)):
            grid[y_to_j[y[k]], x_to_i[x[k]]] = arr[k]
        return grid

    conditions = to_grid(model_metrics["conditions"])
    margins    = to_grid(model_metrics["current_margins"])
    feasible   = to_grid(model_metrics["feasible"]).astype(bool)

    cond_plot = conditions.copy()
    marg_plot = margins.copy()
    if mask_infeasible:
        cond_plot[~feasible] = np.nan
        marg_plot[~feasible] = np.nan

    if log_condition:
        cond_plot = np.log10(cond_plot)

    extent = [x_coords.min(), x_coords.max(), y_coords.min(), y_coords.max()]

    fig, axes = plt.subplots(1, 3, figsize=(18, 5), constrained_layout=True)

    # Condition heatmap
    im0 = axes[0].imshow(cond_plot, origin="lower", extent=extent, aspect="auto", interpolation="nearest")
    axes[0].set_title("log10(cond₂(A))" if log_condition else "cond₂(A)")
    axes[0].set_xlabel("x (m)")
    axes[0].set_ylabel("y (m)")
    cbar0 = fig.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)
    cbar0.set_label("log10(cond₂)" if log_condition else "cond₂")

    # Current margin heatmap
    im1_kwargs = dict(origin="lower", extent=extent, aspect="auto", interpolation="nearest")
    if max_current is not None:
        im1_kwargs.update(vmin=-abs(max_current), vmax=abs(max_current))

    im1 = axes[1].imshow(marg_plot, **im1_kwargs)
    axes[1].set_title("Current margin (A)")
    axes[1].set_xlabel("x (m)")
    axes[1].set_ylabel("y (m)")
    cbar1 = fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
    cbar1.set_label("A")

    # Feasibility mask
    im2 = axes[2].imshow(feasible.astype(float), origin="lower", extent=extent, aspect="auto",
                         interpolation="nearest", vmin=0, vmax=1)
    axes[2].set_title("Feasible")
    axes[2].set_xlabel("x (m)")
    axes[2].set_ylabel("y (m)")
    cbar2 = fig.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)
    cbar2.set_ticks([0, 1])
    cbar2.set_ticklabels(["No", "Yes"])

    # Overlay data points on all subplots
    if data_points is not None:
        dp = np.asarray(data_points, dtype=float)

        # allow passing (M,3) too; use first two columns
        if dp.ndim != 2 or dp.shape[1] < 2:
            raise ValueError(f"data_points must be (M,2) or (M,>=2). Got {dp.shape}")

        dp_x = dp[:, 0].copy()
        dp_y = dp[:, 1].copy()

        if apply_navion_transform:
            dp_x = -dp_x
            dp_y = -dp_y + 0.2

        for ax in axes:
            ax.scatter(dp_x, dp_y, s=40, facecolors="none", edgecolors="k",
                       linewidths=1.5, marker="o", zorder=10)
            
    for ax in axes:
        ax.set_aspect("equal", adjustable="box")

    title = f"{model_name} Workspace Suite" if model_name else "Workspace Suite"
    fig.suptitle(title, y=1.1)

    plt.show()

In [None]:
for model_name in model_names:
    results_model = results[model_name]
    positions = results["positions"]
    plot_workspace_suite(results_model, positions, model_name=model_name, apply_navion_transform=True, max_current=None)