# Autoencoder Formulation for Pareto Front Analysis

Some of considerd limitations:
1. Defining the optimal point (energy/time -> acceleration) at the beginning of the system. The user has the ability to select that.
2. All data features should be taken into account, e.g. (Decision variables + problem parameters + objectives)


Offline:
1. Train CVAE on Pareto profiles & conditions $c_i$.
2. For each training sample, compute anchor pairs:
   $ (c_i, μ_i) $ ← Encoder $(x_i,c_i)$.
3. Train conditional prior $f_ψ: c_i → μ_i$.


In [32]:
# === Standard Library ===
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# === PyTorch Core ===
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
# import torch.distributions

# === PyTorch Utils / Vision ===
# import torchvision

# === Other Libraries ===
from tqdm import tqdm

# === Configuration ===
torch.manual_seed(0)
plt.rcParams["figure.dpi"] = 200


## Date reading

In [33]:
pareto_data = pd.read_pickle("../data/processed/pareto_data.pkl")
train_df = pd.DataFrame(pareto_data["train"])
test_df = pd.DataFrame(pareto_data["test"])

train_df

Unnamed: 0,accel_ms2,decel_ms2,time_min,energy_kwh,weighted_time,weighted_energy,distance_km
264,0.021956,0.868852,0.675099,0.654848,0.037219,0.962781,0.666667
615,0.263490,0.990406,0.327851,0.340762,0.887128,0.112872,0.333333
329,0.166521,0.986375,0.649817,0.663206,0.705583,0.294417,0.666667
342,0.497278,0.996336,0.636558,0.671084,0.977346,0.022654,0.666667
394,0.094652,0.908151,0.658117,0.660406,0.379313,0.620687,0.666667
...,...,...,...,...,...,...,...
71,0.216862,0.865053,0.961821,0.990308,0.831446,0.168554,1.000000
106,0.006263,0.982190,0.997049,0.978252,0.014206,0.985794,1.000000
270,0.025989,0.865614,0.673712,0.655410,0.047202,0.952798,0.666667
435,0.094725,0.912358,0.026695,0.009812,0.381008,0.618992,0.000000


## Conditional VAE Architecture

### Why to use?
In traditional autoencoders, inputs are mapped deterministically to a latent vector
$z=e(x)$. In variational autoencoders, inputs are mapped to a probability distribution over latent vectors, and a latent vector is then sampled from that distribution. The decoder becomes more robust at decoding latent vectors as a result.

### Variational Autoencoder (VAE) Latent Space Mapping

                     Input Conditions
                      (SOC, Distance,
                      Velocity Profile)
                           │
                           ▼
                ┌───────────────────────┐
                │      Conditioner      │
                │  (MLP Transformer)    │
                └───────────────────────┘
                           │
                           ├──────────────────────┐
                           ▼                      ▼
                ┌───────────────────────┐ ┌───────────────────────┐
                │       Encoder         │ │      Decoder          │
                │  (q(z|x,c))           │ │  (p(x|z,c))           │
                └───────────────────────┘ └───────────────────────┘
                           │                      ▲
                           └───────► z ◄─────────┘
                                  Latent Space


Instead of directly mapping the input **$ x $** to a latent vector **$ z = e(x) $**, we instead map it to:

- **Mean vector**: **$ \mu(x) $**
- **Standard deviation vector**: **$ \sigma(x) $**

These parameters define a **diagonal Gaussian distribution** $ N(\mu(x), \sigma(x)) $, from which we sample the latent vector $ z $:

$$
z \sim N(\mu(x), \sigma(x))
$$

This formulation allows the model to learn a probabilistic latent space representation where each input $ x $ defines its own distribution over latent codes rather than a single deterministic point.


### 1. Problem Definition
Let:
- **x** ∈ ℝ² : Solution vector (time_min, energy_kwh)
- **c** ∈ ℝ⁵ : Condition vector (SOC, distance, avg_velocity, max_accel, energy_weight, time_weight)
- **z** ∈ ℝᴸ : Latent representation (L=8)

### 2. Probabilistic Model
**Objective**: Learn conditional distribution 
$$p_\theta(x|z,c) \quad \text{where} \quad z \sim q_\phi(z|x,c)$$

**Evidence Lower Bound (ELBO)**:
$$
\mathcal{L}(\theta,\phi;x,c) = \mathbb{E}_{q_\phi(z|x,c)}[\log p_\theta(x|z,c)] - \beta D_{KL}(q_\phi(z|x,c) \| p(z))
$$

 
### Encoder Network (Compression)
$$
\mathbf{z} = g_\phi(\mathbf{x}) = \text{LeakyReLU}(\mathbf{W}_2 \cdot \text{ELU}(\mathbf{W}_1\mathbf{x} + \mathbf{b}_1) + \mathbf{b}_2)
$$

### Decoder Network (Reconstruction)
$$
\hat{\mathbf{x}} = f_\theta(\mathbf{z}) = \text{Sigmoid}(\mathbf{W}_4 \cdot \text{ELU}(\mathbf{W}_3\mathbf{z} + \mathbf{b}_3) + \mathbf{b}_4)
$$

**Dimensionality**:
- Input/Output: $\mathbb{R}^2$ (normalized objectives)
- Latent space: $\mathbb{R}^1$ (bottleneck)
- Hidden layers: 32 neurons with ELU activation


## Architecture
Sigmoid ensures outputs stay in normalized $[0,1]$ range

### Encoder
Maps 2D Pareto solutions to a 1D latent space:
$$
\mathbf{z} = \text{Encoder}(\mathbf{x}) = \sigma(\mathbf{W}_2 \cdot \text{ReLU}(\mathbf{W}_1\mathbf{x} + \mathbf{b}_1) + \mathbf{b}_2)
$$

Where:
- $\mathbf{W}_1 \in \mathbb{R}^{h \times 2}$, $\mathbf{W}_2 \in \mathbb{R}^{1 \times h}$ are weight matrices
- $\mathbf{b}_1 \in \mathbb{R}^h$, $\mathbf{b}_2 \in \mathbb{R}^1$ are bias terms
- $h$ is hidden layer size
- $\sigma$ is sigmoid activation

### Decoder
Reconstructs solutions from latent space:
$$
\hat{\mathbf{x}} = \text{Decoder}(\mathbf{z}) = \sigma(\mathbf{W}_4 \cdot \text{ReLU}(\mathbf{W}_3\mathbf{z} + \mathbf{b}_3) + \mathbf{b}_4)
$$

With:
- $\mathbf{W}_3 \in \mathbb{R}^{h \times 1}$, $\mathbf{W}_4 \in \mathbb{R}^{2 \times h}$
- $\mathbf{b}_3 \in \mathbb{R}^h$, $\mathbf{b}_4 \in \mathbb{R}^2$


### Activation function

### Loss Function
Mean Squared Error (MSE) reconstruction loss:
$$
\mathcal{L}_{recon} = \frac{1}{N}\sum_{i=1}^N \|\mathbf{x}_i - \hat{\mathbf{x}}_i\|^2_2
$$

## 1. Composite Loss Components
The total optimization objective combines three critical elements:

$$
\mathcal{L}_{total} = \mathcal{L}_{recon} + \beta\mathcal{L}_{KL} + \mathcal{L}_{phys}
$$

#### 1.1 Weighted Reconstruction Loss
Domain-prioritized MSE accounting for operational criticality:
$$
\mathcal{L}_{recon} = \frac{1}{N}\sum_{i=1}^N \left[w_t\left(\frac{\hat{t}_i - t_i}{t_{max}}\right)^2 + w_e\left(\frac{\hat{e}_i - e_i}{e_{max}}\right)^2\right]
$$

| Parameter | Value | Rationale |
|-----------|-------|-----------|
| $w_t$     | 0.7   | Time minimization priority |
| $w_e$     | 0.3   | Energy conservation importance |
| $t_{max}$ | 120 min | Maximum allowable trip time |
| $e_{max}$ | 10 kWh | Battery capacity limit |

#### 1.2 KL Divergence Regularization
Gaussian constraint for latent space organization:
$$
\mathcal{L}_{KL} = \frac{1}{2}\sum_{j=1}^8 \left[\exp(\log\sigma_j^2) + \mu_j^2 - 1 - \log\sigma_j^2\right]
$$

#### 1.3 Physics-Informed Penalty
Hard constraint enforcement through soft penalties:
$$
\mathcal{L}_{phys} = \lambda_1\max(0, \hat{e} - e_{max})^2 + \lambda_2\max(0, t_{min} - \hat{t})^2
$$

| Constraint | Formula | Weight $\lambda$ |
|------------|---------|-------|
| Energy Cap | $\hat{e} \leq 10\text{kWh}$ | 1.5 |
| Time Floor | $\hat{t} \geq \frac{d}{v_{max}}$ | 0.8 |



### Input features

[
    target_distance,        # Scalar (e.g., 7.5 km)
    time_weight,            # 0.0 (min) ↔ 1.0 (max)
    energy_weight,          # 0.0 (min) ↔ 1.0 (max)
    constraints_vector      # [max_jerk, battery_limit, ...]
]

### Output Features
[
    acceleration_profile,   # Time-series (100 steps)
    deceleration_profile    # Time-series (100 steps)
]


In [34]:
# # Example Training Batch
# batch = {
#     "input_conditions": torch.tensor(
#         [
#             # [distance, time_weight, energy_weight, max_jerk]
#             [7.5, 0.9, 0.1, 0.3],  # Minimum time
#             [10.0, 0.5, 0.5, 0.3],  # Balanced
#             [15.0, 0.1, 0.9, 0.3],  # Minimum energy
#         ],
#         dtype=torch.float32,
#     ),
#     "output_profiles": torch.tensor(
#         [
#             # [accel_profile (100 steps), decel_profile (100 steps)]
#             [0.8, 0.82, ..., 1.2, 0.5, 0.48, ..., 0.1],  # Aggressive accel
#             [0.5, 0.52, ..., 0.6, 0.4, 0.38, ..., 0.2],  # Moderate
#             [0.3, 0.31, ..., 0.4, 0.6, 0.62, ..., 0.8],  # Conservative
#         ],
#         dtype=torch.float32,
#     ),
# }


### Loss function

loss = reconstruction_loss + α*physics_loss + β*constraint_penalty


### Training process
Filter out the non-dominated solutions as resemble pareto front



In [35]:
class Encoder(nn.Module):
    """
    Tabular CVAE Encoder

    Encodes a 4-dimensional feature vector (accel1, accel2, time, energy)
    concatenated with a 5-dimensional condition vector
    (current_speed, remaining_battery, distance_to_pass, energy_weight, time_weight)
    into latent distribution parameters (mean and log-variance).

    Architecture:
    - Input layer: 9 -> 32
    - Hidden layer: 32 -> 16
    - Output heads: 16 -> latent_dim (for both mean and log-variance)

    Args:
        input_dim (int): Dimensionality of the input features (4)
        condition_dim (int): Dimensionality of the condition vector (5)
        latent_dim (int): Dimensionality of the latent space
    """

    def __init__(self, input_dim: int = 4, condition_dim: int = 5, latent_dim: int = 8):
        super().__init__()
        self.input_dim = input_dim
        self.condition_dim = condition_dim
        hidden_dim1 = 32
        hidden_dim2 = 16

        # Shared MLP to extract latent statistics
        self.shared = nn.Sequential(
            nn.Linear(input_dim + condition_dim, hidden_dim1),
            nn.LeakyReLU(0.1),
            nn.BatchNorm1d(hidden_dim1),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.LeakyReLU(0.1),
            nn.BatchNorm1d(hidden_dim2),
        )
        # Output layers for mean and log-variance
        self.mu_layer = nn.Linear(hidden_dim2, latent_dim)
        self.logvar_layer = nn.Linear(hidden_dim2, latent_dim)

    def forward(self, x: torch.Tensor, cond: torch.Tensor):
        """
        Forward pass

        Args:
            x (Tensor): Feature tensor [batch_size, 4]
            cond (Tensor): Condition tensor [batch_size, 5]

        Returns:
            mu (Tensor): Predicted latent means [batch_size, latent_dim]
            logvar (Tensor): Predicted latent log-variances [batch_size, latent_dim]
        """
        assert x.shape[1] == self.input_dim
        assert cond.shape[1] == self.condition_dim
        # concatenate features and conditions
        h = torch.cat([x, cond], dim=1)
        h = self.shared(h)
        mu = self.mu_layer(h)
        logvar = self.logvar_layer(h)
        return mu, logvar


class Decoder(nn.Module):
    """
        Tabular CVAE Decoder

        Decodes a latent vector and 5-dimensional condition
    to reconstruct the original 4-dimensional features.

        Architecture:
        - Input layer: latent_dim + 5 -> 16
        - Hidden layer: 16 -> 32
        - Output layer: 32 -> 4

        Args:
            latent_dim (int): Dimensionality of the latent space
            condition_dim (int): Dimensionality of the condition vector (5)
            output_dim (int): Dimensionality of reconstructed features (4)
    """

    def __init__(
        self, latent_dim: int = 8, condition_dim: int = 5, output_dim: int = 4
    ):
        super().__init__()
        hidden_dim1 = 16
        hidden_dim2 = 32

        self.decoder_net = nn.Sequential(
            nn.Linear(latent_dim + condition_dim, hidden_dim1),
            nn.LeakyReLU(0.1),
            nn.BatchNorm1d(hidden_dim1),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.LeakyReLU(0.1),
            nn.BatchNorm1d(hidden_dim2),
            nn.Linear(hidden_dim2, output_dim),
        )
        # initialize output layer for stability
        nn.init.xavier_uniform_(self.decoder_net[-1].weight)
        nn.init.zeros_(self.decoder_net[-1].bias)

    def forward(self, z: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        """
        Forward pass

        Args:
            z (Tensor): Latent tensor [batch_size, latent_dim]
            cond (Tensor): Condition tensor [batch_size, condition_dim]

        Returns:
            recon (Tensor): Reconstructed features [batch_size, output_dim]
        """
        assert z.dim() == 2 and cond.dim() == 2
        # concatenate latent and conditions
        h = torch.cat([z, cond], dim=1)
        recon = self.decoder_net(h)
        return recon


class PhysicsModel(nn.Module):
    """
    Physics-Informed Model for Vehicle Dynamics and Constraints

    This module computes physical metrics from acceleration profiles,
    including energy consumption, travel time estimation, and jerk statistics.
    It encapsulates fundamental equations of motion and resistive forces,
    enabling integration into CVAE training and real-time inference pipelines.

    Args:
        config (dict): Configuration dictionary with keys:
            - vehicle_mass (float): Vehicle mass in kg
            - air_density (float): Air density in kg/m^3
            - drag_coeff (float): Aerodynamic drag coefficient
            - rolling_res (float): Rolling resistance coefficient
            - motor_eff (float): Motor efficiency (0 < eff <= 1)
    """

    def __init__(self, config: dict):
        super().__init__()
        # Vehicle physical parameters
        self.vehicle_mass = config.get("vehicle_mass", 1500.0)  # kg
        self.air_density = config.get("air_density", 1.225)  # kg/m^3
        self.drag_coeff = config.get("drag_coeff", 0.3)  # unitless
        self.rolling_res = config.get("rolling_res", 0.01)  # rolling resistance coeff
        self.motor_eff = config.get("motor_eff", 0.9)  # motor efficiency fraction

    def calculate_energy(
        self,
        accel_profile: torch.Tensor,
        time_profile: torch.Tensor,
        remaining_battery: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute total energy consumption over a trip using an acceleration profile.

        Args:
            accel_profile (Tensor): [batch, time_steps] acceleration in m/s^2
            time_profile (Tensor): [batch, time_steps] timestamps in seconds
            remaining_battery (Tensor): [batch] current battery state-of-charge (0-1)

        Returns:
            Tensor: energy used in kWh [batch]
        """
        # Integrate acceleration to velocity
        velocity = torch.cumsum(
            accel_profile * (time_profile[:, 1] - time_profile[:, 0]), dim=1
        )
        # Compute resistive forces
        drag_force = 0.5 * self.air_density * self.drag_coeff * velocity**2
        rolling_force = self.rolling_res * self.vehicle_mass * 9.81
        # Mechanical power: (mass*a + drag + rolling) * v
        mech_power = (
            self.vehicle_mass * accel_profile + drag_force + rolling_force
        ) * velocity
        # Account for motor efficiency: electrical power drawn
        elec_power = mech_power / self.motor_eff
        # Integrate power over time to get energy (Ws to Wh)
        energy_Wh = torch.trapz(elec_power, time_profile, dim=1) / 3600.0
        # Convert to kWh and clamp to remaining battery
        return torch.clamp(energy_Wh / 1000.0, max=remaining_battery)

    def calculate_jerk(
        self, accel_profile: torch.Tensor, dt: float = 0.1
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Compute mean and maximum jerk (rate of change of acceleration).

        Args:
            accel_profile (Tensor): [batch, time_steps] acceleration in m/s^2
            dt (float): time step between acceleration measurements

        Returns:
            mean_jerk (Tensor): mean absolute jerk [batch]
            max_jerk (Tensor): maximum absolute jerk [batch]
        """
        # Jerk: derivative of acceleration
        jerk = (accel_profile[:, 1:] - accel_profile[:, :-1]) / dt
        # Compute metrics
        mean_jerk = torch.mean(torch.abs(jerk), dim=1)
        max_jerk = torch.max(torch.abs(jerk), dim=1)
        return mean_jerk, max_jerk

    def forward(
        self, profiles: torch.Tensor, conditions: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass: compute time, energy, average jerk, and max jerk.

        Args:
            profiles (Tensor): [batch, time_steps] acceleration profiles
            conditions (Tensor): [batch, condition_dim] includes remaining_battery at idx 1

        Returns:
            time (Tensor): estimated travel time in seconds [batch]
            energy (Tensor): energy consumption in kWh [batch]
            mean_jerk (Tensor): mean absolute jerk [batch]
            max_jerk (Tensor): maximum absolute jerk [batch]
        """
        batch_size, seq_len = profiles.shape
        # Create uniform time vector if not provided
        dt = conditions.new_full((1,), 0.1).item()
        t = torch.linspace(0.0, dt * (seq_len - 1), seq_len, device=profiles.device)
        time_profile = t.unsqueeze(0).repeat(batch_size, 1)

        # Compute energy based on current battery
        remaining_batt = conditions[:, 1]
        energy = self.calculate_energy(profiles, time_profile, remaining_batt)

        # Velocity integration for time estimation
        velocity = torch.cumsum(profiles * dt, dim=1)
        # Distance traveled
        distance = torch.trapz(velocity, time_profile, dim=1)
        # Time to cover distance: distance / avg speed
        time_est = distance / (torch.mean(velocity, dim=1) + 1e-6)

        # Compute jerk metrics
        mean_jerk, max_jerk = self.calculate_jerk(profiles, dt)

        return time_est, energy, mean_jerk, max_jerk


class CVAE(nn.Module):
    """
    Conditional Variational Autoencoder (CVAE) integrating a physics-based loss.

    The CVAE learns a mapping from input acceleration profiles (tabular features)
    and condition vectors (vehicle state and user preferences) to a latent space,
    then reconstructs profiles via a conditional decoder.
    Physics-informed constraints (time, energy, jerk) are incorporated into the loss
    to enforce feasibility of generated profiles.

    Components:
    - Encoder: Produces μ and logσ² for qϕ(z|x,c)
    - Reparameterizer: Samples z deterministically via z=μ+σ·ε
    - Decoder: Generates reconstructed profiles pθ(x̂|z,c)
    - PhysicsModel: Computes time, energy, and jerk for constraint losses

    Args:
        config (dict): Configuration with keys:
            input_dim (int): Dimensionality of raw features (4)
            latent_dim (int): Dimensionality of latent vector z
            condition_dim (int): Dimensionality of condition vector (5)
            device (str or torch.device): Compute device
            max_jerk (float): Maximum allowable jerk for comfort
            max_energy (float): Battery capacity (kWh)
            max_time (float): Maximum allowable travel time (s)
            base_time_weight (float): Base weight for time loss
            base_energy_weight (float): Base weight for energy loss
            base_jerk_weight (float): Base weight for jerk loss
            adaptive_weighting (bool): Use dynamic weights from conditions
    """

    def __init__(self, config: dict):
        super().__init__()
        # Save config and device
        self.device = torch.device(config.get("device", "cpu"))

        # Instantiate encoder, decoder, and physics model
        self.encoder = Encoder(
            input_dim=config["input_dim"],
            condition_dim=config.get("condition_dim", 5),
            latent_dim=config["latent_dim"],
        )
        self.decoder = Decoder(
            latent_dim=config["latent_dim"],
            condition_dim=config.get("condition_dim", 5),
            output_dim=config["input_dim"],
        )
        self.physics_model = PhysicsModel(config)

        # Physics/constraint parameters
        self.max_jerk = config.get("max_jerk", 2.5)
        self.max_energy = config.get("max_energy", 10.0)
        self.max_time = config.get("max_time", 120.0)

        # Loss weights
        self.adaptive = config.get("adaptive_weighting", True)
        self.base_time_w = config.get("base_time_weight", 0.5)
        self.base_energy_w = config.get("base_energy_weight", 0.5)
        self.base_jerk_w = config.get("base_jerk_weight", 0.2)

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """
        Apply the reparameterization trick to sample z ~ N(mu, sigma^2) in a differentiable way.

        Args:
            mu (Tensor): Mean of latent Gaussian [batch, latent_dim]
            logvar (Tensor): Log-variance of latent Gaussian [batch, latent_dim]
        Returns:
            z (Tensor): Sampled latent vectors [batch, latent_dim]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std, device=self.device)
        return mu + eps * std

    def forward(self, x: torch.Tensor, cond: torch.Tensor):
        """
        Forward pass through CVAE: encode, sample, and decode.

        Args:
            x (Tensor): Original feature profiles [batch, input_dim]
            cond (Tensor): Condition vectors [batch, condition_dim]
        Returns:
            recon (Tensor): Reconstructed profiles [batch, input_dim]
            mu (Tensor), logvar (Tensor): Latent distribution parameters
        """
        # Encode to latent distribution parameters
        mu, logvar = self.encoder(x, cond)
        # Sample latent vectors via reparameterization
        z = self.reparameterize(mu, logvar)
        # Decode conditioned on z and conditions
        recon = self.decoder(z, cond)
        return recon, mu, logvar

    def loss(
        self,
        recon: torch.Tensor,
        x: torch.Tensor,
        mu: torch.Tensor,
        logvar: torch.Tensor,
        cond: torch.Tensor,
    ) -> dict:
        """
        Compute the composite loss combining reconstruction, KL divergence,
        and physics-based constraint penalties.

        Steps:
        1. Weighted MSE reconstruction loss based on time/energy weights
        2. KL divergence term to regularize latent space
        3. Physics model computes time, energy, jerk from recon
        4. Compute huber losses for time and energy targets
        5. Jerk penalty for exceeding comfort threshold

        Args:
            recon (Tensor): Decoded profiles [batch, input_dim]
            x (Tensor): Original profiles [batch, input_dim]
            mu, logvar (Tensor): Latent params [batch, latent_dim]
            cond (Tensor): Conditions [batch, condition_dim]
        Returns:
            dict: {"total": Tensor, "recon":, "kld":, "time":, "energy":, "jerk":}
        """
        # Extract dynamic weights if adaptive, else use base
        tw = cond[:, -1] if self.adaptive else self.base_time_w
        ew = cond[:, -2] if self.adaptive else self.base_energy_w

        # Reconstruction loss (MSE) weighted by time and energy preferences
        mse = F.mse_loss(recon, x, reduction="none")  # [batch, features]
        recon_loss = (mse * torch.stack([tw, ew, tw, ew], dim=1)).mean()

        # KL divergence
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()

        # Physics-based metrics on reconstructions
        time_pred, energy_pred, mean_jerk, max_jerk = self.physics_model(recon, cond)

        # Targets: ideal time and energy based on conditions
        speed, batt, dist, _, _ = cond.T
        ideal_time = dist / (speed + 1e-6)
        ideal_energy = batt * dist * 0.001  # placeholder conversion

        time_loss = F.huber_loss(time_pred, ideal_time)
        energy_loss = F.huber_loss(energy_pred, ideal_energy)
        jerk_loss = F.relu(max_jerk - self.max_jerk).mean()

        # Combine losses
        total_loss = (
            recon_loss
            + kld
            + tw.mean() * time_loss
            + ew.mean() * energy_loss
            + self.base_jerk_w * jerk_loss
        )
        return {
            "total": total_loss,
            "recon": recon_loss,
            "kld": kld,
            "time": time_loss,
            "energy": energy_loss,
            "jerk": jerk_loss,
        }


def train(model: nn.Module, dataloader, config: dict) -> None:
    """
    Train a Conditional Variational Autoencoder (CVAE) model with physics-informed loss.

    This routine runs the full training loop, including:
      1. Moving the model to the specified device.
      2. Iterating over epochs and batches.
      3. Performing forward and backward passes.
      4. Updating model parameters via AdamW.
      5. Scheduling the learning rate based on validation loss.

    Args:
        model (nn.Module): The CVAE instance to train.
        dataloader (DataLoader): Provides batches of dicts with keys:
            - "profile": Tensor of shape [batch, input_dim]
            - "conditions": Tensor of shape [batch, condition_dim]
        config (dict): Training configuration with entries:
            - "device": torch.device or string, e.g., "cuda" or "cpu"
            - "epochs": int, total number of epochs
            - "lr": float, initial learning rate
            - "weight_decay": float, L2 regularization coefficient
            - "batch_size": int, number of samples per batch
            - "scheduler_patience": int, epochs to wait for LR plateau
            - "scheduler_factor": float, LR reduction factor on plateau

    Returns:
        None: The function trains the model in place and prints progress.
    """
    # Move model to the requested compute device
    device = torch.device(config["device"])
    model.to(device)
    model.train()

    # Set up optimizer with weight decay for regularization
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config["lr"],
        weight_decay=config.get("weight_decay", 1e-5),
    )

    # Reduce LR when total loss has plateaued
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        patience=config.get("scheduler_patience", 5),
        factor=config.get("scheduler_factor", 0.5),
    )

    # Training loop
    for epoch in range(1, config["epochs"] + 1):
        epoch_loss = 0.0
        progress = tqdm(dataloader, desc=f"[Epoch {epoch}/{config['epochs']}]")

        for batch in progress:
            # Unpack and move inputs to device
            x = batch["profile"].to(device)  # [batch, input_dim]
            cond = batch["conditions"].to(device)  # [batch, condition_dim]

            # Zero gradients from previous step
            optimizer.zero_grad()

            # Forward pass through CVAE
            recon, mu, logvar = model(x, cond)

            # Compute all loss components
            losses = model.loss_function(recon, x, mu, logvar, cond)

            # Backward pass: gradients flow through reparameterization
            losses["total"].backward()

            # Gradient clipping for stability
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Parameter update
            optimizer.step()

            # Accumulate and display per-batch loss
            batch_loss = losses["total"].item()
            epoch_loss += batch_loss
            progress.set_postfix(total_loss=batch_loss / x.size(0))

        # Step the scheduler based on epoch’s total loss
        scheduler.step(epoch_loss)

        avg_loss = epoch_loss / len(dataloader.dataset)
        print(f"Epoch {epoch:>3}: Avg Loss = {avg_loss:.4f}")


def get_device() -> torch.device:
    """
    Utility to select the compute device.

    Returns:
        torch.device: 'cuda' if available, otherwise 'cpu'.
    """
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [36]:
def plot_latent(autoencoder, data, num_batches=100):
    for i, (x, y) in enumerate(data):
        z = autoencoder.encoder(x.to(get_device()))
        z = z.to("cpu").detach().numpy()
        plt.scatter(z[:, 0], z[:, 1], c=y, cmap="tab10")
        if i > num_batches:
            plt.colorbar()
            break

In [37]:
# Training

# Configuration Example
config = {
    # —— Model dimensions ——————————————————————————————————
    "input_dim": 200,  # length of accel+decel profile vector
    "output_dim": 200,  # same as input_dim
    "latent_dim": 32,  # size of z
    "condition_dim": 5,  # number of condition features (e.g. [speed, soc, distance, energy_w, time_w])
    # —— Training settings ——————————————————————————————
    "device": get_device(),
    "epochs": 100,
    "lr": 1e-4,
    "batch_size": 64,
    "weight_decay": 1e-5,  # for AdamW
    "beta": 1.0,  # KLD annealing coefficient (if you use a β–VAE schedule)
    "adaptive_weighting": True,  # enable condition‐driven time/energy weighting
    # —— Physics & Constraint hyperparameters —————————————
    "max_jerk": 2.5,  # m/s³ comfort threshold
    "max_energy": 10.0,  # kWh battery capacity
    "max_time": 120.0,  # minutes or seconds consistent with your loss
    # —— Loss base weights (if not adaptive) —————————————
    "base_time_weight": 0.5,
    "base_energy_weight": 0.5,
    "base_jerk_weight": 0.2,
    # —— Scheduler & Seed ——————————————————————————————
    "seed": 42,  # reproducibility
    "scheduler_patience": 5,  # for ReduceLROnPlateau
    "scheduler_factor": 0.5,  # LR reduction factor
}

cvae = CVAE(config)

# Create synthetic dataset
train_data = [...]  # Implement proper Dataset class
dataloader = DataLoader(train_data, batch_size=config["batch_size"], shuffle=True)

# Train model
train(cvae, dataloader, config)

# Save model
torch.save(cvae.state_dict(), "monocap_cvae.pth")


[Epoch 1/100]:   0%|          | 0/1 [00:00<?, ?it/s]


TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'ellipsis'>

## Visualize latent space

In [None]:
def plot_latent(autoencoder, data, num_batches=100):
    for i, (x, y) in enumerate(data):
        z = autoencoder.encoder(x.to(get_device()))
        z = z.to("cpu").detach().numpy()
        plt.scatter(z[:, 0], z[:, 1], c=y, cmap="tab10")
        if i > num_batches:
            plt.colorbar()
            break

## Anchor Extraction & Storage

In [None]:
class AnchorBuilder:
    def __init__(self, cvae: CVAE, dataset: Dataset):
        self.cvae = cvae.eval()
        self.dataset = dataset
        self.anchors = []  # list of (cond, mu)

    def build(self, device="cpu"):
        for batch in DataLoader(self.dataset, batch_size=64):
            x = batch["profile"].to(device)
            c = batch["conditions"].to(device)
            with torch.no_grad():
                mu, _ = self.cvae.encoder(x, c)
            for cond_vec, mu_vec in zip(c.cpu().numpy(), mu.cpu().numpy()):
                self.anchors.append((cond_vec, mu_vec))
        # Save anchors to disk for reuse
        np.save("anchors.npy", self.anchors)


## KD-Tree for Fast NN Search

In [None]:
from sklearn.neighbors import KDTree


class ConditionIndex:
    def __init__(self, anchors_path="anchors.npy"):
        self.anchors = np.load(anchors_path, allow_pickle=True)
        conds = np.stack([a[0] for a in self.anchors])
        self.tree = KDTree(conds)  # leaf_size defaults to 40

    def query(self, c_new, k=3):
        dists, idxs = self.tree.query([c_new], k=k)
        return [(self.anchors[i][1], dists[0][j]) for j, i in enumerate(idxs[0])]


In [None]:
class ConditionalPrior(nn.Module):
    """Train a lightweight MLP to predict encoder means directly from conditions, avoiding NN lookups"""

    def __init__(self, condition_dim, latent_dim, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(condition_dim, hidden), nn.ReLU(), nn.Linear(hidden, latent_dim)
        )

    def forward(self, c):
        return self.net(c)


prior = ConditionalPrior(cond_dim, latent_dim).to(device)
opt = torch.optim.Adam(prior.parameters(), lr=1e-3)
for epoch in range(50):
    for cond_vec, mu_vec in DataLoader(anchors_dataset, batch_size=64):
        pred = prior(cond_vec.to(device))
        loss = F.mse_loss(pred, mu_vec.to(device))
        loss.backward()
        opt.step()
        opt.zero_grad()


# Resources

- [1] [Autoencoder implementation in Pytorch](https://avandekleut.github.io/vae/)
- [2] [Deep dive into Conditional VAE](https://beckham.nz/2023/04/27/conditional-vaes.html?utm_source=chatgpt.com#sec_derivation)
- [3] [How to implement conditional VAE](https://www.linkedin.com/advice/1/how-do-you-implement-conditional-vae-what-benefits?utm_source=chatgpt.com)
- [4] [What about the conditional variational autoencoder?](https://creatis-myriad.github.io/tutorials/2022-09-12-tutorial-cvae.html?utm_source=chatgpt.com)

- [5] [Youtub : Variational Auto Encoder (VAE) - Theory](https://www.youtube.com/watch?v=vJo7hiMxbQ8)