# Ransac experiment per coil and per position. We are fiting the ransac model to the affinity in currents.

# Note for reviewer:

A small error in this notebook was found during cleanup such that the percentage of outliers detected in the OctoMag dataset reported in the paper is slightly wrong. The corrected value will be updated in the revision.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import sys
from tqdm.auto import tqdm
from sklearn.linear_model import LinearRegression, RANSACRegressor


cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
base_dir = os.path.dirname(parent_dir)


src_dir = base_dir + "/src"

sys.path.insert(0, src_dir)

from data_analysis import data_processing, load_data, convert_frames, correct_sensor_bias, plot_quiver_slice, load_navion_format, plot_positions

## Select which eMNS dataset to use

In [None]:
#####################################
########### Which eMNS? #############
#####################################
emns = "octomag" # "octomag" or "navion"

if emns.lower() == "octomag":
    data_dir = base_dir + "/data/octomag_data/data"
    sensor_bias_file = data_dir + "/sensor_bias.csv"

    clean_data_dir = base_dir + "/data/octomag_data/clean_data"
    processed_data_file = clean_data_dir + "/outlier_data_currents.pkl"
    clean_data_file = clean_data_dir + "/clean_data.pkl"

    data_og = convert_frames(correct_sensor_bias(load_data(data_dir), bias_csv=sensor_bias_file))

elif emns.lower() == "navion":
    data_dir = base_dir + "/data/navion_data/data"
    clean_data_dir = data_dir + "/../clean_data"
    processed_data_file = clean_data_dir + "/outlier_data_currents.pkl"
    clean_data_file = clean_data_dir + "/clean_data.pkl"
    data_og = convert_frames(load_navion_format(data_dir))

else:
    raise ValueError("eMNS type not recognized. Please use 'octomag' or 'navion'.")

In [None]:
# Histogram of the magnitudes of the original dataset. Draw vertical lines for quartiles
plt.figure(figsize=(8, 5))
magnitudes = np.linalg.norm(data_og[["Bx", "By", "Bz"]].to_numpy(), axis=1)
plt.hist(magnitudes, bins=50, edgecolor='black')

# Calculate quartiles
quartiles = np.percentile(magnitudes, [25, 50, 75])
for q in quartiles:
    plt.axvline(q, color='r', linestyle='dashed', linewidth=1)

plt.xlabel("Magnetic Field Magnitude (mT)")
plt.ylabel("Number of Samples")
plt.title("Histogram of Magnetic Field Magnitudes. Quartiles marked in red dashed lines.")
plt.show()

# print the quartile values
print("Quartiles of Magnetic Field Magnitudes (mT):")
print(f"25th percentile: {quartiles[0]:.4f} mT")
print(f"50th percentile (median): {quartiles[1]:.4f} mT")
print(f"75th percentile: {quartiles[2]:.4f} mT")

# Store first quartile
first_quartile = quartiles[0]

In [None]:
data = data_og
total_data_points = len(data)
data

## Create a key to identify different positions

Consider coordinates as the same if they are within a threshold of each other

In [None]:
tol = 0.002 if emns.lower() == "navion" else 0.0016  
cols = ['x', 'y', 'z']

# Create a discrete key representing each tolerance-sized cell
df_key = (data[cols] / tol).round().astype(int)

# Turn that into a single string key per row
data['pos_key'] = df_key.astype(str).agg('_'.join, axis=1)

In [None]:
# Print number of unique positions
unique_positions = data['pos_key'].nunique()
print(f"Number of unique positions (within {tol*1000} mm tolerance): {unique_positions}")

# Stats on number of samples per position
samples_per_position = data['pos_key'].value_counts()
print("Samples per position stats:")
print(samples_per_position.describe())

## Also separate by active coil

In [None]:
em_cols = [col for col in data.columns if col.startswith('em_')]

candidate = data[em_cols].abs().idxmax(axis=1)

data['active_coil'] = candidate.where(data[em_cols].ne(0).any(axis=1))

In [None]:
# Rows with no active coil are NaN in 'active_coil' column
# Lets repeat each of these rows for each coil
has_coil = data['active_coil'].notna()
data_with = data[has_coil].copy()
data_without = data[~has_coil].copy()   # these have all em_* = 0

repeated = pd.concat(
    [data_without.assign(active_coil=coil) for coil in em_cols],
    ignore_index=True
)

data_expanded = pd.concat([data_with, repeated], ignore_index=True)

In [None]:
# Add the active coil to the key
data_expanded['pos_coil_key'] = data_expanded['pos_key'] + '_'  + data_expanded['active_coil']

In [None]:
keys = data_expanded['pos_coil_key'].unique()
print(f"Number of unique position+coil keys: {len(keys)}")

# Number of samples per key
samples_per_key = data_expanded['pos_coil_key'].value_counts()

In [None]:
min_samples = 5
small_counts = samples_per_key[samples_per_key < min_samples]

bad_keys = small_counts.index

# Set a bad key flag
data_expanded['bad_key'] = data_expanded['pos_coil_key'].isin(bad_keys)

data_good = data_expanded[~data_expanded['bad_key']].copy()

del data_expanded

## Now run a RANSAC affine regression per position + active coil (should be an affine mapping from current to field)

In [None]:
def ransac_one_group(
    df_group,
    min_distinct_currents=3,
    residual_percent=20.0,   # allowed error (%)
    max_trials=100,
    min_abs_threshold=3.27 if emns.lower() == "navion" else 2.66,
):
    """
    Run RANSAC for a single pos_coil_key group using a *relative* L2 vector loss.

    Model:
      B(I) = a * I + b, B in R^3 (Bx, By, Bz)

    Residual per sample i:
      e_i    = ||B_meas_i - B_pred_i||_2
      Bmag_i = ||B_meas_i||

      denom_i = max(Bmag_i, min_abs_threshold)
      rel_i   = e_i / denom_i   (dimensionless relative error)

    RANSAC inlier condition:
      rel_i <= residual_percent / 100
    """

    df = df_group

    field_cols = ["Bx", "By", "Bz"]
    current_col = df["active_coil"].iloc[0]  # e.g. "em_0", ...

    # Drop exact duplicates (same current + same field). Sort for reproducibility
    subset_cols = [current_col] + field_cols
    df = df.drop_duplicates(subset=subset_cols)

    # Extract arrays
    currents = df[current_col].to_numpy()
    B = df[field_cols].to_numpy()  # (n_samples, 3)
    n_points = len(df)
    n_distinct_currents = np.unique(currents).shape[0]

    # Init bookkeeping columns
    df["ransac_fit_ok"] = False
    df["ransac_n_points"] = n_points
    df["ransac_n_inliers"] = 0
    df["ransac_inlier"] = False

    for name in ["ransac_slope_x", "ransac_slope_y", "ransac_slope_z",
                 "ransac_offset_x", "ransac_offset_y", "ransac_offset_z"]:
        df[name] = np.nan

    # Not enough distinct currents → can't fit reliably
    if n_distinct_currents < min_distinct_currents:
        df["outlier_in_currents"] = True
        df["ransac_residual_norm"] = np.nan
        df["ransac_residual_rel"] = np.nan
        df["ransac_vector_outlier"] = True
        return df

    # Field magnitude in this group
    B_mag = np.linalg.norm(B, axis=1)

    

    # Relative threshold 
    residual_threshold_rel = residual_percent / 100.0

    # Prepare data for sklearn
    X = currents.reshape(-1, 1)  # (n_samples, 1)
    y = B                        # (n_samples, 3)

    base_model = LinearRegression(fit_intercept=True)

    # --- custom relative L2 loss ---
    def rel_l2_loss(y_true, y_pred, min_abs_threshold=min_abs_threshold):
        diff = y_true - y_pred
        num = np.linalg.norm(diff, axis=1)      # ||ΔB||
        Bmag = np.linalg.norm(y_true, axis=1)   # ||B||
        denom = np.maximum(Bmag, min_abs_threshold)   # elementwise max

        return num / denom                      # relative error per sample

    ransac = RANSACRegressor(
        estimator=base_model,
        min_samples=2,
        loss=rel_l2_loss,
        residual_threshold=residual_threshold_rel,
        max_trials=max_trials,
        random_state=0,
    )

    try:
        ransac.fit(X, y)
    except Exception as e:
        print(f"RANSAC failed for group {df['pos_coil_key'].iloc[0]}: {e}")
        df["outlier_in_currents"] = True
        df["ransac_residual_norm"] = np.nan
        df["ransac_residual_rel"] = np.nan
        df["ransac_vector_outlier"] = True
        return df

    inlier_mask = ransac.inlier_mask_
    n_inliers = int(inlier_mask.sum())

    df.loc[df.index, "ransac_inlier"] = inlier_mask
    df["ransac_fit_ok"] = True
    df["ransac_n_inliers"] = n_inliers

    # Slopes and offsets
    est = ransac.estimator_
    slopes = est.coef_.reshape(-1)
    offsets = est.intercept_.reshape(-1)

    df["ransac_slope_x"] = slopes[0]
    df["ransac_slope_y"] = slopes[1]
    df["ransac_slope_z"] = slopes[2]
    df["ransac_offset_x"] = offsets[0]
    df["ransac_offset_y"] = offsets[1]
    df["ransac_offset_z"] = offsets[2]

    # For diagnostics: absolute + relative residuals for the final model
    y_pred = ransac.predict(X)
    res_vec = y - y_pred
    res_norm = np.linalg.norm(res_vec, axis=1)

    denom = np.maximum(B_mag, min_abs_threshold)
    res_rel = res_norm / denom

    df["ransac_residual_norm"] = res_norm
    df["ransac_residual_rel"] = res_rel

    # Make main label consistent with that
    df["outlier"] = ~df["ransac_inlier"]  
        
    # Add current column
    df["current"] = currents


    return df

### Apply RANSAC

In [None]:
groups = data_good.groupby("pos_coil_key")
n_groups = len(groups)

results = []
for key, group in tqdm(groups, total=n_groups, desc="RANSAC per pos_coil_key"):
    results.append(ransac_one_group(group, min_abs_threshold=first_quartile))

del data_good

print("Concatenating results, this may take a while...")
data_ransac = pd.concat(results, ignore_index=True)
print("Done RANSAC processing.")

### Check results

In [None]:
num_outliers = data_ransac["outlier"].sum()
print(f"Total number of outlier samples detected by RANSAC: {num_outliers}, out of {len(data_ransac)} total samples. Corresponding to {num_outliers/len(data_ransac)*100:.2f}% of the dataset.")

outlier_df = data_ransac[data_ransac["outlier"]]

# Calculate magnetic field magnitude for outliers
B_magnitude = np.linalg.norm(outlier_df[["Bx", "By", "Bz"]].to_numpy(), axis=1)

# Plot count of outliers against magnetic field magnitude
plt.figure(figsize=(8, 5))
plt.hist(B_magnitude, bins=100, edgecolor='black')
plt.xlabel("Magnetic Field Magnitude (mT)")
plt.ylabel("Number of Outlier Samples")
plt.title("Distribution of Outlier Samples vs Magnetic Field Magnitude")
plt.show()

## Save the df

In [None]:
# Columns to keep 
keep_cols = ["x", "y", "z", "Bx", "By", "Bz"] + \
            em_cols + \
             (['lvl', 'pos', 'rot', 'base_x', 'base_y', 'z_offset'] if emns.lower() == "octomag" else [])

ransac_cols = ["ransac_slope_x", "ransac_slope_y", "ransac_slope_z",
             "ransac_offset_x", "ransac_offset_y", "ransac_offset_z",
             "ransac_residual_norm", "ransac_residual_rel",
             "ransac_inlier", "ransac_fit_ok",
             "ransac_n_points", "ransac_n_inliers", "outlier",
             "active_coil", "current"]


data_keep = data_ransac[keep_cols + ransac_cols]

In [None]:
# Make dir if not exists
os.makedirs(clean_data_dir, exist_ok=True)

print(f"Saving cleaned data to {processed_data_file}...")

data_keep.to_pickle(processed_data_file)

In [None]:
data_clean = data_keep[~data_keep["outlier"]].copy()
data_clean = data_clean[keep_cols]
data_clean.drop_duplicates(inplace=True)


print(f"Saving final clean data to {clean_data_file}...")
data_clean.to_pickle(clean_data_file)