# Plotting the condition in the workspace

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import pickle
from matplotlib import lines as mlines
import re

cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
base_dir = os.path.dirname(parent_dir)
data_dir = base_dir + "/data/navion_data/split_dataset/"
src_dir = base_dir + "/src"
plot_dir = cwd + "/plots/"

results_path = parent_dir + "/workspace_analysis/results/results_navion.pkl"

sys.path.insert(0, src_dir)
from paper import latex_utils

## Load results from Navion workspace analysis

In [None]:
results = pickle.load(open(results_path, "rb"))

## Model names

In [None]:
model_order_names = [
    "MPEM Dipole",
    "MPEM Quadrupole",
    "MPEM Octopole"
]

model_segment_names = [
    "MPEM Segment 0",
    "MPEM Segment 1",
    "MPEM Segment 2",
]

## First plot: Dipole, Quadrupole, Octopole

In [None]:
def mm_to_in(mm):
    return mm / 25.4

def _to_grid_from_positions(values_1d, pos_xy):
    pos = np.asarray(pos_xy)
    x = pos[:, 0]
    y = pos[:, 1]

    x_coords = np.unique(x)
    y_coords = np.unique(y)

    x_to_i = {v: i for i, v in enumerate(x_coords)}
    y_to_j = {v: j for j, v in enumerate(y_coords)}

    Ny, Nx = len(y_coords), len(x_coords)
    grid = np.full((Ny, Nx), np.nan, dtype=float)
    values_1d = np.asarray(values_1d).reshape(-1)

    for k in range(values_1d.size):
        grid[y_to_j[y[k]], x_to_i[x[k]]] = values_1d[k]

    extent = [x_coords.min(), x_coords.max(), y_coords.min(), y_coords.max()]
    return grid, extent

def _clean_title(name: str) -> str:
    """Keep only Dipole/Quadrupole/Octopole or 'Segment i' (and shift segments by +1)."""
    name = name.replace("MPEM", "").strip()

    m = re.search(r"(Dipole|Quadrupole|Octopole)", name)
    if m:
        return m.group(1)

    m = re.search(r"Segment\s*(\d+)", name)
    if m:
        seg = int(m.group(1)) + 1  # 0,1,2 -> 1,2,3
        return f"Segment {seg}"

    return " ".join(name.split())

def plot_log_condition_two_columns(
    results,
    model_order_names,
    model_segment_names,
    apply_navion_transform=True,
    mask_infeasible=False,
    cmap="RdYlGn_r",
    rc_params=None,
    width_mm=140,
    height_mm=70,
    savepath=None,
    dashed_line=True,
    units_scale=100.0,
    units_label="cm",
    panel_labels=True,   # add (a), (b)
):
    assert len(model_order_names) == 3
    assert len(model_segment_names) == 3

    pos = np.asarray(results["positions"]).copy()
    if apply_navion_transform:
        pos[:, 0] = -pos[:, 0]
        pos[:, 1] = -pos[:, 1] + 0.2
    pos[:, 0] *= units_scale
    pos[:, 1] *= units_scale

    def model_grid(name):
        cond = np.asarray(results[name]["conditions"], dtype=float)
        if mask_infeasible:
            feas = np.asarray(results[name]["feasible"], dtype=bool)
            cond = cond.copy()
            cond[~feas] = np.nan

        logc = np.full_like(cond, np.nan, dtype=float)
        valid = np.isfinite(cond) & (cond > 0)
        logc[valid] = np.log10(cond[valid])

        grid, extent = _to_grid_from_positions(logc, pos[:, :2])
        return grid, extent

    order_grids, segment_grids = [], []
    extent = None
    for n in model_order_names:
        g, e = model_grid(n)
        order_grids.append(g)
        extent = e if extent is None else extent
    for n in model_segment_names:
        g, e = model_grid(n)
        segment_grids.append(g)
        extent = e if extent is None else extent

    all_grids = order_grids + segment_grids
    all_vals = np.concatenate([g[np.isfinite(g)].ravel() for g in all_grids])
    vmin, vmax = np.nanmin(all_vals), np.nanmax(all_vals)

    if rc_params is None:
        rc_params = latex_utils.latex_prms_2img
    rc_ctx = latex_utils.rc_context_latex(**rc_params)

    figsize = (mm_to_in(width_mm), mm_to_in(height_mm))

    with rc_ctx:
        fig, axes = plt.subplots(
            nrows=3, ncols=2,
            figsize=figsize,
            constrained_layout=True,
            sharex=True, sharey=True
        )

        mappable = None
        for i in range(3):
            # LEFT: segments (swapped)
            axL = axes[i, 0]
            imL = axL.imshow(
                segment_grids[i],
                origin="lower",
                extent=extent,
                aspect="equal",
                interpolation="nearest",
                cmap=cmap,
                vmin=vmin, vmax=vmax
            )
            mappable = imL
            axL.set_title(_clean_title(model_segment_names[i]), loc="center", pad=3)
            axL.set_ylabel(rf"$y$ ({units_label})")
            axL.grid(False)

            # RIGHT: order (swapped)
            axR = axes[i, 1]
            imR = axR.imshow(
                order_grids[i],
                origin="lower",
                extent=extent,
                aspect="equal",
                interpolation="nearest",
                cmap=cmap,
                vmin=vmin, vmax=vmax
            )
            mappable = imR
            axR.set_title(_clean_title(model_order_names[i]), loc="center", pad=3)
            axR.grid(False)

        axes[-1, 0].set_xlabel(rf"$x$ ({units_label})")
        axes[-1, 1].set_xlabel(rf"$x$ ({units_label})")

        cbar = fig.colorbar(
            mappable,
            ax=axes.ravel().tolist(),
            fraction=0.04,
            pad=0.02
        )
        cbar.set_label(r"$\log_{10}\!\left(\mathrm{cond}_2(\actuation_{\field})\right)$")

        fig.canvas.draw()

        if dashed_line:
            left_edge_right = axes[1, 0].get_position().x1
            right_edge_left = axes[1, 1].get_position().x0
            x_sep = 0.5 * (left_edge_right + right_edge_left)

            y0 = min(axes[r, 0].get_position().y0 for r in range(3))
            y1 = max(axes[r, 0].get_position().y1 for r in range(3))

            fig.add_artist(
                mlines.Line2D(
                    [x_sep, x_sep], [y0, y1],
                    transform=fig.transFigure,
                    linestyle=(0, (4, 4)),
                    linewidth=1.0,
                    color="0.25",
                    alpha=0.9,
                )
            )

        if panel_labels:
            col0 = axes[2, 0].get_position()
            col1 = axes[2, 1].get_position()

            x_a = 0.5 * (col0.x0 + col0.x1)
            x_b = 0.5 * (col1.x0 + col1.x1)

            y_bottom = min(col0.y0, col1.y0)
            y_lbl = y_bottom - 0.18

            fig.text(x_a, y_lbl, r"(a)", ha="center", va="top")
            fig.text(x_b, y_lbl, r"(b)", ha="center", va="top")

        if savepath is not None:
            os.makedirs(os.path.dirname(savepath), exist_ok=True)
            fig.savefig(savepath, bbox_inches="tight")

        plt.show()

    return fig, axes

In [None]:
plot_log_condition_two_columns(
    results=results,
    model_order_names=model_order_names,
    model_segment_names=model_segment_names,
    apply_navion_transform=True,
    mask_infeasible=False,
    cmap="RdYlGn_r",
    rc_params=latex_utils.latex_prms_singlecol,  # or latex_prms_2img / latex_prms_3img
    width_mm=140,
    height_mm=70,
    savepath=plot_dir + "condition_six_pack.pdf",
    dashed_line=True,
    units_scale=100.0,
    units_label="cm",
    panel_labels=True,
)