# Launch the MPEM training ROS node

Note that this requires a functioning ROS installation as well as the framework from:
```text
External link retracted according to guidelines.
```

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import sys
import subprocess

cwd = os.getcwd()
base_dir = os.path.dirname(os.path.dirname(cwd))

## Which eMNS? And define initial parameters!

In [None]:
#####################################
########### Which eMNS? #############
#####################################
emns = "octomag"  # "navion" or "octomag"


if emns == "octomag":
    data_dir = base_dir + "/data/octomag_data/split_dataset"
    mpem_dir = base_dir + "/mpem"

    # Define path to initial and trained parameters
    initial_parameters_file = mpem_dir + "/initial_parameters/affine_first_order.yaml"
    trained_parameters_file = mpem_dir + "/optimized_parameters/optimized_dipole_5.yaml"
    
elif emns == "navion":
    data_dir = base_dir + "/data/navion_data/"
    mpem_dir = base_dir + "/mpem_navion"

    # Define path to initial and trained parameters
    initial_parameters_file = mpem_dir + "/initial_parameters/Navion_dipole_affine.yaml"
    trained_parameters_file = mpem_dir + "/optimized_parameters/Navion_dipole.yaml"

## Set data to use

In [None]:
training_data_file = data_dir + "/training_data_5.pkl"      # Obtained from downsampling using /notebooks/processing/downsample.ipynb
validation_data_file = data_dir + "/validation_data_5.pkl"  # Obtained from downsampling using /notebooks/processing/downsample.ipynb

mpem_data_dir = mpem_dir + "/data"
mpem_training_data_file = mpem_data_dir + "/training_data.csv"
mpem_validation_data_file = mpem_data_dir + "/validation_data.csv"



# Define emns name
emns_name = emns

## Load data

In [None]:
training_data = pd.read_pickle(training_data_file)
validation_data = pd.read_pickle(validation_data_file)

pos_cols = ["x", "y", "z"]
field_cols = ["Bx", "By", "Bz"]
em_cols = [col for col in training_data.columns if col.startswith("em_")]

pos_cols, field_cols, em_cols

## Format dataframe

csv columns should be: px,py,pz,i0,...,iN,bx,by,bz

In [None]:
mpem_pos_cols = ["px", "py", "pz"]
mpem_field_cols = ["bx", "by", "bz"]
mpem_em_cols = ["i"+str(i) for i in range(len(em_cols))]

mpem_pos_cols, mpem_field_cols, mpem_em_cols

In [None]:
# Keep only relevant columns
training_data = training_data[pos_cols + em_cols + field_cols]
validation_data = validation_data[pos_cols + em_cols + field_cols]

# Reduce dataset size for testing
to_keep = 1
training_data = training_data.sample(frac=to_keep, random_state=42).reset_index(drop=True)
validation_data = validation_data.sample(frac=to_keep, random_state=42).reset_index(drop=True)

# Change units (mT -> T)
training_data[field_cols] = training_data[field_cols] * 1e-3
validation_data[field_cols] = validation_data[field_cols] * 1e-3

# Rename columns to match MPEM format
rename_map = dict(
    zip(
        pos_cols + field_cols + em_cols,
        mpem_pos_cols + mpem_field_cols + mpem_em_cols,
    )
)

training_data = training_data.rename(columns=rename_map)
validation_data = validation_data.rename(columns=rename_map)

In [None]:
# Store as CSV for MPEM training
os.makedirs(mpem_data_dir, exist_ok=True)
training_data.to_csv(mpem_training_data_file, index=False)
validation_data.to_csv(mpem_validation_data_file, index=False)

## Launch training

In [None]:
cmd = [
    "rosrun", "mpem", "dipole-model-fit",
    emns_name,
    mpem_training_data_file,
    mpem_validation_data_file,
    initial_parameters_file,
    trained_parameters_file
]

print("Running MPEM training with command:")
print(" ".join(cmd))

try:
    subprocess.run(cmd, check=True)
    print("Whooo! MPEM training completed successfully.")
except subprocess.CalledProcessError as e:
    print(f"An error occurred during MPEM training: {e}")