# Checking linearity

In [None]:
import sys
import os
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import re
import numpy as np
from tqdm.auto import tqdm

cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
base_dir = os.path.dirname(parent_dir)
src_dir = os.path.join(base_dir, 'src')
plot_dir = cwd + "/plots/"

octomag_data_path = os.path.join(base_dir, 'data', 'octomag_data', 'data')
octomag_sensor_biases_path = os.path.join(octomag_data_path, 'sensor_bias.csv')
navion_data_path = os.path.join(base_dir, "data", "navion_data", "data")


sys.path.insert(0, src_dir)

from data_analysis import load_data, convert_frames, plot_positions, plot_quiver_slice, plot_quiver_3d, correct_sensor_bias, load_raw_set_of_pkls, load_octomag_format
from paper import latex_utils

## Load data

### Octomag data

In [None]:
df_octomag = convert_frames(correct_sensor_bias(load_data(octomag_data_path), bias_csv=octomag_sensor_biases_path))

field_cols = ['Bx', 'By', 'Bz']
em_cols_octomag = [col for col in df_octomag.columns if col.startswith('em_')]

df_octomag["current"] = df_octomag[em_cols_octomag[0]]

# --- quantize position to 1.6 mm and create pos_id ---
res = 0.001
df_octomag["x_q"] = np.round(df_octomag["x"] / res) * res
df_octomag["y_q"] = np.round(df_octomag["y"] / res) * res
df_octomag["z_q"] = np.round(df_octomag["z"] / res) * res

df_octomag["pos_id"] = pd.factorize(list(zip(df_octomag["x_q"], df_octomag["y_q"], df_octomag["z_q"])))[0].astype(np.int32)

# --- active coil + current (exactly one coil nonzero per row) ---
df_octomag["active_coil"] = df_octomag[em_cols_octomag].ne(0).idxmax(axis=1)
df_octomag["current"] = df_octomag[em_cols_octomag].sum(axis=1).astype(np.float64)

### Navion data

In [None]:
df_navion = convert_frames(load_raw_set_of_pkls(navion_data_path))
em_cols_navion = [col for col in df_navion.columns if col.startswith('em_')]
df_navion[field_cols] *=1000 # T to mT

# --- quantize position to 1 mm and create pos_id ---
res = 0.001
df_navion["x_q"] = np.round(df_navion["x"] / res) * res
df_navion["y_q"] = np.round(df_navion["y"] / res) * res
df_navion["z_q"] = np.round(df_navion["z"] / res) * res

df_navion["pos_id"] = pd.factorize(list(zip(df_navion["x_q"], df_navion["y_q"], df_navion["z_q"])))[0].astype(np.int32)
# --- active coil + current (exactly one coil nonzero per row) ---
df_navion["active_coil"] = df_navion[em_cols_navion].ne(0).idxmax(axis=1)
df_navion["current"] = df_navion[em_cols_navion].sum(axis=1).astype(np.float64)

# make sure these are numeric
df_navion[field_cols] = df_navion[field_cols].astype(np.float64)

df_navion

# Compute linear fits

In [None]:
from tqdm import tqdm

def add_pos_id(df: pd.DataFrame, res: float = 0.001) -> pd.DataFrame:
    """Quantize (x,y,z) and factorize into pos_id."""
    out = df.copy()
    out["x_q"] = np.round(out["x"] / res) * res
    out["y_q"] = np.round(out["y"] / res) * res
    out["z_q"] = np.round(out["z"] / res) * res
    out["pos_id"] = pd.factorize(list(zip(out["x_q"], out["y_q"], out["z_q"])))[0].astype(np.int32)
    return out

def add_active_coil_and_current(df: pd.DataFrame, em_cols) -> pd.DataFrame:
    """Assumes exactly one coil is nonzero or you want the max coil as 'active'."""
    out = df.copy()
    out["active_coil"] = out[em_cols].idxmax(axis=1)
    out["current"] = out[em_cols].max(axis=1).astype(np.float64)
    return out

def compute_linear_fits(
    df: pd.DataFrame,
    dataset: str,
    group_cols=("pos_id", "active_coil"),
    field_cols=field_cols,
    min_n: int = 3,
    keep_pos=True,
    pos_cols_q=("x_q", "y_q", "z_q"),
    pos_cols=("x", "y", "z"),
) -> pd.DataFrame:
    """
    Fits Bx,By,Bz vs current with y = a*x + b per group.
    Keeps group keys + (optionally) position columns in the output.
    """
    rows = []

    have_q = all(c in df.columns for c in pos_cols_q)
    have_xyz = all(c in df.columns for c in pos_cols)

    groups = list(df.groupby(list(group_cols), sort=False))
    
    for key, g in tqdm(groups, desc=f"Computing fits for {dataset}"):
        if len(g) < min_n:
            continue

        x = g["current"].to_numpy(dtype=np.float64)
        if len(x) < 2 or np.unique(x).size < 2:
            continue

        Y = g[list(field_cols)].to_numpy(dtype=np.float64)  # (N,3)

        slopes = np.zeros(len(field_cols), dtype=np.float64)
        intercepts = np.zeros(len(field_cols), dtype=np.float64)
        Y_pred = np.empty_like(Y)

        for i in range(len(field_cols)):
            a, b = np.polyfit(x, Y[:, i], deg=1)
            slopes[i], intercepts[i] = a, b
            Y_pred[:, i] = a * x + b

        sse = np.sum((Y - Y_pred) ** 2)
        sst = np.sum((Y - Y.mean(axis=0, keepdims=True)) ** 2)
        r2 = np.nan if sst == 0 else 1 - sse / sst

        E = Y - Y_pred
        rmse = np.sqrt(np.mean(np.linalg.norm(E, axis=1) ** 2))

        rec = {
            "dataset": dataset,
            "slopes": slopes,
            "intercepts": intercepts,
            "r2": r2,
            "rmse": rmse,
            "n": len(g),
        }

        # unpack group keys into columns
        if not isinstance(key, tuple):
            key = (key,)
        for col, val in zip(group_cols, key):
            rec[col] = val

        # --- keep position in fits ---
        if keep_pos:
            if have_q:
                # should be constant inside the group; take first
                rec["x_q"], rec["y_q"], rec["z_q"] = (float(g[pos_cols_q[0]].iloc[0]),
                                                     float(g[pos_cols_q[1]].iloc[0]),
                                                     float(g[pos_cols_q[2]].iloc[0]))
            if have_xyz:
                # in case raw xyz varies slightly, store mean
                rec["x_mean"] = float(g[pos_cols[0]].mean())
                rec["y_mean"] = float(g[pos_cols[1]].mean())
                rec["z_mean"] = float(g[pos_cols[2]].mean())

        rec["group_id"] = " | ".join([f"{c}={rec[c]}" for c in group_cols])
        rows.append(rec)

    fits = pd.DataFrame(rows)

    if len(fits):
        fits["intercept_mag"] = fits["intercepts"].apply(lambda v: float(np.linalg.norm(v)))
        fits["slope_mag"] = fits["slopes"].apply(lambda v: float(np.linalg.norm(v)))

    return fits

def get_fit_row(fits_df: pd.DataFrame, **keys) -> pd.Series:
    m = np.ones(len(fits_df), dtype=bool)
    for k, v in keys.items():
        m &= (fits_df[k] == v)
    sub = fits_df[m]
    if len(sub) != 1:
        raise ValueError(f"Expected 1 fit row for {keys}, got {len(sub)}")
    return sub.iloc[0]

def plot_group_fit_oneplot(
    df: pd.DataFrame,
    fits: pd.DataFrame,
    title: str = "",
    field_cols=field_cols,
    legend=True,
    **group_keys,
):
    """Generic: works for OctoMag or Navion as long as group_keys matches group_cols used in fits."""
    g = df.copy()
    for k, v in group_keys.items():
        g = g[g[k] == v]
    if len(g) == 0:
        raise ValueError(f"No data for {group_keys}")

    fit = get_fit_row(fits, **group_keys)

    x = g["current"].to_numpy(float)
    Y = g[list(field_cols)].to_numpy(float)

    xx = np.linspace(np.min(x), np.max(x), 200)

    fig, ax = plt.subplots(1, 1, figsize=(6.5, 4.0), constrained_layout=True)
    for i, c in enumerate(field_cols):
        ax.scatter(x, Y[:, i], s=12, alpha=0.55, label="_nolegend_")
        ax.plot(xx, fit["slopes"][i] * xx + fit["intercepts"][i], "--", linewidth=1.6, label=c)

    ax.set_title(title)
    ax.set_xlabel("Current (A)")
    ax.set_ylabel("B (mT)")
    ax.grid(True, alpha=0.25)
    if legend:
        ax.legend(frameon=False, ncol=3)

    return fig, ax

In [None]:
def plot_linearity_boxplots(fits_list, clip_q=0.995):
    # fits_list: [(name, fits_df), ...]
    dfs = []
    for name, fits in fits_list:
        d = fits[["r2", "intercept_mag"]].copy()
        d["dataset"] = name
        dfs.append(d)
    plot_df = pd.concat(dfs, ignore_index=True)
    plot_df = plot_df[np.isfinite(plot_df["r2"]) & np.isfinite(plot_df["intercept_mag"])].copy()

    imax = plot_df["intercept_mag"].quantile(clip_q)
    plot_df["intercept_mag_plot"] = plot_df["intercept_mag"].clip(upper=imax)

    fig, axes = plt.subplots(2, 1, figsize=(3.35, 4), constrained_layout=True)
    sns.boxplot(data=plot_df, x="dataset", y="r2", ax=axes[0], showfliers=False)
    axes[0].set_ylabel(r"$R^2$")
    axes[0].set_xlabel("")
    axes[0].set_ylim(0.99, 1.0)

    sns.boxplot(data=plot_df, x="dataset", y="intercept_mag_plot", ax=axes[1], showfliers=False)
    axes[1].set_ylabel(r"$\|\mathbf{b}\|$ (mT)")
    axes[1].set_xlabel("")

    return fig, axes

In [None]:
oct_fits = compute_linear_fits(df_octomag, dataset="OctoMag", group_cols=("pos_id","active_coil"), min_n=3)
nav_fits = compute_linear_fits(df_navion,  dataset="Navion",  group_cols=("pos_id","active_coil"), min_n=3)

In [None]:
def make_combined_figure_poscoil(
    df_octomag, octomag_fits,
    df_navion, navion_fits,

    # choose which line-panels to show
    oct_pos_id=0,
    oct_active_coil="em_2",
    nav_pos_id=0,
    nav_active_coil="em_2",

    # intercept clipping for boxplot
    clip_q=1.0,

    # legend tuning
    line_legend_loc="upper left",

    # --- LaTeX + output ---
    latex_params=None,
    usetex=True,
    out="combined_figure.pdf",
    save=True,
    show=True,

    # --- layout ---
    box_hspace=0.06,
    wspace=0.25,
    line_wspace=0.55,
    left=0.07, right=0.99, bottom=0.14, top=0.92,

    # --- legend font ---
    legend_fontsize=None,
    legend_font_delta=2,

    # --- panel labels ---
    panel_labels=("a", "b", "c"),
    panel_label_fs=None,
    panel_label_y_offset=0.15,

    # --- baseline alignment tweaks ---
    line_xlabel_pad=4.0,
    box_xtick_pad=10.0,

    # --- box styling ---
    box_width=0.35,
    ds_palette=None, # dict {"OctoMag": color, "Navion": color} or None

    # --- R^2 axis padding / limits ---
    r2_lower=0.985,
    r2_bottom_pad=0.0005,
    r2_top_pad=0.0005,
):
    

    sns.set_theme(style="whitegrid", context="paper")

    ds_order = ["OctoMag", "Navion"]
    if ds_palette is None:
        ds_palette = {
            "OctoMag": sns.color_palette("deep")[0],
            "Navion":  sns.color_palette("deep")[1],
        }

    field_cols = ["Bx", "By", "Bz"]
    comp_palette = {
        "Bx": sns.color_palette("colorblind")[0],
        "By": sns.color_palette("colorblind")[1],
        "Bz": sns.color_palette("colorblind")[2],
    }
    comp_label = {"Bx": r"$b_x$", "By": r"$b_y$", "Bz": r"$b_z$"}

    def _get_fit_row(fits_df, **keys):
        m = np.ones(len(fits_df), dtype=bool)
        for k, v in keys.items():
            m &= (fits_df[k] == v)
        sub = fits_df[m]
        if len(sub) == 0:
            raise ValueError(f"No fit row found for {keys}")
        if len(sub) > 1:
            raise ValueError(f"Multiple fit rows found for {keys} (expected 1).")
        return sub.iloc[0]

    def _plot_lines_panel(
        ax, x, Y, slopes, intercepts,
        title="",
        y_label=r"$b$ (mT)",
        show_legend=True,
        legend_loc=None,
        legend_fs=None,
        xlabel_pad=None,
    ):
        x = np.asarray(x, dtype=float).ravel()
        Y = np.asarray(Y, dtype=float)
        slopes = np.asarray(slopes, dtype=float).ravel()
        intercepts = np.asarray(intercepts, dtype=float).ravel()

        xx = np.linspace(np.min(x), np.max(x), 200)

        for i, c in enumerate(field_cols):
            color = comp_palette[c]
            ax.scatter(
                x, Y[:, i],
                s=35, alpha=0.55,
                color=color, edgecolor="none",
                label="_nolegend_"
            )
            ax.plot(
                xx, slopes[i] * xx + intercepts[i],
                "--", linewidth=1.6, color=color,
                label=comp_label[c]
            )

        if title:
            ax.set_title(title)

        ax.set_xlabel("Current (A)", labelpad=xlabel_pad)
        ax.set_ylabel(y_label)
        ax.grid(True, alpha=0.25)

        if show_legend:
            ax.legend(
                loc=(line_legend_loc if legend_loc is None else legend_loc),
                frameon=False,
                fontsize=legend_fs,
                handlelength=2.4,
                handletextpad=0.6,
                borderaxespad=0.2,
            )

    def _add_panel_labels_fig(fig, axes, labels, y_offset, fs):
        y0s = [ax.get_position().y0 for ax in axes]
        y = min(y0s) - y_offset
        for ax, lab in zip(axes, labels):
            pos = ax.get_position()
            x = 0.5 * (pos.x0 + pos.x1)
            fig.text(x, y, f"({lab})", ha="center", va="top", fontsize=fs)

    # -----------------------
    # Boxplot dataframe
    # -----------------------
    oct_plot = octomag_fits[["r2", "intercept_mag"]].copy()
    oct_plot["dataset"] = "OctoMag"
    nav_plot = navion_fits[["r2", "intercept_mag"]].copy()
    nav_plot["dataset"] = "Navion"

    plot_df = pd.concat([oct_plot, nav_plot], ignore_index=True)
    plot_df = plot_df[np.isfinite(plot_df["r2"]) & np.isfinite(plot_df["intercept_mag"])].copy()

    imax = plot_df["intercept_mag"].quantile(clip_q)
    plot_df["intercept_mag_plot"] = plot_df["intercept_mag"].clip(upper=imax)

    # --- R^2 limits (FIXED: never go below r2_lower) ---
    r2_min = float(plot_df["r2"].min()) if len(plot_df) else float(r2_lower)
    r2_bottom = max(float(r2_lower), r2_min - float(r2_bottom_pad))  # <-- hard clamp
    r2_top = 1.0 + float(r2_top_pad)

    # -----------------------
    # OctoMag selection (pos_id + active_coil)
    # -----------------------
    g_oct = df_octomag[(df_octomag["pos_id"] == oct_pos_id) & (df_octomag["active_coil"] == oct_active_coil)].copy()
    if len(g_oct) == 0:
        raise ValueError(f"No OctoMag data for pos_id={oct_pos_id}, active_coil={oct_active_coil}")
    fit_oct = _get_fit_row(octomag_fits, pos_id=oct_pos_id, active_coil=oct_active_coil)
    x_oct = g_oct["current"].to_numpy(dtype=float)
    Y_oct = g_oct[field_cols].to_numpy(dtype=float)

    # -----------------------
    # Navion selection (pos_id + active_coil)
    # -----------------------
    g_nav = df_navion[(df_navion["pos_id"] == nav_pos_id) & (df_navion["active_coil"] == nav_active_coil)].copy()
    if len(g_nav) == 0:
        raise ValueError(f"No Navion data for pos_id={nav_pos_id}, active_coil={nav_active_coil}")
    fit_nav = _get_fit_row(navion_fits, pos_id=nav_pos_id, active_coil=nav_active_coil)
    x_nav = g_nav["current"].to_numpy(dtype=float)
    Y_nav = g_nav[field_cols].to_numpy(dtype=float)

    # -----------------------
    # Plotting (LaTeX rc_context)
    # -----------------------
    latex_params = {} if latex_params is None else dict(latex_params)
    latex_params.setdefault("usetex", usetex)

    with latex_utils.rc_context_latex(**latex_params):
        fig = plt.figure(figsize=(12.5, 4.2), constrained_layout=False)

        outer = fig.add_gridspec(
            nrows=1, ncols=2,
            width_ratios=[1.05, 2.4],
            wspace=wspace
        )
        left_gs = outer[0].subgridspec(nrows=2, ncols=1, hspace=box_hspace)
        right_gs = outer[1].subgridspec(nrows=1, ncols=2, wspace=line_wspace)

        ax_r2  = fig.add_subplot(left_gs[0, 0])
        ax_int = fig.add_subplot(left_gs[1, 0])
        ax_oct = fig.add_subplot(right_gs[0, 0])
        ax_nav = fig.add_subplot(right_gs[0, 1])

        fig.subplots_adjust(left=left, right=right, bottom=bottom, top=top)

        axes_label_fs = plt.rcParams.get("axes.labelsize", 10)
        if legend_fontsize is None:
            legend_fontsize = axes_label_fs + legend_font_delta
        if panel_label_fs is None:
            panel_label_fs = axes_label_fs

        # --- R^2 boxplot
        sns.boxplot(
            data=plot_df, x="dataset", y="r2",
            order=ds_order, palette=ds_palette,
            width=box_width,
            showfliers=False, linewidth=1.0, ax=ax_r2
        )
        ax_r2.set_ylabel(r"$R^2$")
        ax_r2.set_xlabel("")
        ax_r2.set_ylim(r2_bottom, r2_top)
        ax_r2.tick_params(axis="x", labelbottom=False)

        # --- intercept magnitude boxplot
        sns.boxplot(
            data=plot_df, x="dataset", y="intercept_mag_plot",
            order=ds_order, palette=ds_palette,
            width=box_width,
            showfliers=False, linewidth=1.0, ax=ax_int
        )
        ax_int.set_ylabel(r"$\|\field_0\|$ (mT)")
        ax_int.set_xlabel("")

        ax_int.tick_params(axis="x", labelsize=axes_label_fs, rotation=0, pad=box_xtick_pad)
        for t in ax_int.get_xticklabels():
            t.set_horizontalalignment("center")

        for ax in (ax_r2, ax_int):
            leg = ax.get_legend()
            if leg is not None:
                leg.remove()

        fig.align_ylabels([ax_r2, ax_int])

        # --- line panels
        _plot_lines_panel(
            ax_oct, x_oct, Y_oct,
            slopes=fit_oct["slopes"], intercepts=fit_oct["intercepts"],
            title="",
            y_label=r"$b$ (mT)",
            show_legend=True,
            legend_loc="upper right",
            legend_fs=legend_fontsize,
            xlabel_pad=line_xlabel_pad,
        )

        _plot_lines_panel(
            ax_nav, x_nav, Y_nav,
            slopes=fit_nav["slopes"], intercepts=fit_nav["intercepts"],
            title="",
            y_label=r"$b$ (mT)",
            show_legend=True,
            legend_loc=None,
            legend_fs=legend_fontsize,
            xlabel_pad=line_xlabel_pad,
        )

        # bring y-labels closer to the axes (smaller pad = closer)
        ax_oct.set_ylabel(r"$b$ (mT)", labelpad=-8)
        ax_nav.set_ylabel(r"$b$ (mT)", labelpad=-8)

        # --- panel labels
        a, b, c = panel_labels
        _add_panel_labels_fig(
            fig,
            axes=[ax_int, ax_oct, ax_nav],
            labels=[a, b, c],
            y_offset=panel_label_y_offset,
            fs=panel_label_fs,
        )

        if save and out:
            # Ensure /plots dir exists
            plot_dir = os.path.dirname(out)
            os.makedirs(plot_dir, exist_ok=True)
            fig.savefig(out, bbox_inches="tight")

        if show:
            plt.show()
        else:
            plt.close(fig)

    return fig

In [None]:
oct_pos_id = 10483
nav_pos_id = 3

fig = make_combined_figure_poscoil(
    df_octomag, oct_fits,
    df_navion,  nav_fits,
    oct_pos_id=oct_pos_id, oct_active_coil="em_3",
    nav_pos_id=nav_pos_id,     nav_active_coil="em_2",
    latex_params=latex_utils.latex_prms_2img,
    usetex=True,
    out=plot_dir + "linearity_check.pdf",
    save=True,
    show=True,
    box_width=0.28,
    ds_palette={"OctoMag": "#4C72B0", "Navion": "#DD8452"},
    line_wspace=0.30,
    wspace=0.17,
    r2_top_pad=0.0005,
    r2_lower=0.985,
    r2_bottom_pad=0.0005,
)

In [None]:
print("OctoMag:", df_octomag.loc[df_octomag["pos_id"] == oct_pos_id, ["x_q","y_q","z_q","x","y","z"]].head(1))
print("Navion :", df_navion .loc[df_navion ["pos_id"] == nav_pos_id, ["x_q","y_q","z_q","x","y","z"]].head(1))