# DPC Torque Control — Training Example

This notebook demonstrates the basic usage of the Differentiable predictive control (DPC) codebase
and provides a minimal, end-to-end example that can be executed top-to-bottom.

## Overview
- **Goal:** Train a neural network policy to control PMSM torque.
- **Method:** Differentiable predictive control (DPC)
- **Output:** Trained policy network with convergence analysis

## Prerequisites & Setup

### Requirements
1. Clone the repository from GitHub
2. Install all dependencies listed in `requirements.txt`
3. Ensure JAX is configured to use GPU (optional but recommended for performance)

### Execution
- Run all cells sequentially from top to bottom
- Each cell depends on the previous ones having completed successfully
- Total runtime: ~5 hours (depends on hardware and training configuration)

## Notebook Sections

This notebook is organized into the following sections:

1. **Setup & Dependencies** - Import required libraries and configure JAX
2. **Motor Definition** - Initialize PMSM (Permanent Magnet Synchronous Motor) with physical parameters
3. **Training Configuration** - Define neural network architecture, loss functions, and training hyperparameters
4. **Trainer Initialization** - Create DPC trainer instance with configured components
5. **Training Execution** - Train the policy network 
6. **Results Analysis** - Plot training losses and visualize convergence

## 1. Setup & Dependencies

### Import External Libraries and Configure Environment

In [None]:
"""
Import external dependencies for numerical computation, JAX framework, and visualization.

Libraries used:
- JAX/Equinox: Automatic differentiation and neural network primitives
- NumPy: Numerical operations and array handling
- Optax: Gradient-based optimization algorithms
- Matplotlib: Data visualization
- CasADi: Symbolic computation (optional, for OCP solver)
"""

# Core scientific computing
import sys
from pathlib import Path
import numpy as np

# JAX ecosystem
import jax
import jax.nn as jnn
import jax.numpy as jnp
import equinox as eqx
import optax

# Visualization and data analysis
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter, MaxNLocator

# Utilities
from functools import partial
from IPython.display import display, clear_output
import pandas as pd

# Symbolic computation (for OCP solver)
import casadi as ca

In [None]:
"""
Configure project root and add to Python path for local module imports.

This allows importing from policy/, utils/, and visualization/ directories.
"""

# Add project root to Python path for local imports
PROJECT_ROOT = Path().resolve().parents[0]
sys.path.insert(0, str(PROJECT_ROOT))

In [None]:
"""
Import project-specific modules for motor simulation, policy training, and visualization.

Components:
- Motor Environment: PMSM physical simulator
- Policy Training: DPC trainer and neural network architecture
- Loss Functions: Multi-objective loss components for policy optimization
- Utilities: Data generation, feature extraction, and visualization
"""

# Motor environment and simulation
from exciting_environments.pmsm.pmsm_env import PMSM

# Policy architecture and training
from policy.policy_training import DPCTrainer
from policy.networks import MLP
from policy.data_generation import reset, generate_feasible, node_dat_gen_sin, featurize

# Loss functions (6 component weighted loss)
from policy.losses import (
    ref_loss_fcn,                    # Torque reference tracking error
    efficincy_loss_fcn,              # Copper loss minimization
    posit_id_loss_fcn,               # Positive d-axis current constraint
    idq_nom_loss_fcn,                # Nominal current operating point
    idq_lim_loss,                    # Current limit constraint
    idq_SS_loss                      # Steady-state error
)

# Diagnostics and visualization
from policy.policy_training_diagnostics import plot_training_losses
from visualization.style import set_plot_style

In [None]:
"""
Configure JAX and matplotlib for this notebook session.

JAX Setup:
- Detect available GPU devices
- Set primary computation device (GPU if available, otherwise CPU)

Matplotlib Setup:
- Apply custom plot styling for consistent visualization
"""

# Configure JAX: use first available device (GPU preferred)
gpus = jax.devices()
jax.config.update("jax_default_device", gpus[0])

# Apply custom plot styling
set_plot_style()

## 2. Motor Definition

### Initialize PMSM with Physical Parameters

In [None]:
"""
Create PMSM (Permanent Magnet Synchronous Motor) environment instance.

Motor: BRUSA (electric vehicle traction motor)
Physical Parameters:
- Pole pairs (p): 3
- Stator resistance (r_s): 15 mΩ
- D-axis inductance (l_d): 0.37 mH
- Q-axis inductance (l_q): 1.2 mH
- Permanent magnet flux (psi_p): 65.6 mH
- Deadtime: 0 µs (ideal inverter)

Configuration:
- Batch size: 1 (single trajectory simulation)
- Saturation: Disabled (linear motor model)
- LUT: Look-up table motor characterization
"""

motor_env = PMSM(
    LUT_motor_name="BRUSA",
    saturated=False,
    batch_size=1,
    control_state=[],
    static_params={
        "p": 3,                    # Pole pairs
        "r_s": 15e-3,              # Stator resistance (Ω)
        "l_d": 0.37e-3,            # D-axis inductance (H)
        "l_q": 1.2e-3,             # Q-axis inductance (H)
        "psi_p": 65.6e-3,          # Permanent magnet flux (Wb)
        "deadtime": 0,             # Inverter deadtime (s)
    }
)

## 3. Training Configuration

### Define Neural Network Architecture and Loss Weights

In [None]:
"""
Configure neural network policy architecture and training hyperparameters.

Neural Network Architecture:
- Input: 8 features (motor state: currents, speed, etc.)
- Hidden layers: 3 × 128 units (fully-connected with activation)
- Output: 2 actions (d-axis and q-axis voltage commands)
- Framework: MLP (Multi-Layer Perceptron) with Equinox

Optimization:
- Optimizer: Adam with learning rate 1e-4
- Batch size: 50 trajectories per gradient update
- Training steps: 500000 gradient descent iterations

Loss Function (6 weighted components):
1. Reference tracking (weight: 1.0) - Follow desired torque
2. Efficiency loss (weight: 0.008) - Minimize copper losses (I²R)
3. Positive id constraint (weight: 0.0) - Encourage d-axis current direction
4. Nominal current (weight: 0.0) - Penalize deviation from nominal operating point
5. Current limit constraint (weight: 0.4) - Soft constraint on current magnitude
6. Steady-state error (weight: 0.0) - Final state tracking error
"""

# Random seed for reproducibility
seed = np.random.randint(0, 2**3)
jax_key = jax.random.PRNGKey(6)

# Neural network architecture parameters
layer_size = 128                   # Hidden layer width
num_layers = 3                     # Number of hidden layers
input_dim = 8                      # State observation dimension
output_dim = 2                     # Action dimension (u_d, u_q)

# Build MLP policy architecture: [input] → [128, 128, 128] → [output]
policy_archit = [input_dim] + [layer_size for _ in range(num_layers)] + [output_dim]
policy = MLP(policy_archit, key=jax_key)

# Optimizer setup
optimizer = optax.adam(1e-4)       # Adam optimizer with learning rate
opt_state = optimizer.init(policy) # Initialize optimizer state

# Loss functions (6 components for multi-objective optimization)
loss_fcns = [
    ref_loss_fcn,                  # Torque reference tracking
    efficincy_loss_fcn,            # Efficiency (minimize copper loss)
    posit_id_loss_fcn,             # Positive d-axis constraint
    idq_nom_loss_fcn,              # Nominal current operating point
    idq_lim_loss,                  # Current magnitude constraint
    idq_SS_loss                    # Steady-state error
]

# Loss weights for each component (tuned for MTPC control)
ieff = 0.008                       # Efficiency loss weight
iL = 0.4                           # Current limit constraint weight
loss_weights = [1.0, ieff, 0.0, 0.0, iL, 0.0]

# Training hyperparameters
batch_size = 50                    # Trajectories per batch
train_steps = 500000                 # Total gradient descent iterations
horizon = 60                       # Prediction/rollout horizon (timesteps)

## 4. Trainer Initialization

### Create DPC Trainer with Configured Components

In [None]:
"""
Create DPCTrainer instance that orchestrates the policy training process.

DPCTrainer Configuration:
- Manages batch generation, loss computation, and gradient updates
- Handles trajectory sampling across motor operating points (speed/torque grid)
- Applies analytical reference generation (AnalyticalRG) for reference signals
- Computes multi-objective weighted loss and backpropagation

Key Components:
- reset_env: Reinitialize motor state at trajectory start
- data_gen_sin: Single-data reference generator
- gen_feas: Generate feasible reference currents and torques
- featurize: Extract neural network input features from motor observations
"""

# Create trainer instance with all components
trainer = DPCTrainer(
    batch_size=batch_size,                        # Trajectories per gradient step
    train_steps=train_steps,                      # Total training iterations
    horizon_length=horizon,                       # Prediction horizon
    reset_env=reset,                              # Environment reset function
    data_gen_sin=node_dat_gen_sin,                # Reference trajectory generator
    gen_feas=generate_feasible,                   # Feasible reference generator
    featurize=featurize,                          # Feature extraction function
    policy_optimizer=optimizer,                   # Gradient optimizer (Adam)
    loss_fcns=loss_fcns,                          # 6-component loss functions
    loss_weights=loss_weights,                    # Loss weights vector
)

# Initialize random keys for batch sampling (one per trajectory)
# This enables deterministic but pseudo-random trajectory sampling
keys = jax.vmap(jax.random.PRNGKey)(
    np.random.randint(0, 2**31, size=(batch_size,))
)

## 5. Training Execution

### Train Policy on PMSM Environment

Progress is displayed via tqdm progress bar.

In [None]:
"""
Execute policy training loop using DPCTrainer.

This applies DPC over 'train_steps' iterations

Returns:
- p2: Trained policy network (neural network weights)
- fin_opt_state: Final optimizer state (momentum/variance accumulators)
- fin_keys: Final random keys after training
- losses: List of total loss values per iteration
- ref_losses: Reference tracking error per iteration
- eff_losses: Efficiency (copper loss) per iteration
- i_lim_losses: Current limit constraint violation per iteration
- i_ss_losses: Steady-state error per iteration
- acts_norm_losses: Action magnitude penalty per iteration
- train_data: Intermediate training trajectories for analysis
"""

# Train DPC policy on PMSM environment
(p2, fin_opt_state, fin_keys, losses, ref_losses, eff_losses,
 i_lim_losses, i_ss_losses, acts_norm_losses, train_data) = trainer.fit_non_jit(
    policy,              # Initial policy network
    motor_env,           # PMSM motor simulator
    keys,                # Random keys for sampling
    opt_state            # Initial optimizer state
)

In [None]:
"""
Visualize training convergence by plotting loss components.

This generates diagnostic plots showing:
1. Total loss convergence
2. Reference tracking error decay (primary objective)
3. Efficiency loss trajectory (copper loss minimization)
4. Current limit constraint violations
5. Steady-state error evolution
6. Action magnitude penalties

Interpretation:
- All losses should decrease monotonically (no guarantee due to non-convex problem)
- Reference loss should show fastest initial decay
- Efficiency loss may oscillate due to trade-offs with tracking
- Current limit loss indicates policy constraint satisfaction
"""

# Plot all loss components during training
plot_training_losses(
    losses,              # Total weighted loss
    ref_losses,          # Torque reference tracking error
    eff_losses,          # Efficiency (copper loss) I²R
    i_lim_losses,        # Current limit constraint
    i_ss_losses,         # Steady-state error
    acts_norm_losses     # Action magnitude penalty
)