# Estimating “bias currents” from the affine model

The affine model writes the field as  
$$
\mathbf{b}(\mathbf{p}, \mathbf{i}) = \mathcal{A}_b(\mathbf{p})\,\mathbf{i} + \mathbf{b}_0(\mathbf{p}),
$$  

---

### Stacking several positions

Choose $k$ positions $\mathbf{p}_1,\dots,\mathbf{p}_k$ and evaluate the model in each, giving for each position $\mathbf{p}_i$ a field actuation matrix $\mathcal{A}_{b,i}$ and field offset $\mathbf{b}_{0,i}$.

Stack them:
$$
\mathcal{A}_{b,\text{stack}} =
\begin{bmatrix}
\mathcal{A}_{b,1}\\
\mathcal{A}_{b,2}\\
\vdots\\
\mathcal{A}_{b,k}
\end{bmatrix}
\in \mathbb{R}^{3k\times 8},\qquad
\mathbf{b}_{0,\text{stack}} =
\begin{bmatrix}
\mathbf b_1\\
\mathbf b_2\\
\vdots\\
\mathbf b_k
\end{bmatrix}
\in \mathbb R^{3k}.
$$

We want one current vector $\mathbf{i}_{\text{equiv}}$ such that
$$
\mathcal{A}_{b,\text{stack}} \,\mathbf{i}_{\text{equiv}} \approx \mathbf{b}_{0, \text{stack}}.
$$

This is a least-squares problem. If $\mathcal{A}_{b,\text{stack}}$ has full column rank (rank $8$), the least-squares solution is
$$
\mathbf{i}_{\text{equiv}} = \mathcal{A}_{b,\text{stack}}^{\dagger}\,\mathbf{b}_{0,\text{stack}},
$$
where $\mathcal{A}_{b,\text{stack}}^{\dagger}$ is the Moore–Penrose pseudoinverse.

Interpretation: $\mathbf{i}_{\text{equiv}}$ are the equivalent currents that best reproduce the learned bias field across all sampled positions.

---

### Rank, conditioning, and residual

We choose more positions until $\mathcal{A}_{b,\text{stack}}$  

- has rank $8$ (full column rank), and  
- has condition number $\kappa(\mathcal{A}_{b,\text{stack}})$ below some threshold,

so that the pseudoinverse and the inferred currents are numerically stable.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import pandas as pd
import os
import sys
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
base_dir = os.path.dirname(parent_dir)
data_dir = base_dir + "/data/octomag_data/split_dataset"
src_dir = base_dir + "/src"
params_dir = parent_dir + "/training/params"

sys.path.insert(0, src_dir)

from calibration import MPEM, MPEM_AVAILABLE, ActuationNet

In [None]:
# Load data
data_path = data_dir + "/training_data.pkl"
data = pd.read_pickle(data_path)

In [None]:
# Load affine model
model_path = params_dir + "/ActuationNet_100_512x512x512.pt"
model = ActuationNet.load_from(model_path)

## Notation alert:

In the following code we use $J$ (as in Jacobian) for the field actuation matrix $\mathcal{A}_b$

In [None]:
def build_well_conditioned_stack(
    data,
    affine_model,
    max_num_samples=50,
    cond_threshold=5,
    min_samples=3,
    svd_tol=1e-8,
    seed=0,
    verbose = True
):
    """
    Incrementally stack J (3x8) from AffineActuationNet evaluated at positions
    sampled from `data` until:

      - rank(J_stack) == 8 (full column rank), and
      - cond(J_stack) <= cond_threshold

    or we hit max_num_samples.

    J and b are converted to *physical units* if `normalization_params`
    (with keys "mean" and "std") is provided.

    Returns:
        J_pinv_final   : np.ndarray of shape (8, 3*k_used) or None if never valid
        cond_final     : float (condition number) or np.inf
        k_used         : int, number of samples used for best_cond
        J_stack_final  : np.ndarray of shape (3*k_used, 8) or None
        b_stack_final  : np.ndarray of shape (3*k_used,) or None
    """
    affine_model.eval()
    device = next(affine_model.parameters()).device

    N_rows = len(data)
    max_num_samples = min(max_num_samples, N_rows)

    rng = np.random.default_rng(seed) if type(seed) == int else seed
    sample_indices = rng.choice(N_rows, size=max_num_samples, replace=False)

    J_blocks = []  # list of (3, 8)
    b_blocks = []  # list of (3,)

    best_cond = np.inf
    best_J_pinv = None
    best_k = 0

    n_cols = 8  # J has shape (3, 8)

    for k, idx in enumerate(sample_indices, start=1):
        # position from DataFrame row
        row = data.iloc[idx]
        pos_np = row[["x", "y", "z"]].to_numpy(dtype=np.float32)  # (3,)
        pos_t  = torch.tensor(pos_np, dtype=torch.float32, device=device).unsqueeze(0)  # (1, 3)
        with torch.no_grad():
            J_batch, b_batch = affine_model(pos_t, currents=None)  # (1,3,8), (1,3)

        J_phys = J_batch[0].cpu().numpy()  # (3, 8)
        b_phys = b_batch[0].cpu().numpy()  # (3,)

        J_blocks.append(J_phys)
        b_blocks.append(b_phys)

        # stack so far
        J_stack = np.vstack(J_blocks)  # (3*k, 8)
        n_rows = J_stack.shape[0]

        if verbose:
            print(f"After {k:2d} samples (row {idx}), J_stack shape = {J_stack.shape}")

        # Don't check conditioning until we have enough rows & samples
        if n_rows < n_cols or k < min_samples:
            continue

        # SVD, rank, and condition number
        U, S, Vh = np.linalg.svd(J_stack, full_matrices=False)

        # numerical rank
        rank = np.sum(S > svd_tol)

        if rank < n_cols:
            cond = np.inf
            if verbose:
                print(f"  rank(J_stack) = {rank} < {n_cols} -> ill-conditioned (cond = ∞)")
        else:
            sigma_max = S[0]
            sigma_min = S[rank - 1]
            cond = sigma_max / sigma_min
            if verbose:
                print(f"  rank(J_stack) = {rank}, cond(J_stack) = {cond:.3e}")

        # track best (smallest) cond so far if it's finite and full rank
        if np.isfinite(cond) and rank == n_cols and cond < best_cond:
            best_cond = cond
            best_J_pinv = np.linalg.pinv(J_stack)  # shape (8, 3*k)
            best_k = k

        # if full-rank AND well-conditioned enough, stop and return
        if np.isfinite(cond) and rank == n_cols and cond <= cond_threshold:
            if verbose:
                print(
                    f"✅ Reached full rank (8) and desired conditioning: "
                    f"rank = {rank}, cond = {cond:.3e} ≤ {cond_threshold:.3e}"
                )
            J_stack_final = np.vstack(J_blocks[:best_k])      # (3*best_k, 8)
            b_stack_final = np.concatenate(b_blocks[:best_k]) # (3*best_k,)
            return best_J_pinv, best_cond, best_k, J_stack_final, b_stack_final

    # If we get here, we never reached full-rank & good cond within max_num_samples
    if best_J_pinv is None:
        print(
            f"⚠️ Never achieved full column rank 8 within {max_num_samples} samples. "
            f"No valid pseudo-inverse returned."
        )
        return None, np.inf, 0, None, None

    print(
        f"⚠️ Did not reach cond <= {cond_threshold:.3e} "
        f"within {max_num_samples} samples, but best full-rank cond = {best_cond:.3e} at k={best_k}."
    )
    J_stack_final = np.vstack(J_blocks[:best_k])      # (3*best_k, 8)
    b_stack_final = np.concatenate(b_blocks[:best_k]) # (3*best_k,)
    return best_J_pinv, best_cond, best_k, J_stack_final, b_stack_final

In [None]:
J_pinv, cond_final, k_used, J_stack, b_stack = build_well_conditioned_stack(
    data=data,
    affine_model=model,
    max_num_samples=100,
    cond_threshold=5.0,
    seed=4,
)

print("Final cond:", cond_final)
print("Samples used:", k_used)
print("J_pinv shape:", None if J_pinv is None else J_pinv.shape)

if J_pinv is not None:
    # currents that reproduce the bias
    I_equiv = J_pinv @ b_stack          # shape (8,)

    print("Equivalent currents (A) that reproduce b:", I_equiv)

    # sanity check relative residual
    res = J_stack @ I_equiv - b_stack   # (3*k,)
    rel_res = np.linalg.norm(res) / np.linalg.norm(b_stack)
    print("Relative residual ||J I_equiv - b|| / ||b|| =", rel_res)

In [None]:
def run_multiple_stacks(
    data,
    affine_model,
    n_runs: int = 10,
    max_num_samples: int = 50,
    cond_threshold: float = 5.0,
    min_samples: int = 3,
    svd_tol: float = 1e-8,
    seed: int = 0,
    verbose: bool = True,
    store_matrices: bool = False,
):
    """
    Run build_well_conditioned_stack multiple times with a *single* RNG seed,
    collecting results into a DataFrame.

    For each run we store:
      - best_cond       : condition number of best stack
      - k_used          : number of samples used
      - success         : whether a valid pseudo-inverse was found
      - em_0,...        : estimated equivalent currents per coil
      - res_norm        : ||J_stack I_equiv - b_stack||
      - rel_res         : ||J_stack I_equiv - b_stack|| / ||b_stack||

    Optionally also store the matrices themselves.
    """

    rng = np.random.default_rng(seed)  # single generator

    # Coil column names (in desired order)
    coil_cols = ["em_0", "em_1", "em_2", "em_3", "em_4", "em_5", "em_7", "em_8"]

    rows = []

    for run_id in range(n_runs):
        J_pinv, cond, k_used, J_stack, b_stack = build_well_conditioned_stack(
            data=data,
            affine_model=affine_model,
            max_num_samples=max_num_samples,
            cond_threshold=cond_threshold,
            min_samples=min_samples,
            svd_tol=svd_tol,
            seed=rng,           # pass generator, not int
            verbose=verbose,
        )

        # Base row
        row = {
            "run_id": run_id,
            "max_num_samples": max_num_samples,
            "cond_threshold": cond_threshold,
            "min_samples": min_samples,
            "svd_tol": svd_tol,
            "best_cond": cond,
            "k_used": k_used,
            "success": (J_pinv is not None),
            "res_norm": np.nan,
            "rel_res": np.nan,
        }

        # Initialize coil columns as NaN
        for col in coil_cols:
            row[col] = np.nan

        if J_pinv is not None and J_stack is not None and b_stack is not None:
            # Equivalent currents that reproduce the stacked bias
            I_equiv = J_pinv @ b_stack        # (num_coils,)

            # Fill coil columns in order
            n_coils = min(len(coil_cols), len(I_equiv))
            for i in range(n_coils):
                row[coil_cols[i]] = float(I_equiv[i])

            # Residuals
            res = J_stack @ I_equiv - b_stack # (3*k_used,)
            res_norm = float(np.linalg.norm(res))
            b_norm = float(np.linalg.norm(b_stack))
            rel_res = res_norm / b_norm if b_norm > 0 else np.nan

            row["res_norm"] = res_norm
            row["rel_res"] = rel_res

        if store_matrices:
            row["J_pinv"] = J_pinv
            row["J_stack"] = J_stack
            row["b_stack"] = b_stack

        rows.append(row)

    df = pd.DataFrame(rows)
    return df

In [None]:
df_runs = run_multiple_stacks(
    data=data,
    affine_model=model,
    n_runs=50,
    max_num_samples=50,
    cond_threshold=5.0,
    min_samples=3,
    svd_tol=1e-8,
    seed=42,
    verbose=False,
    store_matrices=False,
)

df_runs.head()

In [None]:
def plot_stack_boxplots_with_currents(df, figsize=(10, 6), title=None):
    """
    Top row: boxplots for best_cond, k_used, res_norm (shared y-range).
    Bottom row: boxplots for em_* columns (one axis, separate y-scale).
    """
    metrics = ["best_cond", "k_used", "res_norm"]
    data_metrics = [df[m].dropna().to_numpy() for m in metrics]

    # em_* columns
    em_cols = [col for col in df.columns if col.startswith("em_")]
    em_data = [df[c].dropna().to_numpy() for c in em_cols]

    fig = plt.figure(figsize=figsize, constrained_layout=True)
    gs = fig.add_gridspec(2, 3)

    if title is not None:
        fig.suptitle(title, fontsize=14)

    # ---- Top row: metrics ----
    axes_metrics = []
    for i, (m, vals) in enumerate(zip(metrics, data_metrics)):
        ax = fig.add_subplot(gs[0, i])
        ax.boxplot(vals, showfliers=False)
        ax.set_title(m)
        ax.set_xticks([])
        if i == 0:
            ax.set_ylabel("Value")
        axes_metrics.append(ax)

    # ---- Bottom row: em_* currents ----
    ax_em = fig.add_subplot(gs[1, :])   # span all columns

    if em_cols:
        ax_em.boxplot(em_data, labels=em_cols, showfliers=False)
        ax_em.set_ylabel("Equivalent current (A)")
        ax_em.set_title("Estimated equivalent currents per coil")
        ax_em.tick_params(axis="x", rotation=30)
    else:
        ax_em.text(0.5, 0.5, "No em_* columns found", ha="center", va="center")
        ax_em.set_axis_off()

    plt.show()
    return fig, axes_metrics, ax_em

fig, axes_metrics, ax_em = plot_stack_boxplots_with_currents(
    df_runs,
    title="Stack selection + equivalent currents"
)