In [2]:
# Install the sysidmjx package in editable mode
import sys
import subprocess

# Install the package
result = subprocess.run([sys.executable, "-m", "pip", "install", "-e", "/media/carlijn/500GB/mjx_sysid-main"], 
                       capture_output=True, text=True)
print(result.stdout)
if result.returncode != 0:
    print("STDERR:", result.stderr)
else:
    print("✓ sysidmjx installed successfully!")

Obtaining file:///media/carlijn/500GB/mjx_sysid-main
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Checking if build backend supports build_editable: started
  Checking if build backend supports build_editable: finished with status 'done'
  Getting requirements to build editable: started
  Getting requirements to build editable: finished with status 'done'
  Preparing editable metadata (pyproject.toml): started
  Preparing editable metadata (pyproject.toml): finished with status 'done'
Building wheels for collected packages: sysidmjx
  Building editable for sysidmjx (pyproject.toml): started
  Building editable for sysidmjx (pyproject.toml): finished with status 'done'
  Created wheel for sysidmjx: filename=sysidmjx-0.1.0-0.editable-py3-none-any.whl size=3653 sha256=7a23abae51d2f505947a87f5db4c0f86e825de0efa31eebb40d56811ee848c4d
  Stored in directory: /tmp/pip-ephem-wheel-cache-66i24uil/wheels/b3/99/b6/98b73b7b9fbacb21729f0760a

# Exoskeleton System Identification

This notebook optimizes physical parameters (armature, friction loss, damping) for all 10 actuated joints of the March exoskeleton.

## ⚠️ CRITICAL: Execution Order

**RESTART KERNEL, then run cells in this exact order:**

1. **Cell 2**: Imports and configuration
2. **Cell 4**: Define data loading function (does NOT load data yet)
3. **Cell 8**: Model setup - Updates JOINT_NAMES and loads data automatically
4. **Cell 10**: Parameter setup - Creates PID gains and optimization functions
5. **Cell 12**: Training loop

**The key fix:** Cell 8 now loads the model first, updates `JOINT_NAMES` to match the model's actuator order, THEN loads the data. This ensures the data columns are reordered to match the model's internal joint order.

## Data Requirements

Your CSV file should have the following columns:
- `timestamp`: Time in seconds
- For each joint in `[left_hip_aa, left_hip_fe, left_knee, left_ankle_dpf, left_ankle_ie, right_hip_aa, right_hip_fe, right_knee, right_ankle_dpf, right_ankle_ie]`:
  - `<joint_name>_pos`: Joint position (radians)
  - `<joint_name>_vel`: Joint velocity (rad/s)
  - `<joint_name>_ctrl`: Control input (desired position in radians, or current in Amps - will be converted)

Example CSV structure:
```
timestamp,left_hip_aa_pos,left_hip_aa_vel,left_hip_aa_ctrl,left_hip_fe_pos,...
0.000,0.1,0.0,0.15,0.2,...
0.001,0.101,0.1,0.15,0.21,...
```

## PID Gains Used
The optimization uses the following PID gains from your system:
- **Ankle DPF**: P=50, I=0, D=0.2
- **Hip AA**: P=80, I=0, D=0.4  
- **Hip FE**: P=120, I=0, D=0.8
- **Knee**: P=80, I=0, D=0.6
- **Ankle IE (Linear)**: P=3.0, I=0, D=0.03


In [3]:
import os
os.environ['JAX_PLATFORMS'] = 'cpu'

import jax
import jax.numpy as jnp
from flax.training import train_state
import optax
import mujoco
from mujoco import mjx
from typing import Callable, Dict
from jax import Array

import pandas as pd
import numpy as np
from sysidmjx.core import generate_loss_train_functions, get_batch
from jax import config

config.update("jax_debug_nans", True)
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_default_matmul_precision", "high")


class PARAMS:
    SEED = jax.random.PRNGKey(2)
    EXPERIMENT_NAME = "exo_optimization"

    class DATASET:
        PATH = "assets/exo/dataset/exo_data_1627.csv"  # Your exo data file
        DT = 0.001  # Adjust to your data sampling rate
        
    class SIM:
        PATH = "assets/exo/model/exo_hydrax.xml"
        INTEGRATOR = mujoco.mjtIntegrator.mjINT_EULER
        ITERATIONS = 1

    class TRAIN:
        EPOCH_NUM = 100
        BATCH_SIZE = 100  # Adjust based on your data size
        LEARNING_RATE = 1e-3
        TX = optax.adam(LEARNING_RATE)

# Joint names in your exoskeleton (excluding the free joint for safety catcher)
JOINT_NAMES = [
    'left_hip_aa', 'left_hip_fe', 'left_knee', 
    'left_ankle_dpf', 'left_ankle_ie',
    'right_hip_aa', 'right_hip_fe', 'right_knee',
    'right_ankle_dpf', 'right_ankle_ie'
]

# PID gains for each joint (P, I, D)
PID_GAINS = {
    'left_ankle_dpf': (50, 0, 0.2),
    'right_ankle_dpf': (50, 0, 0.2),
    'left_hip_aa': (80, 0, 0.4),
    'right_hip_aa': (80, 0, 0.4),
    'left_hip_fe': (120, 0, 0.8),
    'right_hip_fe': (120, 0, 0.8),
    'left_knee': (80, 0, 0.6),
    'right_knee': (80, 0, 0.6),
    'left_ankle_ie': (3.0, 0, 0.03),  # Linear actuators
    'right_ankle_ie': (3.0, 0, 0.03),
}

ModuleNotFoundError: No module named 'sysidmjx'

# Load and prepare exoskeleton data

This cell loads your exoskeleton data. Your CSV should have columns for:
- Time stamps
- Joint positions for all 10 joints
- Joint velocities for all 10 joints  
- Control inputs (currents or desired positions) for all 10 joints

In [None]:
def load_exo_data(csv_path):
    """
    Load exoskeleton data from CSV and reorder to match model actuator order.
    
    Expected CSV columns:
    - timestamp
    - For each joint: <joint_name>_pos, <joint_name>_vel, <joint_name>_ctrl
    
    IMPORTANT: The data will be reordered to match JOINT_NAMES (which reflects the model's actuator order)
    """
    print(f"Loading data from: {csv_path}")
    df = pd.read_csv(csv_path)
    print(f"CSV loaded: {len(df)} samples, {len(df.columns)} columns")
    print(f"Available columns: {list(df.columns)[:10]}...")  # Show first 10
    
    # Extract data for all joints IN THE ORDER SPECIFIED BY JOINT_NAMES
    n_samples = len(df)
    n_joints = len(JOINT_NAMES)
    
    qpos = np.zeros((n_samples, n_joints))
    qvel = np.zeros((n_samples, n_joints))
    qctrl = np.zeros((n_samples, n_joints))
    
    print(f"\nLoading data for {n_joints} joints in MODEL ORDER:")
    print(f"JOINT_NAMES (model actuator order): {JOINT_NAMES}")
    missing_cols = []
    
    for i, joint in enumerate(JOINT_NAMES):
        pos_col = f'{joint}_pos'
        vel_col = f'{joint}_vel'
        ctrl_col = f'{joint}_ctrl'
        
        print(f"  [{i}] {joint}:", end=" ")
        
        if pos_col not in df.columns:
            missing_cols.append(pos_col)
            print(f"✗ MISSING {pos_col}")
        else:
            qpos[:, i] = df[pos_col].values
            
        if vel_col not in df.columns:
            missing_cols.append(vel_col)
            print(f"✗ MISSING {vel_col}")
        else:
            qvel[:, i] = df[vel_col].values
            
        if ctrl_col not in df.columns:
            missing_cols.append(ctrl_col)
            print(f"✗ MISSING {ctrl_col}")
        else:
            qctrl[:, i] = df[ctrl_col].values
            
        if pos_col in df.columns and vel_col in df.columns and ctrl_col in df.columns:
            print(f"✓")
    
    if missing_cols:
        print(f"\n⚠️  ERROR: Missing columns in CSV: {missing_cols[:5]}...")
        print(f"Total missing: {len(missing_cols)} columns")
        raise ValueError(f"CSV file is missing required columns. First few: {missing_cols[:5]}")
    
    print(f"\n✓ Successfully loaded and reordered data for all {n_joints} joints")
    
    # Create next timestep targets
    qpos_next = np.roll(qpos, -1, axis=0)
    qpos_next[-1] = qpos[-1]  # Last sample
    
    dataset = {
        'qpos': jnp.array(qpos),
        'qvel': jnp.array(qvel),
        'qact': jnp.array(qctrl),
        'qpos_next': jnp.array(qpos_next)
    }
    
    print(f"Dataset shapes: qpos={dataset['qpos'].shape}, qvel={dataset['qvel'].shape}, qact={dataset['qact'].shape}")
    
    return dataset, df

print("✓ Data loading function defined. DO NOT RUN THIS CELL UNTIL AFTER MODEL SETUP!")

# Test with synthetic data (optional)

Before using real data, you can test the pipeline with synthetic data generated from the model.

In [None]:
# def generate_synthetic_test_data(n_samples=1000):
#     """
#     Generate synthetic data to test the optimization pipeline.
#     This simulates simple sinusoidal joint movements.
    
#     NOTE: Call this AFTER running the model setup cell so JOINT_NAMES is correctly populated.
#     """
#     dt = 0.001
#     t = np.arange(n_samples) * dt
    
#     n_joints = len(JOINT_NAMES)
#     qpos = np.zeros((n_samples, n_joints))
#     qvel = np.zeros((n_samples, n_joints))
#     qctrl = np.zeros((n_samples, n_joints))
    
#     # Generate sinusoidal trajectories for each joint
#     for i in range(n_joints):
#         amplitude = 0.2 + 0.1 * i  # Different amplitudes
#         frequency = 0.5 + 0.1 * i   # Different frequencies
#         phase = i * np.pi / n_joints
        
#         qpos[:, i] = amplitude * np.sin(2 * np.pi * frequency * t + phase)
#         qvel[:, i] = amplitude * 2 * np.pi * frequency * np.cos(2 * np.pi * frequency * t + phase)
#         qctrl[:, i] = qpos[:, i]  # Perfect tracking for testing
    
#     # Create next timestep
#     qpos_next = np.roll(qpos, -1, axis=0)
#     qpos_next[-1] = qpos[-1]
    
#     dataset = {
#         'qpos': jnp.array(qpos),
#         'qvel': jnp.array(qvel),
#         'qact': jnp.array(qctrl),
#         'qpos_next': jnp.array(qpos_next)
#     }
    
#     # Create dataframe for visualization
#     df_data = {'timestamp': t}
#     for i, joint in enumerate(JOINT_NAMES):
#         df_data[f'{joint}_pos'] = qpos[:, i]
#         df_data[f'{joint}_vel'] = qvel[:, i]
#         df_data[f'{joint}_ctrl'] = qctrl[:, i]
#     df = pd.DataFrame(df_data)
    
#     print(f"Generated {n_samples} synthetic data samples for {n_joints} joints")
#     print(f"Data shape: qpos={dataset['qpos'].shape}, qvel={dataset['qvel'].shape}")
#     print(f"Joints: {JOINT_NAMES}")
    
#     return dataset, df

# # Use synthetic data for testing:
# dataset, df = generate_synthetic_test_data(n_samples=500)
# print("\n✓ Synthetic test data generated. You can now run the optimization cells.")

# Setup model

In [None]:
mj_model = mujoco.MjModel.from_xml_path(PARAMS.SIM.PATH)
mj_model.opt.timestep = PARAMS.DATASET.DT
mj_model.opt.iterations = PARAMS.SIM.ITERATIONS
mj_model.opt.integrator = PARAMS.SIM.INTEGRATOR
mj_data = mujoco.MjData(mj_model)

# Print model info
print(f"integrator: {mj_model.opt.integrator}")
print(f"timestep: {mj_model.opt.timestep}")
print(f"iterations: {mj_model.opt.iterations}")
print(f"Number of DOF: {mj_model.nv}")
print(f"Number of actuators: {mj_model.nu}")

# Get actual actuator names from model (this is the TRUE order!)
model_actuators = [mujoco.mj_id2name(mj_model, mujoco.mjtObj.mjOBJ_ACTUATOR, i) for i in range(mj_model.nu)]
print(f"Actuators (MODEL ORDER): {model_actuators}")

# CRITICAL: Update JOINT_NAMES to match the model's actuator order!
print(f"\n⚠️  IMPORTANT: Updating JOINT_NAMES to match model actuator order")
print(f"Old JOINT_NAMES: {JOINT_NAMES}")
JOINT_NAMES.clear()
JOINT_NAMES.extend(model_actuators)
print(f"New JOINT_NAMES: {JOINT_NAMES}")

# Verify all joints exist
print(f"\nVerifying joints in model:")
for joint_name in JOINT_NAMES:
    try:
        jnt_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_JOINT, joint_name)
        print(f"  ✓ {joint_name}: ID={jnt_id}")
    except:
        print(f"  ✗ {joint_name}: NOT FOUND in model")

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

print(f"\n✓ Model setup complete!")
print(f"⚠️  NOW load the data by running: dataset, df = load_exo_data(PARAMS.DATASET.PATH)")

# Load data with corrected JOINT_NAMES
dataset, df = load_exo_data(PARAMS.DATASET.PATH)
print(f"\n✓ Data loaded and reordered to match model actuator order!")

# Identify parameters and function they use

In [None]:
# Generate random keys for initialization
n_joints = len(JOINT_NAMES)
keys = jax.random.split(PARAMS.SEED, num=n_joints + 2)

print(f"\n{'='*60}")
print(f"OPTIMIZATION SETUP")
print(f"{'='*60}")
print(f"Number of joints to optimize: {n_joints}")
print(f"Joint names: {JOINT_NAMES}")

# Define initial parameters for all joints (set to zeros for baseline)
zero_params = {
    "armature": jnp.zeros(n_joints),
    "frictionloss": jnp.zeros(n_joints),
    "damping": jnp.zeros(n_joints),
}

# Define parameters for model initialization, randomly sampled within specific ranges
# Note: These are per-joint parameters
init_params = {
    "armature": jax.random.uniform(keys[0], minval=0.0, maxval=5.0, shape=(n_joints,)),
    "frictionloss": jax.random.uniform(keys[1], minval=0.0, maxval=10.0, shape=(n_joints,)),
    "damping": jax.random.uniform(keys[2], minval=0.0, maxval=5.0, shape=(n_joints,)),
}

print(f"\nInitial parameter values:")
print(f"  armature: {init_params['armature']}")
print(f"  frictionloss: {init_params['frictionloss']}")
print(f"  damping: {init_params['damping']}")

# Determine the DOF offset (number of DOFs before the actuated joints)
# This is typically 0 for models without free joints, or 6 for models with a free joint
n_dof_total = mjx_model.nv
n_actuated_dof = n_joints
dof_offset = n_dof_total - n_actuated_dof
print(f"\nModel DOF structure:")
print(f"  Total DOF: {n_dof_total}")
print(f"  Actuated DOF: {n_actuated_dof}")
print(f"  DOF offset: {dof_offset}")

# Build PID gains array matching actual joints
print(f"\nPID Gains for each joint:")
kp_list = []
ki_list = []
kd_list = []
for joint in JOINT_NAMES:
    if joint in PID_GAINS:
        p, i, d = PID_GAINS[joint]
        kp_list.append(p)
        ki_list.append(i)
        kd_list.append(d)
        print(f"  {joint}: P={p}, I={i}, D={d}")
    else:
        print(f"  ⚠️  {joint}: No PID gains defined, using default (P=50, I=0, D=0.5)")
        kp_list.append(50.0)
        ki_list.append(0.0)
        kd_list.append(0.5)

# Convert to JAX arrays for use in make_action
KP_ARRAY = jnp.array(kp_list)
KD_ARRAY = jnp.array(kd_list)
KI_ARRAY = jnp.array(ki_list)

# Verification: Check shapes match
print(f"\n{'='*60}")
print(f"SHAPE VERIFICATION")
print(f"{'='*60}")
print(f"KP_ARRAY shape: {KP_ARRAY.shape}")
print(f"KD_ARRAY shape: {KD_ARRAY.shape}")
print(f"Dataset qpos shape: {dataset['qpos'].shape}")
print(f"Dataset qvel shape: {dataset['qvel'].shape}")
print(f"DOF offset: {dof_offset}")
print(f"Expected actuated DOF after slicing [dof_offset:]: {n_dof_total - dof_offset}")
print(f"✓ All shapes should match: {n_joints}")
print(f"{'='*60}\n")


@jax.jit
def change_model(params: Dict, old_model: mjx.Model):
    """
    Modify the exoskeleton model by updating physical parameters for all joints.
    
    Args:
        params: Dictionary with 'armature', 'frictionloss', 'damping' arrays (one value per joint)
        old_model: The MJX model to update
    
    Returns:
        changed_model: Updated model with new parameters
    """
    # Create full parameter arrays with zeros for any non-actuated DOFs
    if dof_offset > 0:
        full_armature = jnp.concatenate([jnp.zeros(dof_offset), jnp.abs(params["armature"])])
        full_frictionloss = jnp.concatenate([jnp.zeros(dof_offset), jnp.abs(params["frictionloss"])])
        full_damping = jnp.concatenate([jnp.zeros(dof_offset), jnp.abs(params["damping"])])
    else:
        full_armature = jnp.abs(params["armature"])
        full_frictionloss = jnp.abs(params["frictionloss"])
        full_damping = jnp.abs(params["damping"])
    
    changed_model = old_model.replace(
        dof_armature=full_armature,
        dof_frictionloss=full_frictionloss,
        dof_damping=full_damping,
    )
    
    return changed_model


@jax.jit
def make_action(params: Dict, data: mjx.Data, ctrl: Array):
    """
    Computes control actions (torques) using PID control law for all joints.
    
    Args:
        params: Model parameters (not used in this controller, but kept for interface)
        data: Current state (qpos, qvel)
        ctrl: Desired positions for all joints
    
    Returns:
        tau: Control torques for all actuated joints
    """
    # Extract actuated joint positions and velocities (skip any non-actuated DOFs)
    qpos_actuated = data.qpos[dof_offset:]
    qvel_actuated = data.qvel[dof_offset:]
    
    # PD control (ignoring integral term for now)
    position_error = ctrl - qpos_actuated
    tau = KP_ARRAY * position_error - KD_ARRAY * qvel_actuated
    
    return tau


# Generate total loss and training step functions
total_loss, train_step, _ = generate_loss_train_functions(
    mjx_model=mjx_model,
    mjx_data=mjx_data,
    change_model=change_model,
    make_action=make_action,
)

print("✓ Optimization functions generated successfully!")

# Training loop

In [None]:
# Initialize training state
state = train_state.TrainState.create(
    apply_fn=None,
    params=init_params,
    tx=PARAMS.TRAIN.TX,
)

loss_hist = []
params_hist = []
indxs = jax.numpy.array(range(dataset["qpos"].shape[0]))

print(f"Starting training for {PARAMS.TRAIN.EPOCH_NUM} epochs...")
print(f"Dataset size: {dataset['qpos'].shape[0]} samples")
print(f"Batch size: {PARAMS.TRAIN.BATCH_SIZE}")

for epoch in range(PARAMS.TRAIN.EPOCH_NUM):
    # VALIDATE - compute loss on full dataset
    loss = total_loss(
        state.params,
        qpos=dataset["qpos"],
        qvel=dataset["qvel"],
        ctrl_vec=dataset["qact"],
        qpos_des=dataset["qpos_next"],
    )
    loss_hist.append(loss)
    print(f"Epoch {epoch:3d}, Loss: {loss:.6e}")
    
    # Store parameter history
    params_hist.append(state.params)
    
    # TRAIN - update parameters using batch
    batch, indxs = get_batch(dataset, PARAMS.SEED, indxs, PARAMS.TRAIN.BATCH_SIZE)
    state, grads = train_step(
        state,
        qpos=batch["qpos"],
        qvel=batch["qvel"],
        ctrl_vec=batch["qact"],
        qpos_des=batch["qpos_next"],
    )

print("Training complete!")
print("\nFinal optimized parameters:")
for param_name, values in state.params.items():
    print(f"{param_name}:")
    for i, (joint, val) in enumerate(zip(JOINT_NAMES, values)):
        print(f"  {joint}: {val:.6f}")

print("Training loop ready.")

# Loss Analysis

In [None]:
# NOTE: Uncomment when training is complete

# baseline_loss = total_loss(
#     zero_params,
#     qpos=dataset["qpos"],
#     qvel=dataset["qvel"],
#     ctrl_vec=dataset["qact"],
#     qpos_des=dataset["qpos_next"],
# )
# adjusted_model_loss = np.array(loss_hist)
# base_line = np.ones_like(adjusted_model_loss)

# print(f"Baseline loss (no optimization): {baseline_loss:.6e}")
# print(f"Final loss (optimized): {adjusted_model_loss[-1]:.6e}")
# print(f"Improvement: {(1 - adjusted_model_loss[-1]/baseline_loss)*100:.2f}%")

print("Loss analysis ready. Uncomment when training is complete.")

In [None]:
# NOTE: Uncomment when training is complete

# import matplotlib.pyplot as plt
# import os, joblib

# # Create comprehensive visualization
# fig = plt.figure(figsize=(16, 12))
# gs = fig.add_gridspec(4, 3, hspace=0.3, wspace=0.3)

# # 1. Loss convergence
# ax1 = fig.add_subplot(gs[0, :])
# ax1.plot(adjusted_model_loss / baseline_loss, label="Relative loss", linewidth=2)
# ax1.plot(base_line, '--', label="Baseline (no optimization)", alpha=0.7)
# ax1.set_title("Training Loss Convergence", fontsize=14, fontweight='bold')
# ax1.set_ylabel("Relative loss")
# ax1.set_xlabel("Epoch")
# ax1.legend()
# ax1.grid(True, alpha=0.3)

# # 2-4. Parameter convergence for each parameter type
# param_types = ['armature', 'frictionloss', 'damping']
# for idx, param_type in enumerate(param_types):
#     ax = fig.add_subplot(gs[1, idx])
#     for i, joint in enumerate(JOINT_NAMES):
#         values = [p[param_type][i] for p in params_hist]
#         ax.plot(values, label=joint, alpha=0.7)
#     ax.set_title(f"{param_type.capitalize()} Convergence")
#     ax.set_ylabel(param_type)
#     ax.set_xlabel("Epoch")
#     ax.legend(fontsize=8, loc='best')
#     ax.grid(True, alpha=0.3)

# # 5-7. Final parameter values by joint
# for idx, param_type in enumerate(param_types):
#     ax = fig.add_subplot(gs[2, idx])
#     final_values = state.params[param_type]
#     colors = plt.cm.viridis(np.linspace(0, 1, len(JOINT_NAMES)))
#     bars = ax.bar(range(len(JOINT_NAMES)), final_values, color=colors)
#     ax.set_title(f"Final {param_type.capitalize()} Values")
#     ax.set_ylabel(param_type)
#     ax.set_xticks(range(len(JOINT_NAMES)))
#     ax.set_xticklabels(JOINT_NAMES, rotation=45, ha='right', fontsize=8)
#     ax.grid(True, alpha=0.3, axis='y')

# # 8. Comparison table
# ax_table = fig.add_subplot(gs[3, :])
# ax_table.axis('off')
# table_data = []
# for i, joint in enumerate(JOINT_NAMES):
#     row = [
#         joint,
#         f"{state.params['armature'][i]:.4f}",
#         f"{state.params['frictionloss'][i]:.4f}",
#         f"{state.params['damping'][i]:.4f}"
#     ]
#     table_data.append(row)

# table = ax_table.table(
#     cellText=table_data,
#     colLabels=['Joint', 'Armature', 'Friction Loss', 'Damping'],
#     cellLoc='center',
#     loc='center',
#     colWidths=[0.3, 0.2, 0.2, 0.2]
# )
# table.auto_set_font_size(False)
# table.set_fontsize(9)
# table.scale(1, 2)
# for i in range(len(JOINT_NAMES) + 1):
#     if i == 0:
#         table[(i, 0)].set_facecolor('#40466e')
#         table[(i, 1)].set_facecolor('#40466e')
#         table[(i, 2)].set_facecolor('#40466e')
#         table[(i, 3)].set_facecolor('#40466e')
#         table[(i, 0)].set_text_props(weight='bold', color='white')
#         table[(i, 1)].set_text_props(weight='bold', color='white')
#         table[(i, 2)].set_text_props(weight='bold', color='white')
#         table[(i, 3)].set_text_props(weight='bold', color='white')

# # Save results
# folder_path = os.path.join("assets/experiments", PARAMS.EXPERIMENT_NAME)
# os.makedirs(folder_path, exist_ok=True)
# img_path = os.path.join(folder_path, "pictures")
# os.makedirs(img_path, exist_ok=True)

# # Save optimization results
# joblib.dump(
#     {
#         "params": params_hist,
#         "final_params": state.params,
#         "loss_hist": adjusted_model_loss,
#         "baseline_loss": baseline_loss,
#         "joint_names": JOINT_NAMES,
#         "pid_gains": PID_GAINS,
#     },
#     os.path.join(folder_path, "exo_optimization_results.joblib"),
# )

# # Save figure
# fig.savefig(
#     os.path.join(img_path, "exo_optimization.png"),
#     format="png",
#     bbox_inches="tight",
#     dpi=300
# )

# print(f"Results saved to {folder_path}")
# plt.show()

print("Visualization code ready. Uncomment when training is complete.")