# Question: RL for PID Auto-Tuning in Multi-Loop Systems (Offline)

## Goal:

Use **offline RL** to **learn a policy** that:
* Takes system context (plant dynamics, recent states, error history, etc. )
* Outputs a full set of optimal PID parameters $K = [K_p^1, K_i^1, K_d^1, ..., K_p^N, K_i^N, K_d^N]$

## Challenges:

* High-dimensional action space for large N loops
* Inter-loop coupling -> optimal gains depend on other loops
* Mutli-modal solutions: multiple gain sets may perform similarly well
* Offline-only data -> can't interact online to improve

## Strengths

1. Multimodal action modeling:
    - many different gain sets can yield good control; diffusion handles this better than Gaussian policies
2. Smooth action trajectories:
    - outputting sequences of gains (e.g., adapting gains over a window) benefits from smooth, coherent generation
3. Flexible conditioning:
    - Can condition on error curves, setpoint changes, prior performance - allowing fine-grained gain generation
4. Action space scalability: 
    - Diffusion models scale well to large output spaces (hundreds of dimensions), which suits multi-loop systems. 
5. Offline training performance:
    - diffusion policies outperform traditional BC in many offline scenarios by better modeling expert behavior. 


## Challenges: 

1. Inference speed:
    - If gains must be tuned frequently (e.g., per setpoint change), slow sampling may hinder real-time applicability
2. Interpretability:
    - we care why a certain gain set was chosen - diffusion policies are harder to interpret unless distilled
3. Reward alignment:
    - Training depends on how clearly performance (e.g., overshoot, settling time) is reflected in data
4. Sparse or biased datasets:
    - If most offline data uses "safe but suboptimal" gains, the learned model may not generalize to better ones. 

# Design considerations: Diffusion-Based PID Auto-Tuner

**Inputs**:

* Current plant state and recent history $x_{t-w:t}$
* Setpoint $r_t$ or setpoint change trajectory
* Past controller performance metrics (e.g., IAE, overshoot)


**Output**:
* Full vector of PID gains for all loops

**Training Flow**:
1. Collect offline dataset of:
    $$ (context, K_{used}, performance) $$
2. Corrupt PID gains with noise over T steps
3. Train diffusion model to denoise PID gains, conditioned on context
4. At inference, sample noise, denoise it based on new plant context -> get tuned PID

# Compared with other RL algorithms:

1. SAC / TD3 (continuous RL): can learn good gains, but struggles with mutlimodality and large joint action spaces
2. Decision Transformer (DT): can learn to condition on performance goals and trajectories; still needs large datasets and reward shaping
3. Diffusion policy: Excels in multimodal, high-dimensional action spaces - a very promising fit for offline PID gain learning

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import tqdm
import gymnasium as gym
import copy
from collections import defaultdict

# Load the CSTR environment
# Make sure CSTR_model_plus.py is in the same directory
from CSTR_model_plus import CSTRRLEnv

# Set random seeds for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

#######################
# Data Generation
#######################

def generate_pid_dataset(num_episodes=1000, time_steps=100, save_path='pid_dataset.pt'):
    """Generate a dataset of PID gain trajectories and their performance on the CSTR system.
    
    Args:
        num_episodes: Number of episodes (different PID settings) to generate
        time_steps: Number of time steps per episode
        save_path: Path to save the dataset
        
    Returns:
        dataset: Dictionary containing the dataset
    """
    # Initialize environment
    env = CSTRRLEnv(simulation_steps=time_steps, 
                    uncertainty_level=0.1,
                    noise_level=0.02,
                    actuator_delay_steps=1,
                    transport_delay_steps=2,
                    enable_disturbances=True)
    
    # Lower and upper bounds for PID gains
    pid_lower = np.array([-5, 0, 0.02, 0, 0, 0.01])
    pid_upper = np.array([25, 20, 10, 1, 2, 1])
    
    dataset = {
        "state_traj": [],           # State trajectory
        "error_traj": [],           # Error trajectory  
        "setpoint_traj": [],        # Setpoint trajectory
        "pid_gains": [],            # PID gains (actions)
        "performance": [],          # Performance metrics
        "metadata": {               # Metadata
            "pid_lower": pid_lower.tolist(),
            "pid_upper": pid_upper.tolist(),
            "time_steps": time_steps,
            "num_episodes": num_episodes,
        }
    }
    
    print(f"Generating {num_episodes} episodes...")
    
    for episode in tqdm.tqdm(range(num_episodes)):
        # Reset the environment with different setpoints for better exploration
        # Create varying setpoints for Cb and V
        setpoints_Cb = [0.5 + 0.5 * np.random.rand(), 
                        0.5 + 0.5 * np.random.rand()]
        setpoints_V = [99.0 + 2.0 * np.random.rand(), 
                      99.0 + 2.0 * np.random.rand()]
        setpoint_durations = [time_steps // 2, time_steps // 2 + time_steps % 2]
        
        options = {
            'setpoints_Cb': setpoints_Cb,
            'setpoints_V': setpoints_V,
            'setpoint_durations': setpoint_durations
        }
        
        obs, _ = env.reset(seed=episode, options=options)
        
        # Sample random PID gains (normalized between -1 and 1 for the action space)
        # We'll have a mix of:
        # 1. Random gains (50%)
        # 2. Expert-tuned gains with noise (50%)
        
        if np.random.rand() < 0.5:
            # Random gains
            action = np.random.uniform(-1, 1, size=6)
        else:
            # Start with some good PID values and add noise
            # These values are heuristics that work reasonably well for this system
            good_pid_normalized = np.array([0.5, 0.3, 0.2, 0.4, 0.3, 0.2])  # Normalized good PID values
            # Convert from normalized to actual
            good_pid = ((good_pid_normalized + 1) / 2) * (pid_upper - pid_lower) + pid_lower
            # Add noise
            good_pid *= (1 + 0.3 * np.random.randn(6))
            # Convert back to normalized
            action = 2 * (good_pid - pid_lower) / (pid_upper - pid_lower) - 1
            # Clip to valid range
            action = np.clip(action, -1, 1)
        
        # Storage for trajectories
        state_traj = []
        error_traj = []
        setpoint_traj = []
        
        # Run the episode with fixed PID gains
        done = False
        step = 0
        cumulative_reward = 0
        
        # Statistics for performance metrics
        abs_Cb_errors = []
        abs_V_errors = []
        Cb_setpoint_changes = []
        V_setpoint_changes = []
        controller_efforts = []
        
        while not done:
            # Store trajectory information
            # Extract values from observation
            current_Cb = obs[0]  # Current Cb is at index 0
            current_T = obs[1]   # Current T is at index 1
            current_V = obs[2]   # Current V is at index 2
            
            current_setpoint_Cb = obs[9]   # Current setpoint Cb is at index 9
            current_setpoint_V = obs[10]   # Current setpoint V is at index 10
            
            # Calculate errors
            error_Cb = current_setpoint_Cb - current_Cb
            error_V = current_setpoint_V - current_V
            
            # Store trajectories
            state_traj.append([current_Cb, current_T, current_V])
            error_traj.append([error_Cb, error_V])
            setpoint_traj.append([current_setpoint_Cb, current_setpoint_V])
            
            # Take action (use the same PID gains throughout the episode)
            next_obs, reward, terminated, truncated, info = env.step(action)
            
            # Track performance metrics
            abs_Cb_errors.append(abs(error_Cb))
            abs_V_errors.append(abs(error_V))
            
            # Track control effort (changes in control signals)
            if step > 0:
                Tc_change = abs(info["control_action"][0] - prev_control_action[0])
                Fin_change = abs(info["control_action"][1] - prev_control_action[1])
                controller_efforts.append([Tc_change, Fin_change])
            
            prev_control_action = info["control_action"]
            
            # Track setpoint changes
            if step > 0:
                Cb_setpoint_change = abs(current_setpoint_Cb - prev_setpoint_Cb)
                V_setpoint_change = abs(current_setpoint_V - prev_setpoint_V)
                if Cb_setpoint_change > 0:
                    Cb_setpoint_changes.append(step)
                if V_setpoint_change > 0:
                    V_setpoint_changes.append(step)
            
            prev_setpoint_Cb = current_setpoint_Cb
            prev_setpoint_V = current_setpoint_V
            
            # Update for next iteration
            obs = next_obs
            cumulative_reward += reward
            step += 1
            done = terminated or truncated
        
        # Calculate performance metrics
        iae_Cb = np.sum(abs_Cb_errors)  # Integral Absolute Error for Cb
        iae_V = np.sum(abs_V_errors)    # Integral Absolute Error for V
        
        # Calculate rise time and settling time for each setpoint change
        rise_times = []
        settling_times = []
        
        # Calculate mean control effort
        mean_control_effort_Tc = np.mean([e[0] for e in controller_efforts]) if controller_efforts else 0
        mean_control_effort_Fin = np.mean([e[1] for e in controller_efforts]) if controller_efforts else 0
        
        # Overall performance score (lower is better)
        performance_score = iae_Cb * 0.7 + iae_V * 0.2 + mean_control_effort_Tc * 0.05 + mean_control_effort_Fin * 0.05
        
        # Calculate overshoot
        overshoots = []
        
        # Store all the information in the dataset
        dataset["state_traj"].append(np.array(state_traj))
        dataset["error_traj"].append(np.array(error_traj))
        dataset["setpoint_traj"].append(np.array(setpoint_traj))
        dataset["pid_gains"].append(action)  # Store the normalized action
        
        # Store performance metrics
        dataset["performance"].append({
            "iae_Cb": iae_Cb,
            "iae_V": iae_V,
            "mean_control_effort_Tc": mean_control_effort_Tc,
            "mean_control_effort_Fin": mean_control_effort_Fin,
            "performance_score": performance_score,
            "cumulative_reward": cumulative_reward
        })
    
    # Convert lists to numpy arrays
    dataset["pid_gains"] = np.array(dataset["pid_gains"])
    
    # Save the dataset
    torch.save(dataset, save_path)
    print(f"Dataset saved to {save_path}")
    
    return dataset


def filter_dataset(dataset, performance_threshold=0.7):
    """Filter the dataset based on performance metrics to keep only good examples.
    
    Args:
        dataset: Original dataset
        performance_threshold: Threshold percentile for filtering (0.7 means keep the top 30%)
        
    Returns:
        filtered_dataset: Filtered dataset
    """
    # Get performance scores
    performance_scores = np.array([p["performance_score"] for p in dataset["performance"]])
    
    # Lower score is better, so we want to keep the bottom percentage
    threshold = np.percentile(performance_scores, performance_threshold * 100)
    
    # Find indices of episodes to keep
    keep_indices = np.where(performance_scores <= threshold)[0]
    
    # Create filtered dataset
    filtered_dataset = {
        "state_traj": [dataset["state_traj"][i] for i in keep_indices],
        "error_traj": [dataset["error_traj"][i] for i in keep_indices],
        "setpoint_traj": [dataset["setpoint_traj"][i] for i in keep_indices],
        "pid_gains": dataset["pid_gains"][keep_indices],
        "performance": [dataset["performance"][i] for i in keep_indices],
        "metadata": dataset["metadata"]
    }
    
    print(f"Filtered dataset: kept {len(keep_indices)} out of {len(dataset['pid_gains'])} episodes ({performance_threshold * 100:.1f}% best performers)")
    
    return filtered_dataset


class PIDDataset(Dataset):
    """PyTorch Dataset for PID controller tuning."""
    
    def __init__(self, dataset, window_size=10):
        """
        Args:
            dataset: Dictionary containing the dataset
            window_size: Size of the sliding window for state/error/setpoint trajectories
        """
        self.state_traj = dataset["state_traj"]
        self.error_traj = dataset["error_traj"]
        self.setpoint_traj = dataset["setpoint_traj"]
        self.pid_gains = torch.tensor(dataset["pid_gains"], dtype=torch.float32)
        self.window_size = window_size
        
        # Metadata
        self.pid_lower = torch.tensor(dataset["metadata"]["pid_lower"], dtype=torch.float32)
        self.pid_upper = torch.tensor(dataset["metadata"]["pid_upper"], dtype=torch.float32)
        
    def __len__(self):
        return len(self.pid_gains)
    
    def __getitem__(self, idx):
        # Get trajectories
        state = self.state_traj[idx]
        error = self.error_traj[idx]
        setpoint = self.setpoint_traj[idx]
        
        # Get random window if trajectory is longer than window_size
        traj_len = len(state)
        if traj_len > self.window_size:
            start_idx = np.random.randint(0, traj_len - self.window_size)
            end_idx = start_idx + self.window_size
            
            state_window = state[start_idx:end_idx]
            error_window = error[start_idx:end_idx]
            setpoint_window = setpoint[start_idx:end_idx]
        else:
            # Pad with zeros if trajectory is shorter than window_size
            state_window = np.pad(state, ((0, self.window_size - traj_len), (0, 0)), mode='constant')
            error_window = np.pad(error, ((0, self.window_size - traj_len), (0, 0)), mode='constant')
            setpoint_window = np.pad(setpoint, ((0, self.window_size - traj_len), (0, 0)), mode='constant')
        
        # Convert to tensors
        state_tensor = torch.tensor(state_window, dtype=torch.float32)
        error_tensor = torch.tensor(error_window, dtype=torch.float32)
        setpoint_tensor = torch.tensor(setpoint_window, dtype=torch.float32)
        pid_gains = self.pid_gains[idx]
        
        return {
            "state_traj": state_tensor,
            "error_traj": error_tensor,
            "setpoint_traj": setpoint_tensor,
            "pid_gains": pid_gains
        }


#######################
# Model Architecture
#######################

class TimeEmbedding(nn.Module):
    """Sinusoidal time embedding for diffusion timesteps."""
    
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        
    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = t[:, None] * embeddings[None, :]
        embeddings = torch.cat((torch.sin(embeddings), torch.cos(embeddings)), dim=-1)
        
        # Zero-pad if dim is odd
        if self.dim % 2 == 1:
            embeddings = F.pad(embeddings, (0, 1, 0, 0))
            
        return embeddings


class TrajectoryEncoder(nn.Module):
    """Encoder for trajectory data to provide context for the diffusion model."""
    
    def __init__(self, state_dim=3, error_dim=2, setpoint_dim=2, hidden_dim=128, context_dim=128, window_size=10):
        super().__init__()
        self.window_size = window_size
        
        # Input dimensions for combined trajectory features
        input_dim = state_dim + error_dim + setpoint_dim
        
        # LSTM to process the sequence
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            dropout=0.1
        )
        
        # Projection to context embedding
        self.context_projection = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, context_dim)
        )
        
    def forward(self, state_traj, error_traj, setpoint_traj):
        batch_size = state_traj.shape[0]
        
        # Concatenate all features along feature dimension
        # Shape: [batch_size, window_size, state_dim + error_dim + setpoint_dim]
        combined_input = torch.cat([state_traj, error_traj, setpoint_traj], dim=2)
        
        # Process sequence with LSTM
        lstm_out, (hidden, _) = self.lstm(combined_input)
        
        # Use the final hidden state
        # Shape: [batch_size, hidden_dim]
        final_hidden = hidden[-1]
        
        # Project to context embedding
        # Shape: [batch_size, context_dim]
        context = self.context_projection(final_hidden)
        
        return context


class DiffusionBlock(nn.Module):
    """Basic building block for the diffusion model."""
    
    def __init__(self, hidden_dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.GELU(),
            nn.Linear(hidden_dim * 2, hidden_dim)
        )
        self.layer_norm = nn.LayerNorm(hidden_dim)
        
    def forward(self, x):
        return self.layer_norm(x + self.block(x))


class DiffusionModel(nn.Module):
    """Diffusion model for predicting PID gains."""
    
    def __init__(self, pid_dim=6, time_dim=128, context_dim=128, hidden_dim=256, depth=6):
        super().__init__()
        self.pid_dim = pid_dim
        
        # Time embedding
        self.time_embedding = TimeEmbedding(time_dim)
        
        # Initial projection
        self.input_projection = nn.Linear(pid_dim, hidden_dim)
        
        # Combined embedding of time and context
        self.combined_projection = nn.Sequential(
            nn.Linear(time_dim + context_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Main diffusion blocks
        self.diffusion_blocks = nn.ModuleList([
            DiffusionBlock(hidden_dim) for _ in range(depth)
        ])
        
        # Final output layer
        self.output_projection = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, pid_dim)
        )
        
    def forward(self, x, t, context):
        """
        Args:
            x: Noisy PID gains [batch_size, pid_dim]
            t: Diffusion timesteps [batch_size]
            context: Context embeddings [batch_size, context_dim]
        """
        # Time embedding
        t_emb = self.time_embedding(t)
        
        # Combine time and context
        cond_emb = torch.cat([t_emb, context], dim=1)
        cond = self.combined_projection(cond_emb)
        
        # Initial projection of noisy PID gains
        h = self.input_projection(x)
        
        # Apply diffusion blocks with conditioning
        for block in self.diffusion_blocks:
            h = h + cond  # Add the conditioning at each step
            h = block(h)
        
        # Output projection
        output = self.output_projection(h)
        
        return output


#######################
# Diffusion Process
#######################

def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
    """Linear schedule for variance of the forward process."""
    return torch.linspace(beta_start, beta_end, timesteps)


class DiffusionTrainer:
    """Trainer for the diffusion model."""
    
    def __init__(self, timesteps=1000):
        # Set up the noise schedule
        self.timesteps = timesteps
        self.betas = linear_beta_schedule(timesteps)
        
        # Pre-compute values for sampling and training
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # Calculations for diffusion q(x_t | x_{t-1})
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        
        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
        
    def q_sample(self, x_0, t, noise=None):
        """Forward diffusion process: q(x_t | x_0)."""
        if noise is None:
            noise = torch.randn_like(x_0)
            
        # Get the scheduled values for the given timestep
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t]
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t]
        
        # Reshape for broadcasting
        sqrt_alphas_cumprod_t = sqrt_alphas_cumprod_t.view(-1, 1)
        sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod_t.view(-1, 1)
        
        # Forward process
        return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
    
    def p_sample(self, model, x_t, t, context):
        """Single step of reverse diffusion process: p(x_{t-1} | x_t)."""
        # Get model prediction (predicted noise)
        pred_noise = model(x_t, t, context)
        
        # Get alpha values for timestep t
        alpha = self.alphas[t]
        alpha_cumprod = self.alphas_cumprod[t]
        alpha_cumprod_prev = self.alphas_cumprod_prev[t]
        beta = self.betas[t]
        
        # Reshape for broadcasting
        alpha_view = alpha.view(-1, 1)
        beta_view = beta.view(-1, 1)
        
        # Get posterior mean
        pred_x0 = (x_t - (1 - alpha_view).sqrt() * pred_noise) / alpha_view.sqrt()
        posterior_mean = (
            (alpha_cumprod_prev.sqrt() / (1. - alpha_cumprod).sqrt()) * beta_view * pred_x0 +
            ((1. - alpha_cumprod_prev).sqrt() / (1. - alpha_cumprod).sqrt()) * (1. - beta_view) * x_t
        )
        
        # Get posterior variance
        posterior_variance = beta_view * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)
        posterior_log_variance = torch.log(posterior_variance)
        
        # Add noise scaled by the variance
        noise = torch.randn_like(x_t) if t[0] > 0 else torch.zeros_like(x_t)
        
        # Sample from the posterior
        x_t_prev = posterior_mean + torch.exp(0.5 * posterior_log_variance) * noise
        
        return x_t_prev
    
    def p_sample_loop(self, model, context, shape, device, verbose=True):
        """Full reverse diffusion from noise to data."""
        # Start from pure noise
        img = torch.randn(shape).to(device)
        
        # Iteratively denoise
        iterator = range(self.timesteps - 1, -1, -1)
        if verbose:
            iterator = tqdm.tqdm(iterator, desc="Sampling")
            
        for i in iterator:
            # Same timestep for entire batch
            t = torch.full((shape[0],), i, device=device, dtype=torch.long)
            
            # Apply single denoising step
            with torch.no_grad():
                img = self.p_sample(model, img, t, context)
                
        return img
    
    def train_step(self, model, optimizer, batch, device):
        """Single training step."""
        # Extract data
        state_traj = batch["state_traj"].to(device)
        error_traj = batch["error_traj"].to(device)
        setpoint_traj = batch["setpoint_traj"].to(device)
        pid_gains = batch["pid_gains"].to(device)
        
        # Get context
        context = model.trajectory_encoder(state_traj, error_traj, setpoint_traj)
        
        # Sample timestep
        batch_size = pid_gains.shape[0]
        t = torch.randint(0, self.timesteps, (batch_size,), device=device).long()
        
        # Sample noise
        noise = torch.randn_like(pid_gains)
        
        # Forward diffusion
        x_t = self.q_sample(pid_gains, t, noise)
        
        # Predict noise
        pred_noise = model(x_t, t, context)
        
        # Calculate loss
        loss = F.mse_loss(pred_noise, noise)
        
        # Optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        return loss.item()


#######################
# Complete Diffusion Policy
#######################

class DiffusionPolicy(nn.Module):
    """Complete diffusion policy for PID controller tuning."""
    
    def __init__(self, pid_dim=6, state_dim=3, error_dim=2, setpoint_dim=2, 
                 time_dim=128, context_dim=128, hidden_dim=256, 
                 depth=6, window_size=10):
        super().__init__()
        
        # Trajectory encoder
        self.trajectory_encoder = TrajectoryEncoder(
            state_dim=state_dim,
            error_dim=error_dim,
            setpoint_dim=setpoint_dim,
            hidden_dim=hidden_dim,
            context_dim=context_dim,
            window_size=window_size
        )
        
        # Diffusion core model
        self.diffusion_core = DiffusionModel(
            pid_dim=pid_dim,
            time_dim=time_dim,
            context_dim=context_dim,
            hidden_dim=hidden_dim,
            depth=depth
        )
    
    def forward(self, x, t, context):
        """Forward pass for the diffusion model."""
        return self.diffusion_core(x, t, context)


#######################
# Training Function
#######################

def train_diffusion_policy(dataset, num_epochs=50, batch_size=32, lr=3e-4, 
                           timesteps=1000, save_path='diffusion_policy.pt',
                           window_size=10):
    """Train the diffusion policy on the dataset.
    
    Args:
        dataset: Dictionary containing the dataset
        num_epochs: Number of training epochs
        batch_size: Batch size for training
        lr: Learning rate
        timesteps: Number of diffusion timesteps
        save_path: Path to save the model
        window_size: Size of the trajectory window
        
    Returns:
        model: Trained model
        trainer: Diffusion trainer
        train_losses: List of training losses
    """
    # Create dataset and dataloader
    train_dataset = PIDDataset(dataset, window_size=window_size)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    # Get device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Create model and optimizer
    model = DiffusionPolicy(pid_dim=6, window_size=window_size).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    
    # Create diffusion trainer
    trainer = DiffusionTrainer(timesteps=timesteps)
    
    # Training loop
    train_losses = []
    best_loss = float('inf')
    
    for epoch in range(num_epochs):
        epoch_losses = []
        
        # Training
        model.train()
        with tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
            for batch in pbar:
                loss = trainer.train_step(model, optimizer, batch, device)
                epoch_losses.append(loss)
                pbar.set_postfix({"loss": loss})
        
        # Calculate average loss
        avg_loss = np.mean(epoch_losses)
        train_losses.append(avg_loss)
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")
        
        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch,
                'loss': best_loss,
                'timesteps': timesteps,
                'window_size': window_size,
            }, save_path)
            print(f"Model saved to {save_path}")
    
    return model, trainer, train_losses


#######################
# Evaluation Functions
#######################
def evaluate_pid_gains(env, pid_gains, num_episodes=5, render=False):
    """Evaluate PID gains on the environment.
    
    Args:
        env: Environment instance
        pid_gains: PID gains to evaluate (normalized between -1 and 1)
        num_episodes: Number of evaluation episodes
        render: Whether to render the environment
        
    Returns:
        metrics: Dictionary of evaluation metrics
    """
    # Storage for metrics
    all_metrics = defaultdict(list)
    
    for episode in range(num_episodes):
        # Reset environment with different setpoints each time
        setpoints_Cb = [0.5 + 0.5 * np.random.rand(), 
                       0.5 + 0.5 * np.random.rand()]
        setpoints_V = [99.0 + 2.0 * np.random.rand(), 
                      99.0 + 2.0 * np.random.rand()]
        setpoint_durations = [env.sim_steps // 2, env.sim_steps // 2 + env.sim_steps % 2]
        
        options = {
            'setpoints_Cb': setpoints_Cb,
            'setpoints_V': setpoints_V,
            'setpoint_durations': setpoint_durations
        }
        
        obs, _ = env.reset(seed=episode*100, options=options)
        
        # Storage for episode metrics
        abs_Cb_errors = []
        abs_V_errors = []
        Cb_values = []
        V_values = []
        T_values = []
        setpoint_Cb_values = []
        setpoint_V_values = []
        control_Tc_values = []
        control_Fin_values = []
        
        # Run episode with fixed PID gains
        done = False
        step = 0
        cumulative_reward = 0
        max_overshoot_Cb = 0
        max_overshoot_V = 0
        
        # For tracking setpoint changes
        last_setpoint_Cb = None
        last_setpoint_V = None
        setpoint_change_steps_Cb = []
        setpoint_change_steps_V = []
        
        while not done:
            # Take action (use the same PID gains throughout the episode)
            next_obs, reward, terminated, truncated, info = env.step(pid_gains)
            
            # Get current values
            current_Cb = info["true_state"][1]
            current_T = info["true_state"][3]
            current_V = info["true_state"][4]
            current_setpoint_Cb = info["setpoint_Cb"]
            current_setpoint_V = info["setpoint_V"]
            
            # Track setpoint changes
            if last_setpoint_Cb is not None and abs(current_setpoint_Cb - last_setpoint_Cb) > 0.01:
                setpoint_change_steps_Cb.append(step)
            if last_setpoint_V is not None and abs(current_setpoint_V - last_setpoint_V) > 0.01:
                setpoint_change_steps_V.append(step)
                
            last_setpoint_Cb = current_setpoint_Cb
            last_setpoint_V = current_setpoint_V
            
            # Calculate errors
            error_Cb = current_setpoint_Cb - current_Cb
            error_V = current_setpoint_V - current_V
            
            # Store values
            abs_Cb_errors.append(abs(error_Cb))
            abs_V_errors.append(abs(error_V))
            Cb_values.append(current_Cb)
            V_values.append(current_V)
            T_values.append(current_T)
            setpoint_Cb_values.append(current_setpoint_Cb)
            setpoint_V_values.append(current_setpoint_V)
            control_Tc_values.append(info["control_action"][0])
            control_Fin_values.append(info["control_action"][1])
            
            # Calculate overshoot if we've had a setpoint change
            if len(setpoint_change_steps_Cb) > 0:
                # For each setpoint change
                for change_step in setpoint_change_steps_Cb:
                    # Only consider steps after the change
                    if step > change_step:
                        # Calculate overshoot relative to setpoint
                        # Overshoot is when we go beyond the setpoint
                        error = current_setpoint_Cb - current_Cb
                        # If error is negative, we've overshot
                        if error < 0:
                            overshoot = abs(error) / current_setpoint_Cb * 100
                            max_overshoot_Cb = max(max_overshoot_Cb, overshoot)
            
            # Same for volume
            if len(setpoint_change_steps_V) > 0:
                for change_step in setpoint_change_steps_V:
                    if step > change_step:
                        error = current_setpoint_V - current_V
                        if error < 0:
                            overshoot = abs(error) / current_setpoint_V * 100
                            max_overshoot_V = max(max_overshoot_V, overshoot)
            
            # Update for next iteration
            obs = next_obs
            cumulative_reward += reward
            step += 1
            done = terminated or truncated
            
            # Render if required
            if render and episode == 0:
                env.render()
        
        # Calculate metrics
        # IAE - Integral Absolute Error
        iae_Cb = np.sum(abs_Cb_errors)
        iae_V = np.sum(abs_V_errors)
        
        # ISE - Integral Squared Error
        ise_Cb = np.sum(np.array(abs_Cb_errors)**2)
        ise_V = np.sum(np.array(abs_V_errors)**2)
        
        # Mean error
        mean_error_Cb = np.mean(abs_Cb_errors)
        mean_error_V = np.mean(abs_V_errors)
        
        # Control effort (variability in control signals)
        control_effort_Tc = np.sum(np.abs(np.diff(control_Tc_values)))
        control_effort_Fin = np.sum(np.abs(np.diff(control_Fin_values)))
        
        # Rise time and settling time
        rise_times_Cb = []
        settling_times_Cb = []
        rise_times_V = []
        settling_times_V = []
        
        # Calculate rise and settling times for each setpoint change
        for change_step in setpoint_change_steps_Cb:
            if change_step + 1 < len(Cb_values):
                # Get the new setpoint value
                new_setpoint = setpoint_Cb_values[change_step]
                start_value = Cb_values[change_step]
                
                # Calculate the change
                total_change = new_setpoint - start_value
                
                # Only process significant changes
                if abs(total_change) > 0.01:
                    # Find rise time (time to reach 90% of the change)
                    target_90 = start_value + 0.9 * total_change
                    rise_time = float('inf')
                    
                    for i in range(change_step + 1, len(Cb_values)):
                        if (total_change > 0 and Cb_values[i] >= target_90) or \
                           (total_change < 0 and Cb_values[i] <= target_90):
                            rise_time = i - change_step
                            break
                            
                    # Find settling time (time to stay within 5% of setpoint)
                    settling_band = 0.05 * abs(total_change)
                    settling_time = float('inf')
                    settled = False
                    
                    for i in range(change_step + 1, len(Cb_values)):
                        if abs(Cb_values[i] - new_setpoint) <= settling_band:
                            # Check if it stays in the band
                            if i + 5 < len(Cb_values):
                                if all(abs(Cb_values[j] - new_setpoint) <= settling_band for j in range(i, i+5)):
                                    settling_time = i - change_step
                                    settled = True
                                    break
                    
                    if rise_time < float('inf'):
                        rise_times_Cb.append(rise_time)
                    if settled:
                        settling_times_Cb.append(settling_time)
        
        # Similar calculation for Volume
        for change_step in setpoint_change_steps_V:
            if change_step + 1 < len(V_values):
                new_setpoint = setpoint_V_values[change_step]
                start_value = V_values[change_step]
                total_change = new_setpoint - start_value
                
                if abs(total_change) > 0.1:  # Higher threshold for volume
                    target_90 = start_value + 0.9 * total_change
                    rise_time = float('inf')
                    
                    for i in range(change_step + 1, len(V_values)):
                        if (total_change > 0 and V_values[i] >= target_90) or \
                           (total_change < 0 and V_values[i] <= target_90):
                            rise_time = i - change_step
                            break
                            
                    settling_band = 0.05 * abs(total_change)
                    settling_time = float('inf')
                    settled = False
                    
                    for i in range(change_step + 1, len(V_values)):
                        if abs(V_values[i] - new_setpoint) <= settling_band:
                            if i + 5 < len(V_values):
                                if all(abs(V_values[j] - new_setpoint) <= settling_band for j in range(i, i+5)):
                                    settling_time = i - change_step
                                    settled = True
                                    break
                    
                    if rise_time < float('inf'):
                        rise_times_V.append(rise_time)
                    if settled:
                        settling_times_V.append(settling_time)
        
        # Calculate mean rise and settling times
        mean_rise_time_Cb = np.mean(rise_times_Cb) if rise_times_Cb else float('inf')
        mean_settling_time_Cb = np.mean(settling_times_Cb) if settling_times_Cb else float('inf')
        mean_rise_time_V = np.mean(rise_times_V) if rise_times_V else float('inf')
        mean_settling_time_V = np.mean(settling_times_V) if settling_times_V else float('inf')
        
        # Store all metrics
        all_metrics["iae_Cb"].append(iae_Cb)
        all_metrics["iae_V"].append(iae_V)
        all_metrics["ise_Cb"].append(ise_Cb)
        all_metrics["ise_V"].append(ise_V)
        all_metrics["mean_error_Cb"].append(mean_error_Cb)
        all_metrics["mean_error_V"].append(mean_error_V)
        all_metrics["control_effort_Tc"].append(control_effort_Tc)
        all_metrics["control_effort_Fin"].append(control_effort_Fin)
        all_metrics["mean_rise_time_Cb"].append(mean_rise_time_Cb)
        all_metrics["mean_settling_time_Cb"].append(mean_settling_time_Cb)
        all_metrics["mean_rise_time_V"].append(mean_rise_time_V)
        all_metrics["mean_settling_time_V"].append(mean_settling_time_V)
        all_metrics["max_overshoot_Cb"].append(max_overshoot_Cb)
        all_metrics["max_overshoot_V"].append(max_overshoot_V)
        all_metrics["cumulative_reward"].append(cumulative_reward)
        
        # Store trajectories for one episode
        if episode == 0:
            all_metrics["Cb_values"] = Cb_values
            all_metrics["V_values"] = V_values
            all_metrics["T_values"] = T_values
            all_metrics["setpoint_Cb_values"] = setpoint_Cb_values
            all_metrics["setpoint_V_values"] = setpoint_V_values
            all_metrics["control_Tc_values"] = control_Tc_values
            all_metrics["control_Fin_values"] = control_Fin_values
    
    # Calculate average metrics
    avg_metrics = {}
    for key, values in all_metrics.items():
        if isinstance(values, list) and key not in ["Cb_values", "V_values", "T_values", 
                                                   "setpoint_Cb_values", "setpoint_V_values",
                                                   "control_Tc_values", "control_Fin_values"]:
            avg_metrics[key] = np.mean(values)
    
    # Add trajectories to avg_metrics
    for key in ["Cb_values", "V_values", "T_values", "setpoint_Cb_values", "setpoint_V_values",
               "control_Tc_values", "control_Fin_values"]:
        if key in all_metrics:
            avg_metrics[key] = all_metrics[key]
    
    return avg_metrics