Neural SDE on Synthetic Data

In [1]:
# Display the Python code block below
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
# --- SDE Import ---
try:
    import torchsde
except ImportError:
    print("Error: torchsde not found. Please install it: pip install torchsde")
    exit()
# --- ODE Solver (Keep for comparison/reference if needed, but not used for SDE) ---
from scipy.integrate import odeint as scipy_odeint # For generating synthetic data
# Metrics and plotting will be used in later sections
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt

# --- 0. Configuration & Setup ---
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Device setup
# Check for MPS (Apple Silicon GPU) first, then CUDA, then CPU
#if torch.backends.mps.is_available():
#    device = torch.device("mps")
#elif torch.cuda.is_available():
#    device = torch.device("cuda")
#else:
#    device = torch.device("cpu")
device = torch.device("cpu")
print(f"Using device: {device}")


# Hyperparameters for Neural SDE Training (Combining stability focus with user's structure)
LEARNING_RATE = 1e-2     # Keep stability-focused LR
NUM_EPOCHS = 1000        # Adopted from user code
HIDDEN_DIM_NN = 32
SDE_SOLVER_METHOD = 'srk'  # Stochastic Runge-Kutta
SDE_SOLVER_DT = 1e-2     # Keep stability-focused dt
NOISE_TYPE = "diagonal"  
BROWNIAN_SIZE = 2        # Dimension of Brownian motion (must match state_dim for diagonal)
PRINT_EVERY = 10        # Adopted from user code
WEIGHT_DECAY = 1e-6
GRADIENT_CLIP_MAX_NORM = 1.0 # Max norm for gradient clipping

# Noise Option - ADDED from user code
ADD_NOISE = True         # Set to False to train on clean data
NOISE_LEVEL_SYNTH = 0.10  # Noise magnitude if ADD_NOISE is True

# Parameters for Synthetic Data Generation
TRUE_ALPHA = 1.5
TRUE_BETA = 1.0
TRUE_GAMMA = 2.0
TRUE_DELTA = 0.5
TRUE_PARAMS = [TRUE_ALPHA, TRUE_BETA, TRUE_GAMMA, TRUE_DELTA]
U0_SYNTH = [8.0, 3.0]
T_START = 0
T_END = 10
N_POINTS = 200

# --- 1a. Generate Synthetic Lotka-Volterra Data ---
print("\n--- Generating Synthetic Lotka-Volterra Data ---")
def lotka_volterra(y, t, alpha, beta, gamma, delta):
    prey, predator = y
    d_prey = alpha * prey - beta * prey * predator
    d_predator = delta * prey * predator - gamma * predator
    return [d_prey, d_predator]

t_synth_np = np.linspace(T_START, T_END, N_POINTS).astype(np.float32) # Full time numpy array
synth_solution = scipy_odeint(
    lotka_volterra, U0_SYNTH, t_synth_np,
    args=(TRUE_ALPHA, TRUE_BETA, TRUE_GAMMA, TRUE_DELTA)
)

# --- Add Noise to Synthetic Data (Optional - based on user code) ---
if ADD_NOISE:
    print(f"Adding Gaussian noise with std dev {NOISE_LEVEL_SYNTH} to synthetic data.")
    noise = np.random.normal(scale=NOISE_LEVEL_SYNTH, size=synth_solution.shape)
    noisy_solution = synth_solution + noise
    # Clamp to ensure non-negativity and avoid exact zero after adding noise
    u_data_np_full = np.maximum(1e-4, noisy_solution).astype(np.float32)
    print("Applied noise and non-negativity clamp.")
else:
    print("Using clean synthetic data.")
    u_data_np_full = synth_solution.astype(np.float32) # Full data numpy array

print(f"Generated {N_POINTS} data points from t={T_START} to t={T_END}.")
print(f"True LV Parameters: alpha={TRUE_ALPHA}, beta={TRUE_BETA}, gamma={TRUE_GAMMA}, delta={TRUE_DELTA}")

# --- 1b. Data Preparation ---
# Calculate split point
split_ratio = 0.8
split_idx = int(split_ratio * N_POINTS)
n_train = split_idx
n_test = N_POINTS - n_train

if n_train < 1 or n_test < 1: print("Error: Not enough data for train/test split."); exit()
print(f"Splitting data: {n_train} training points, {n_test} test points.")

# Split numpy arrays first (keep original scale for evaluation/plotting)
t_train_np = t_synth_np[:split_idx]
u_train_np_orig = u_data_np_full[:split_idx, :] # Original scale train data
t_test_np = t_synth_np[split_idx:]
u_test_np_orig = u_data_np_full[split_idx:, :]   # Original scale test data

# --- Data Scaling ---
print("Scaling data using StandardScaler (fitted on training data)...")
scaler = StandardScaler()
# Fit scaler ONLY on training data, then transform train data
u_train_np_scaled = scaler.fit_transform(u_train_np_orig)
print(f"Scaler Mean: {scaler.mean_}, Scaler Scale: {scaler.scale_}")

# Convert necessary parts to PyTorch tensors
t_train = torch.tensor(t_train_np, dtype=torch.float32).to(device) # Time points for training SDE solve
u_train = torch.tensor(u_train_np_scaled, dtype=torch.float32).to(device) # SCALED target for loss calculation
t_data = torch.tensor(t_synth_np, dtype=torch.float32).to(device)  # Full time tensor for evaluation solve

# Initial condition from the first point of the SCALED training data
u0 = u_train[0].clone().detach().to(device) # SCALED initial condition

print("Data prepared for PyTorch:")
print(f"  Training time shape: {t_train.shape}, Training data shape (scaled): {u_train.shape}")
print(f"  Test time shape (np):{t_test_np.shape}, Test data shape (orig): {u_test_np_orig.shape}")
print(f"  Initial condition (u0, scaled): {u0.cpu().numpy()}")


# --- 2. Define the Neural SDE Structure --- (Operates on SCALED data)

# --- 2a. Known Physics Part (Drift) ---
class KnownDynamicsDrift(nn.Module):
    def __init__(self, initial_params=None):
        super().__init__()
        if initial_params is None: # Generic positive guesses
             initial_params = torch.tensor([0.5, 0.1, 0.5, 0.1], dtype=torch.float32)
        self.log_params = nn.Parameter(torch.log(initial_params + 1e-8))

    def forward(self, t, u): # u is SCALED and already clamped in SDEDynamics.f
        if u.ndim == 1: u = u.unsqueeze(0)
        params = torch.exp(self.log_params)
        alpha, beta, gamma, delta = params[0], params[1], params[2], params[3]
        # Apply dynamics to scaled state
        algae_scaled = u[:, 0]
        rotifers_scaled = u[:, 1]
        d_algae_scaled = alpha * algae_scaled - beta * algae_scaled * rotifers_scaled
        d_rotifers_scaled = delta * algae_scaled * rotifers_scaled - gamma * rotifers_scaled
        du_dt_scaled = torch.stack([d_algae_scaled, d_rotifers_scaled], dim=1)
        return du_dt_scaled

# --- 2b. Unknown Part (Neural Network for Drift Correction) ---
class NeuralNetworkDrift(nn.Module):
    def __init__(self, input_dim=2, hidden_dim=HIDDEN_DIM_NN, output_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, output_dim)
        )
        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('tanh'))
                if m.bias is not None: nn.init.zeros_(m.bias)

    def forward(self, u): # u is SCALED and already clamped in SDEDynamics.f
         return self.net(u)

# --- 2c. Unknown Part (Neural Network for Diffusion) ---
class NeuralNetworkDiffusion(nn.Module):
    def __init__(self, input_dim=2, hidden_dim=HIDDEN_DIM_NN, output_dim=2): # output_dim unused
        super().__init__()
        self.input_dim = input_dim
        if NOISE_TYPE == "diagonal": final_output_dim = input_dim
        elif NOISE_TYPE == "scalar": final_output_dim = 1
        elif NOISE_TYPE == "general": final_output_dim = input_dim * BROWNIAN_SIZE
        else: raise ValueError(f"Unknown noise_type: {NOISE_TYPE}")
        self.final_output_dim = final_output_dim

        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, final_output_dim) # No final activation
        )
        # Initialize weights
        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                if m != self.net[-1]: # Hidden layers
                    nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('tanh'))
                    if m.bias is not None: nn.init.zeros_(m.bias)
                else: # Last layer -> small initial diffusion
                    nn.init.uniform_(m.weight, a=-0.01, b=0.01)
                    if m.bias is not None: nn.init.constant_(m.bias, 0.0)

    def forward(self, u): # u is SCALED and already clamped in SDEDynamics.g
        return self.net(u)

# --- 2d. Combined Neural SDE Dynamics ---
class SDEDynamics(nn.Module):
    sde_type = "ito"
    noise_type = NOISE_TYPE

    def __init__(self, initial_known_params=None):
        super().__init__()
        self.known_drift = KnownDynamicsDrift(initial_known_params).to(device)
        self.nn_drift = NeuralNetworkDrift().to(device)
        self.nn_diffusion = NeuralNetworkDiffusion().to(device)

    def f(self, t, u): # Drift function (operates on SCALED u)
        u_nonneg = torch.relu(u) # Clamp input state >= 0
        known = self.known_drift(t, u_nonneg)
        nn_drift = self.nn_drift(u_nonneg)
        return known + nn_drift

    def g(self, t, u): # Diffusion function (operates on SCALED u)
        u_nonneg = torch.relu(u) # Clamp input state >= 0
        return self.nn_diffusion(u_nonneg)

# --- 3. Setup Model, Optimizer, and Loss ---
sde_func = SDEDynamics().to(device)
parameters = list(sde_func.parameters())
optimizer = optim.Adam(parameters, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
loss_fn = nn.MSELoss() # Loss on SCALED data


Using device: cpu

--- Generating Synthetic Lotka-Volterra Data ---
Adding Gaussian noise with std dev 0.1 to synthetic data.
Applied noise and non-negativity clamp.
Generated 200 data points from t=0 to t=10.
True LV Parameters: alpha=1.5, beta=1.0, gamma=2.0, delta=0.5
Splitting data: 160 training points, 40 test points.
Scaling data using StandardScaler (fitted on training data)...
Scaler Mean: [4.00458    1.50427472], Scaler Scale: [2.8739644  1.24541893]
Data prepared for PyTorch:
  Training time shape: torch.Size([160]), Training data shape (scaled): torch.Size([160, 2])
  Test time shape (np):(40,), Test data shape (orig): (40, 2)
  Initial condition (u0, scaled): [1.4074953 1.1898799]


In [2]:
# --- 4. Training Loop ---
print(f"\n--- Starting Training Neural SDE on Synthetic Data ({NUM_EPOCHS} Epochs) ---")
losses = []
min_loss = float('inf')
best_model_state = None
best_epoch = 0
ts_train = t_train # Use training time tensor

for epoch in range(1, NUM_EPOCHS + 1):
    sde_func.train()
    optimizer.zero_grad()
    u0_batch = u0.unsqueeze(0) # Scaled initial condition with batch dim

    try:
        # Solve SDE on training time points
        u_pred_train_scaled = torchsde.sdeint(sde_func, u0_batch, ts_train, method=SDE_SOLVER_METHOD, dt=SDE_SOLVER_DT, names={'drift': 'f', 'diffusion': 'g'})
        u_pred_train_scaled = u_pred_train_scaled.squeeze(1) # Remove batch dim

        if torch.isnan(u_pred_train_scaled).any() or torch.isinf(u_pred_train_scaled).any():
            print(f"Epoch {epoch}: NaN/Inf detected in prediction. Skipping update.")
            if best_model_state:
                 try: sde_func.load_state_dict(best_model_state); print("  Restored best model state.")
                 except Exception as load_e: print(f"  Failed to restore best model state: {load_e}")
            else: print("  No previous best state to restore.")
            losses.append(float('nan'))
            continue

        # Loss calculation using SCALED prediction vs SCALED true data (u_train)
        loss = loss_fn(u_pred_train_scaled, u_train)

        if torch.isnan(loss): print(f"Error: Loss is NaN at epoch {epoch}. Stopping."); break

        loss.backward()
        torch.nn.utils.clip_grad_norm_(parameters, max_norm=GRADIENT_CLIP_MAX_NORM) # Clipping
        optimizer.step()
        current_loss = loss.item()
        losses.append(current_loss)

    except Exception as e:
        print(f"Error during training loop at epoch {epoch}: {e}")
        if best_model_state: # Try restore before breaking
            try: sde_func.load_state_dict(best_model_state); print("Restored best state before stopping.")
            except Exception as load_e: print(f"Failed to restore best state after error: {load_e}")
        break

    # Logging
    if epoch % PRINT_EVERY == 0 or epoch == 1:
        if not np.isnan(current_loss):
            print(f"Epoch {epoch}/{NUM_EPOCHS}, Loss (Scaled): {current_loss:.6f}")
            # Optional: Check learned params & diffusion magnitude
            with torch.no_grad():
                 current_params = torch.exp(sde_func.known_drift.log_params).cpu().numpy()
                 # print(f"  Learned Params (exp): {np.round(current_params, 3)}")
                 g_u0 = sde_func.g(ts_train[0], u0_batch)
                 # print(f"  Diffusion ||g(0, u0)||: {torch.norm(g_u0).item():.4f}")

    # Save best model
    if not np.isnan(current_loss) and current_loss < min_loss:
        min_loss = current_loss
        best_model_state = sde_func.state_dict()
        best_epoch = epoch

# --- End of Training ---
if epoch < NUM_EPOCHS: print(f"--- Training stopped early at epoch {epoch} ---")
else: print("--- Training Complete ---")

if losses and not np.all(np.isnan(losses)): print(f"Final Training Loss (Scaled): {losses[-1]:.6f}")
if best_model_state:
    print(f"Best Training Loss Achieved (Scaled): {min_loss:.6f} at epoch {best_epoch}")
    print("Loading best model state for evaluation...")
    try: sde_func.load_state_dict(best_model_state); print("Best model state loaded.")
    except Exception as e: print(f"Warning: Failed to load best model state: {e}")
else: print("Warning: No valid best model state saved.")




--- Starting Training Neural SDE on Synthetic Data (1000 Epochs) ---
Epoch 1/1000, Loss (Scaled): 208.175415
Epoch 10/1000, Loss (Scaled): 0.935773
Epoch 20/1000, Loss (Scaled): 0.759084
Epoch 30/1000, Loss (Scaled): 0.820251
Epoch 40/1000, Loss (Scaled): 0.828238
Epoch 50/1000, Loss (Scaled): 0.774556
Epoch 60/1000, Loss (Scaled): 0.721235
Epoch 70/1000, Loss (Scaled): 0.657098
Epoch 80/1000, Loss (Scaled): 0.631252
Epoch 90/1000, Loss (Scaled): 0.821171
Epoch 100/1000, Loss (Scaled): 0.677611
Epoch 110/1000, Loss (Scaled): 0.641683
Epoch 120/1000, Loss (Scaled): 0.624559
Epoch 130/1000, Loss (Scaled): 0.646389
Epoch 140/1000, Loss (Scaled): 0.624991
Epoch 150/1000, Loss (Scaled): 0.687060
Epoch 160/1000, Loss (Scaled): 0.716171
Epoch 170/1000, Loss (Scaled): 0.811312
Epoch 180/1000, Loss (Scaled): 0.629290
Epoch 190/1000, Loss (Scaled): 0.690032
Epoch 200/1000, Loss (Scaled): 0.760886
Epoch 210/1000, Loss (Scaled): 0.808700
Epoch 220/1000, Loss (Scaled): 0.707179
Epoch 230/1000, Los

Orig SDE Synth code

In [3]:


# --- Helper Functions for Metrics (copied from your provided ODE code) ---
def nrmse(y_true, y_pred):
    """Calculate Normalized Root Mean Squared Error (NRMSE) by range."""
    if y_true.ndim > 1 and y_true.shape[1] > 1: # Cannot calculate range for multi-output directly
        # Calculate NRMSE for each dimension separately and average? Or return NaN?
        # Let's return NaN for simplicity, as interpretation gets complex.
        # You could modify this to return per-dimension NRMSE if needed.
        # print("Warning: NRMSE calculated per dimension for multi-output.")
        # nrmse_vals = [nrmse(y_true[:, i], y_pred[:, i]) for i in range(y_true.shape[1])]
        # return np.nanmean(nrmse_vals)
         return np.nan # Simpler: return NaN for multi-output total NRMSE
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    data_range = np.max(y_true) - np.min(y_true)
    if data_range < 1e-8: # Avoid division by zero or near-zero range
        return np.nan if rmse > 1e-8 else 0.0
    return rmse / data_range

def smape(y_true, y_pred, epsilon=1e-8):
    """Calculate Symmetric Mean Absolute Percentage Error (sMAPE)."""
    numerator = np.abs(y_pred - y_true)
    denominator = (np.abs(y_true) + np.abs(y_pred)) / 2.0 # The average scale
    # Add epsilon to prevent division by zero where both true and pred are zero
    ratio = numerator / (denominator + epsilon)
    # Need to handle potential NaNs if both y_true and y_pred are exactly 0 + epsilon
    # But mean should handle this okay unless *all* values are zero.
    return np.mean(ratio) * 100.0 # Return as percentage


# --- 5. Evaluation (Fit on Train + Forecast on Test) ---
print("\n--- Evaluating SDE Model Fit (Train) and Forecast (Test) ---")

# --- Configuration for Evaluation ---
N_ENSEMBLE = 1000 # Number of SDE paths to simulate for evaluation (>= 1)
                # Use > 1 for mean/std dev, 1 for single path evaluation.
EVAL_SDE_SOLVER_METHOD = SDE_SOLVER_METHOD # Use the same solver as training or specify another
EVAL_SDE_SOLVER_DT = SDE_SOLVER_DT         # Use the same dt as training or specify another

sde_func.eval() # Set model to evaluation mode

with torch.no_grad():
    # --- Generate Ensemble Predictions over FULL time span ---
    print(f"Generating {N_ENSEMBLE} ensemble predictions using SDE solver...")
    all_u_pred_full = []

    # Prepare initial condition for batch solving if N_ENSEMBLE > 1
    if N_ENSEMBLE > 1:
        # Repeat u0 N_ENSEMBLE times along a new batch dimension
        u0_ensemble = u0.unsqueeze(0).repeat(N_ENSEMBLE, 1) # Shape: (N_ENSEMBLE, state_dim)
    else:
        u0_ensemble = u0.unsqueeze(0) # Shape: (1, state_dim)

    try:
        # Use torchsde.sdeint for prediction
        # The output shape will be (time, batch, state_dim)
        u_pred_ensemble = torchsde.sdeint(
            sde_func,
            u0_ensemble,
            t_data, # Full time tensor
            method=EVAL_SDE_SOLVER_METHOD,
            dt=EVAL_SDE_SOLVER_DT,
            names={'drift': 'f', 'diffusion': 'g'}
        ).to(device)

        # Check for NaNs/Infs in the ensemble predictions
        if torch.isnan(u_pred_ensemble).any() or torch.isinf(u_pred_ensemble).any():
            print("Error: NaN or Inf detected in SDE predictions during evaluation.")
            # Handle error case - perhaps skip metrics/plotting or try to analyze
            metrics_calculated = False
            u_pred_full_np = np.full((len(t_data), u0.shape[0]), np.nan) # Placeholder NaN array
            u_pred_full_mean = None # Indicate mean calculation failed
            u_pred_full_std = None  # Indicate std calculation failed

        else:
            # --- Calculate Mean and Std Dev of Ensemble ---
            # u_pred_ensemble shape: (time, batch, state_dim) -> (N_POINTS, N_ENSEMBLE, 2)
            u_pred_full_mean = torch.mean(u_pred_ensemble, dim=1) # Mean over ensemble dim, shape (N_POINTS, 2)
            u_pred_full_std = torch.std(u_pred_ensemble, dim=1)   # Std dev over ensemble dim, shape (N_POINTS, 2)

            # Convert MEAN prediction to NumPy for metrics and plotting
            u_pred_full_np = u_pred_full_mean.cpu().numpy() # Use mean for comparison

            # Split the MEAN predictions into train and test portions
            u_pred_train_np = u_pred_full_np[:split_idx, :] # Mean prediction on training timespan
            u_pred_test_np = u_pred_full_np[split_idx:, :]  # Mean prediction on test timespan

            # Also get std dev as numpy if needed for plotting uncertainty
            u_pred_std_np = u_pred_full_std.cpu().numpy() if N_ENSEMBLE > 1 else None

            metrics_calculated = True # Flag that metrics can be calculated

    except Exception as e:
        print(f"Error during SDE prediction for evaluation: {e}")
        metrics_calculated = False
        u_pred_full_np = np.full((len(t_data), u0.shape[0]), np.nan)
        u_pred_full_mean = None
        u_pred_full_std = None


    # --- Calculate Metrics (using MEAN prediction vs True/Noisy Data) ---
    print("\n--- Metrics Calculation (Comparing Mean SDE Prediction to Data) ---")
    metrics_train = {}
    metrics_test = {}

    if metrics_calculated:
        # Separate true components for easier metric calculation
        u_train_prey = u_train_np[:, 0]
        u_train_predator = u_train_np[:, 1]
        u_test_prey = u_test_np[:, 0]
        u_test_predator = u_test_np[:, 1]

        # Separate predicted (mean) components
        u_pred_train_prey = u_pred_train_np[:, 0]
        u_pred_train_predator = u_pred_train_np[:, 1]
        u_pred_test_prey = u_pred_test_np[:, 0]
        u_pred_test_predator = u_pred_test_np[:, 1]

        # --- Training Set Metrics (Fit) ---
        metrics_train['MSE_prey'] = mean_squared_error(u_train_prey, u_pred_train_prey)
        metrics_train['MSE_predator'] = mean_squared_error(u_train_predator, u_pred_train_predator)
        metrics_train['MSE_total'] = mean_squared_error(u_train_np, u_pred_train_np)
        metrics_train['RMSE_prey'] = np.sqrt(metrics_train['MSE_prey'])
        metrics_train['RMSE_predator'] = np.sqrt(metrics_train['MSE_predator'])
        metrics_train['RMSE_total'] = np.sqrt(metrics_train['MSE_total'])
        metrics_train['MAE_prey'] = mean_absolute_error(u_train_prey, u_pred_train_prey)
        metrics_train['MAE_predator'] = mean_absolute_error(u_train_predator, u_pred_train_predator)
        metrics_train['MAE_total'] = mean_absolute_error(u_train_np, u_pred_train_np)
        metrics_train['NRMSE_prey'] = nrmse(u_train_prey, u_pred_train_prey)
        metrics_train['NRMSE_predator'] = nrmse(u_train_predator, u_pred_train_predator)
        # Total NRMSE calculated per component, averaged, or NaN based on helper function
        metrics_train['NRMSE_total'] = nrmse(u_train_np, u_pred_train_np)
        metrics_train['sMAPE_prey'] = smape(u_train_prey, u_pred_train_prey)
        metrics_train['sMAPE_predator'] = smape(u_train_predator, u_pred_train_predator)
        # Total sMAPE is often calculated by averaging component sMAPEs or on flattened arrays
        metrics_train['sMAPE_total'] = smape(u_train_np.flatten(), u_pred_train_np.flatten())
        metrics_train['R2_prey'] = r2_score(u_train_prey, u_pred_train_prey)
        metrics_train['R2_predator'] = r2_score(u_train_predator, u_pred_train_predator)
        metrics_train['R2_total'] = r2_score(u_train_np, u_pred_train_np)

        # --- Test Set Metrics (Forecast) ---
        metrics_test['MSE_prey'] = mean_squared_error(u_test_prey, u_pred_test_prey)
        metrics_test['MSE_predator'] = mean_squared_error(u_test_predator, u_pred_test_predator)
        metrics_test['MSE_total'] = mean_squared_error(u_test_np, u_pred_test_np)
        metrics_test['RMSE_prey'] = np.sqrt(metrics_test['MSE_prey'])
        metrics_test['RMSE_predator'] = np.sqrt(metrics_test['MSE_predator'])
        metrics_test['RMSE_total'] = np.sqrt(metrics_test['MSE_total'])
        metrics_test['MAE_prey'] = mean_absolute_error(u_test_prey, u_pred_test_prey)
        metrics_test['MAE_predator'] = mean_absolute_error(u_test_predator, u_pred_test_predator)
        metrics_test['MAE_total'] = mean_absolute_error(u_test_np, u_pred_test_np)
        metrics_test['NRMSE_prey'] = nrmse(u_test_prey, u_pred_test_prey)
        metrics_test['NRMSE_predator'] = nrmse(u_test_predator, u_pred_test_predator)
        metrics_test['NRMSE_total'] = nrmse(u_test_np, u_pred_test_np)
        metrics_test['sMAPE_prey'] = smape(u_test_prey, u_pred_test_prey)
        metrics_test['sMAPE_predator'] = smape(u_test_predator, u_pred_test_predator)
        metrics_test['sMAPE_total'] = smape(u_test_np.flatten(), u_pred_test_np.flatten())
        metrics_test['R2_prey'] = r2_score(u_test_prey, u_pred_test_prey)
        metrics_test['R2_predator'] = r2_score(u_test_predator, u_pred_test_predator)
        metrics_test['R2_total'] = r2_score(u_test_np, u_pred_test_np)

        # --- Print Metrics ---
        print("\n--- Training Set Metrics (Mean SDE Fit vs Data) ---")
        print(f"  Metric       |   Prey    | Predator  |   Total   ")
        print(f"----------------------------------------------------")
        print(f"  MSE          | {metrics_train['MSE_prey']:<9.4f} | {metrics_train['MSE_predator']:<9.4f} | {metrics_train['MSE_total']:<9.4f}")
        print(f"  RMSE         | {metrics_train['RMSE_prey']:<9.4f} | {metrics_train['RMSE_predator']:<9.4f} | {metrics_train['RMSE_total']:<9.4f}")
        print(f"  MAE          | {metrics_train['MAE_prey']:<9.4f} | {metrics_train['MAE_predator']:<9.4f} | {metrics_train['MAE_total']:<9.4f}")
        print(f"  NRMSE (range)| {metrics_train['NRMSE_prey']:<9.4f} | {metrics_train['NRMSE_predator']:<9.4f} | {metrics_train['NRMSE_total']:<9.4f}")
        print(f"  sMAPE (%)    | {metrics_train['sMAPE_prey']:<9.2f} | {metrics_train['sMAPE_predator']:<9.2f} | {metrics_train['sMAPE_total']:<9.2f}")
        print(f"  R^2          | {metrics_train['R2_prey']:<9.4f} | {metrics_train['R2_predator']:<9.4f} | {metrics_train['R2_total']:<9.4f}")

        print("\n--- Test Set Metrics (Mean SDE Forecast vs Data) ---")
        print(f"  Metric       |   Prey    | Predator  |   Total   ")
        print(f"----------------------------------------------------")
        print(f"  MSE          | {metrics_test['MSE_prey']:<9.4f} | {metrics_test['MSE_predator']:<9.4f} | {metrics_test['MSE_total']:<9.4f}")
        print(f"  RMSE         | {metrics_test['RMSE_prey']:<9.4f} | {metrics_test['RMSE_predator']:<9.4f} | {metrics_test['RMSE_total']:<9.4f}")
        print(f"  MAE          | {metrics_test['MAE_prey']:<9.4f} | {metrics_test['MAE_predator']:<9.4f} | {metrics_test['MAE_total']:<9.4f}")
        print(f"  NRMSE (range)| {metrics_test['NRMSE_prey']:<9.4f} | {metrics_test['NRMSE_predator']:<9.4f} | {metrics_test['NRMSE_total']:<9.4f}")
        print(f"  sMAPE (%)    | {metrics_test['sMAPE_prey']:<9.2f} | {metrics_test['sMAPE_predator']:<9.2f} | {metrics_test['sMAPE_total']:<9.2f}")
        print(f"  R^2          | {metrics_test['R2_prey']:<9.4f} | {metrics_test['R2_predator']:<9.4f} | {metrics_test['R2_total']:<9.4f}")
        print("-" * 52)

    else:
        print("Metrics calculation skipped due to errors in prediction generation.")


    # --- Learned Parameter Printout (Drift part) ---
    try:
        # Access parameters from the known_drift submodule of sde_func
        learned_log_params = sde_func.known_drift.log_params
        learned_params = torch.exp(learned_log_params).cpu().numpy()
        print("\nLearned Physical Parameters (Drift) vs True Parameters:")
        print(f"                 True     Learned")
        # Ensure TRUE_PARAMS are accessible here
        print(f"  alpha:   {TRUE_PARAMS[0]:<8.4f} {learned_params[0]:<8.4f}")
        print(f"  beta:    {TRUE_PARAMS[1]:<8.4f} {learned_params[1]:<8.4f}")
        print(f"  gamma:   {TRUE_PARAMS[2]:<8.4f} {learned_params[2]:<8.4f}")
        print(f"  delta:   {TRUE_PARAMS[3]:<8.4f} {learned_params[3]:<8.4f}")
    except AttributeError:
        print("\nCould not retrieve learned drift parameters (check model structure).")
    except NameError:
        print("\nCould not retrieve learned drift parameters (TRUE_PARAMS not defined).")
    except Exception as e:
        print(f"\nError retrieving learned drift parameters: {e}")


# --- 6. Plotting Results ---
print("\n--- Generating Plots (SDE Mean Prediction and Uncertainty) ---")
plt.style.use('seaborn-v0_8-darkgrid')
split_time = t_train_np[-1] # Time where split occurs

# Plot 1: Loss Curve
# This is typically plotted right after the training loop using the 'losses' list.
# Example:
# plt.figure(figsize=(10, 5))
# plt.plot(range(1, len(losses) + 1), losses) # Assuming 'losses' is available
# plt.xlabel("Epoch")
# plt.ylabel("MSE Loss (Training Data)")
# plt.title("Training Loss Curve (SDE Synth)")
# plt.yscale('log') # Keep log scale if loss varies greatly
# plt.tight_layout()
# plt.show() # Show it separately or save it
print("Skipping loss curve plot here - assumed plotted after training loop.")


# Plot 2: Time Series Fit and Forecast (Mean Prediction + Uncertainty)
fig_ts, axs_ts = plt.subplots(2, 1, figsize=(12, 9), sharex=True) # Increased height slightly

if metrics_calculated: # Only plot if predictions were successful
    # Define colors
    fit_color_prey = 'deepskyblue'
    forecast_color_prey = 'blue'
    fit_color_predator = 'limegreen'
    forecast_color_predator = 'green'
    uncertainty_color_prey = 'lightblue'
    uncertainty_color_predator = 'lightgreen'

    # Data for plotting forecast continuously from end of fit (MEAN prediction)
    t_forecast_plot = np.concatenate(([t_train_np[-1]], t_test_np))
    u_pred_forecast_plot_prey = np.concatenate(([u_pred_train_np[-1, 0]], u_pred_test_np[:, 0]))
    u_pred_forecast_plot_predator = np.concatenate(([u_pred_train_np[-1, 1]], u_pred_test_np[:, 1]))

    # --- Prey Subplot ---
    # Plot the FIT part of the MEAN prediction
    axs_ts[0].plot(t_train_np, u_pred_train_np[:, 0], color=fit_color_prey, linestyle='-', linewidth=2, label='SDE Mean Pred. (Fit)')
    # Plot the FORECAST part of the MEAN prediction
    axs_ts[0].plot(t_forecast_plot, u_pred_forecast_plot_prey, color=forecast_color_prey, linestyle='-', linewidth=2, label='SDE Mean Pred. (Forecast)')

    # Plot Uncertainty Bands (if N_ENSEMBLE > 1)
    if N_ENSEMBLE > 1 and u_pred_std_np is not None:
        # Split std dev array
        u_pred_train_std_np = u_pred_std_np[:split_idx, :]
        u_pred_test_std_np = u_pred_std_np[split_idx:, :]
        # Combine last train std with test std for continuous forecast band plotting
        u_pred_forecast_plot_std_prey = np.concatenate(([u_pred_train_std_np[-1, 0]], u_pred_test_std_np[:, 0]))

        # Fit uncertainty
        axs_ts[0].fill_between(t_train_np,
                               u_pred_train_np[:, 0] - u_pred_train_std_np[:, 0],
                               u_pred_train_np[:, 0] + u_pred_train_std_np[:, 0],
                               color=uncertainty_color_prey, alpha=0.4, label='Std Dev (Fit)')
        # Forecast uncertainty
        axs_ts[0].fill_between(t_forecast_plot,
                               u_pred_forecast_plot_prey - u_pred_forecast_plot_std_prey,
                               u_pred_forecast_plot_prey + u_pred_forecast_plot_std_prey,
                               color=uncertainty_color_prey, alpha=0.6, label='Std Dev (Forecast)') # Slightly darker alpha

    # Plot data points
    axs_ts[0].plot(t_train_np, u_train_np[:, 0], 'ko', markersize=4, alpha=0.7, label='Training Data (Prey)')
    axs_ts[0].plot(t_test_np, u_test_np[:, 0], 'ro', markersize=4, alpha=0.7, label='Test Data (Prey)')
    axs_ts[0].axvline(split_time, color='gray', linestyle='--', linewidth=1.5, label='Train/Test Split')
    axs_ts[0].set_ylabel("Prey Population")
    axs_ts[0].legend(loc='upper right', fontsize=9)
    axs_ts[0].set_title(f"SDE Fit & Forecast (Mean of {N_ENSEMBLE} Paths) on Synthetic Data")

    # --- Predator Subplot ---
    # Plot the FIT part of the MEAN prediction
    axs_ts[1].plot(t_train_np, u_pred_train_np[:, 1], color=fit_color_predator, linestyle='-', linewidth=2, label='SDE Mean Pred. (Fit)')
    # Plot the FORECAST part of the MEAN prediction
    axs_ts[1].plot(t_forecast_plot, u_pred_forecast_plot_predator, color=forecast_color_predator, linestyle='-', linewidth=2, label='SDE Mean Pred. (Forecast)')

    # Plot Uncertainty Bands (if N_ENSEMBLE > 1)
    if N_ENSEMBLE > 1 and u_pred_std_np is not None:
         # Use split std dev from above
         u_pred_forecast_plot_std_predator = np.concatenate(([u_pred_train_std_np[-1, 1]], u_pred_test_std_np[:, 1]))
         # Fit uncertainty
         axs_ts[1].fill_between(t_train_np,
                                u_pred_train_np[:, 1] - u_pred_train_std_np[:, 1],
                                u_pred_train_np[:, 1] + u_pred_train_std_np[:, 1],
                                color=uncertainty_color_predator, alpha=0.4, label='Std Dev (Fit)')
         # Forecast uncertainty
         axs_ts[1].fill_between(t_forecast_plot,
                                u_pred_forecast_plot_predator - u_pred_forecast_plot_std_predator,
                                u_pred_forecast_plot_predator + u_pred_forecast_plot_std_predator,
                                color=uncertainty_color_predator, alpha=0.6, label='Std Dev (Forecast)')

    # Plot data points
    axs_ts[1].plot(t_train_np, u_train_np[:, 1], 'ko', markersize=4, alpha=0.7, label='Training Data (Predator)')
    axs_ts[1].plot(t_test_np, u_test_np[:, 1], 'ro', markersize=4, alpha=0.7, label='Test Data (Predator)')
    axs_ts[1].axvline(split_time, color='gray', linestyle='--', linewidth=1.5) # Split line only
    axs_ts[1].set_ylabel("Predator Population")
    axs_ts[1].set_xlabel("Time")
    axs_ts[1].legend(loc='upper right', fontsize=9)

    plt.tight_layout(rect=[0, 0.03, 1, 0.96]) # Adjust top margin for title
else:
    axs_ts[0].set_title("SDE Fit & Forecast - Plotting Skipped Due to Prediction Error")
    axs_ts[0].text(0.5, 0.5, 'Error during prediction generation', horizontalalignment='center', verticalalignment='center', transform=axs_ts[0].transAxes)
    axs_ts[1].text(0.5, 0.5, 'Error during prediction generation', horizontalalignment='center', verticalalignment='center', transform=axs_ts[1].transAxes)


# Plot 3: Phase Plot (Mean Fit and Forecast)
plt.figure(3, figsize=(8, 8))
if metrics_calculated: # Only plot if predictions were successful
    plt.plot(u_train_np[:, 0], u_train_np[:, 1], 'ko', markersize=5, alpha=0.7, label='Training Data')
    plt.plot(u_test_np[:, 0], u_test_np[:, 1], 'ro', markersize=5, alpha=0.7, label='Test Data')
    # Plot the full MEAN prediction line
    plt.plot(u_pred_full_np[:, 0], u_pred_full_np[:, 1], 'm-', linewidth=2, label=f'SDE Mean Pred. (N={N_ENSEMBLE})')
    # Mark start/end points
    plt.plot(u0[0].cpu(), u0[1].cpu(), 'kX', markersize=10, label='Start')
    plt.plot(u_pred_train_np[-1, 0], u_pred_train_np[-1, 1], 'ms', markersize=8, label='End of Fit / Start of Forecast')
    plt.xlabel("Prey Population")
    plt.ylabel("Predator Population")
    plt.title(f"Phase Plot: Mean SDE Prediction (N={N_ENSEMBLE}) vs Data")
    plt.legend()
    plt.grid(True)
else:
    plt.title("Phase Plot - Plotting Skipped Due to Prediction Error")
    plt.text(0.5, 0.5, 'Error during prediction generation', horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes)
plt.tight_layout()


# --- 7. Visualize Learned Dynamics (Drift Component) ---
print("\n--- Visualizing Learned Drift Dynamics (Based on Mean Trajectory) ---")

# We need the mean prediction as a tensor to evaluate the dynamics
if metrics_calculated and u_pred_full_mean is not None:
    with torch.no_grad():
        # Use u_pred_full_mean tensor (output from mean calculation)
        known_dyn_pred = torch.zeros_like(u_pred_full_mean)
        nn_dyn_pred = torch.zeros_like(u_pred_full_mean)

        # Evaluate dynamics at each point along the MEAN predicted trajectory
        for i in range(len(t_data)):
            # Use the mean state at time t_i as input
            u_i_mean = u_pred_full_mean[i].unsqueeze(0) # Add batch dim
            t_i = t_data[i]

            # Clamp input for safety, similar to training
            u_i_mean_nonneg = torch.relu(u_i_mean)

            # Get contributions from the trained SDE model's drift components
            # Access sub-modules directly
            known_dyn_pred[i] = sde_func.known_drift(t_i, u_i_mean_nonneg).squeeze(0)
            nn_dyn_pred[i] = sde_func.nn_drift(u_i_mean_nonneg).squeeze(0) # NN drift correction

        known_dyn_np = known_dyn_pred.cpu().numpy()
        nn_dyn_np = nn_dyn_pred.cpu().numpy()
        total_learned_drift_np = known_dyn_np + nn_dyn_np # This is the learned dU/dt (drift part)

    # Plot 4: Dynamics Decomposition (Drift only)
    fig_decomp, axs_decomp = plt.subplots(2, 2, figsize=(15, 10))
    # Prey Dynamics plot
    axs_decomp[0, 0].plot(t_data.cpu().numpy(), known_dyn_np[:, 0], 'c--', label='Known Drift (LV)')
    axs_decomp[0, 0].plot(t_data.cpu().numpy(), nn_dyn_np[:, 0], 'm:', label='NN Drift Correction')
    axs_decomp[0, 0].plot(t_data.cpu().numpy(), total_learned_drift_np[:, 0], 'b-', alpha=0.7, label='Total Learned Drift (f)')
    axs_decomp[0, 0].axvline(split_time, color='gray', linestyle=':', linewidth=1) # Indicate split
    axs_decomp[0, 0].set_ylabel("Rate of Change (Prey Drift)")
    axs_decomp[0, 0].legend(); axs_decomp[0, 0].set_title("Drift Decomposition (Prey)")
    axs_decomp[0, 0].axhline(0, color='gray', linestyle='-', linewidth=0.5); axs_decomp[0, 0].grid(True)
    # Predator Dynamics plot
    axs_decomp[1, 0].plot(t_data.cpu().numpy(), known_dyn_np[:, 1], 'y--', label='Known Drift (LV)')
    axs_decomp[1, 0].plot(t_data.cpu().numpy(), nn_dyn_np[:, 1], 'r:', label='NN Drift Correction')
    axs_decomp[1, 0].plot(t_data.cpu().numpy(), total_learned_drift_np[:, 1], 'g-', alpha=0.7, label='Total Learned Drift (f)')
    axs_decomp[1, 0].axvline(split_time, color='gray', linestyle=':', linewidth=1)
    axs_decomp[1, 0].set_ylabel("Rate of Change (Predator Drift)"); axs_decomp[1, 0].set_xlabel("Time")
    axs_decomp[1, 0].legend(); axs_decomp[1, 0].axhline(0, color='gray', linestyle='-', linewidth=0.5); axs_decomp[1, 0].grid(True)
    # Contribution of NN Drift Term (Prey)
    axs_decomp[0, 1].plot(t_data.cpu().numpy(), nn_dyn_np[:, 0], 'm:', label='NN Drift Correction (Prey)')
    axs_decomp[0, 1].axvline(split_time, color='gray', linestyle=':', linewidth=1)
    axs_decomp[0, 1].set_ylabel("NN Output Value"); axs_decomp[0, 1].legend(); axs_decomp[0, 1].set_title("NN Drift Contribution (Prey)")
    axs_decomp[0, 1].axhline(0, color='gray', linestyle='-', linewidth=0.5); axs_decomp[0, 1].grid(True)
    # Contribution of NN Drift Term (Predator)
    axs_decomp[1, 1].plot(t_data.cpu().numpy(), nn_dyn_np[:, 1], 'r:', label='NN Drift Correction (Predator)')
    axs_decomp[1, 1].axvline(split_time, color='gray', linestyle=':', linewidth=1)
    axs_decomp[1, 1].set_ylabel("NN Output Value"); axs_decomp[1, 1].set_xlabel("Time")
    axs_decomp[1, 1].legend(); axs_decomp[1, 1].set_title("NN Drift Contribution (Predator)")
    axs_decomp[1, 1].axhline(0, color='gray', linestyle='-', linewidth=0.5); axs_decomp[1, 1].grid(True)

    plt.suptitle("Learned Drift Dynamics Decomposition (SDE Synth - Mean Trajectory)", y=1.02)
    plt.tight_layout(rect=[0, 0.03, 1, 0.98])

else:
     print("Skipping dynamics decomposition plot due to prediction errors.")


# Show all generated figures at the end
plt.show()

print("\n--- Evaluation and Plotting Script Finished ---")


--- Evaluating SDE Model Fit (Train) and Forecast (Test) ---
Generating 1000 ensemble predictions using SDE solver...

--- Metrics Calculation (Comparing Mean SDE Prediction to Data) ---


NameError: name 'u_train_np' is not defined