# PODS Optimization vs Shooting: Robustness to Model Mismatch

This notebook compares two optimization methods for parameter recovery in Partially Observed Dynamical Systems (PODS) and visualizes the convergence process.


In [None]:
import sys
import os
import torch
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import pandas as pd
import importlib

# Add current directory to path
sys.path.append(os.getcwd())

from data.generate_data import WarfarinTMDDDataset
import optimization_utils
importlib.reload(optimization_utils)
from optimization_utils import WarfarinModel_true, WarfarinModel_miss, run_homotopy_optimization, run_shooting_optimization, calc_param_error

## 1. Setup: Ground Truth & Observations


In [None]:
# True parameters
# We increase kon and koff to make the binding/unbinding dynamics more significant,
# so that removing the koff term (Model Mismatch) has a visible impact.
true_params = {
    'CL': 0.1, 'V1': 1.0, 'Q': 0.5, 'V2': 2.0,
    'kon': 0.1, 'koff': 0.5, 'kdeg': 0.001, 'ksyn': 0.002, 
    'initial_state': [10.0, 0.0, 0.0, 5.0],
    't_start': 0, 't_end': 50, 'num_points': 200
}

# Generate data
generator = WarfarinTMDDDataset(generation_parameters=true_params)
gt_data = generator.generate()
time_points = generator.get_time_points()

# Observations
obs_indices = [0, 2] # x1, x3
noise_std = 0.01 # 1% noise
observed_data = gt_data[:, obs_indices].clone()
observed_data += torch.randn_like(observed_data) * noise_std

# Init params for optimization (perturbed)
init_perturbed = {k: v * 1.5 for k, v in true_params.items() if isinstance(v, (int, float))}

print(f"Observed Shape: {observed_data.shape}")

## 2. Convergence Visualization (Homotopy)

We visualize how the estimated trajectory $X_{est}$ evolves as we increase $\tau$ (tightening the effective dynamics constraint).

In [None]:
# Run Homotopy with history
model_viz = WarfarinModel_true(init_perturbed)
X_final, loss_hist, traj_history = run_homotopy_optimization(model_viz, observed_data, obs_indices, time_points, verbose=False)

# Plotting
def plot_convergence(traj_history, gt_data, time_points, title="Trajectory Evolution over Tau"):
    taus = sorted(traj_history.keys())
    n_taus = len(taus)
    colors = cm.viridis(np.linspace(0, 1, n_taus))
    
    plt.figure(figsize=(15, 10))
    labels = ['x1 (Central)', 'x2 (Peripheral)', 'x3 (Bound)', 'x4 (Target)']
    
    for i in range(4):
        plt.subplot(2, 2, i+1)
        # Plot Truth
        plt.plot(time_points, gt_data[:, i], 'k--', linewidth=2, label='Ground Truth')
        
        # Plot Evolution
        for j, tau in enumerate(taus):
            X_tau = traj_history[tau]
            label = f'Tau={tau}' if i==0 else None
            plt.plot(time_points, X_tau[:, i], color=colors[j], alpha=0.7, label=label)
            
        plt.title(labels[i])
        if i==0: plt.legend()
    
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()

plot_convergence(traj_history, gt_data, time_points)

## 3. PODS vs Shooting: Misspecified Model


In [None]:
# --- Homotopy (Misspecified) ---
model_wrong_pods = WarfarinModel_miss(init_perturbed)
X_pods_wrong, _, _ = run_homotopy_optimization(model_wrong_pods, observed_data, obs_indices, time_points, verbose=False)
err_pods_wrong = calc_param_error(model_wrong_pods, true_params)

# --- Shooting (Misspecified) ---
model_wrong_shoot = WarfarinModel_miss(init_perturbed)
X_shoot_wrong, _ = run_shooting_optimization(model_wrong_shoot, observed_data, obs_indices, time_points, verbose=False)
err_shoot_wrong = calc_param_error(model_wrong_shoot, true_params)

print(f"Misspecified Model Results:")
print(f"  PODS Param Error: {err_pods_wrong:.2f}%")
print(f"  Shooting Param Error: {err_shoot_wrong:.2f}%")

# Visual comparison
X_p = X_pods_wrong.numpy() if torch.is_tensor(X_pods_wrong) else X_pods_wrong
X_s = X_shoot_wrong.numpy() if torch.is_tensor(X_shoot_wrong) else X_shoot_wrong

plt.figure(figsize=(12, 8))
labels = ['x1', 'x2', 'x3', 'x4']
for i in range(4):
    plt.subplot(2, 2, i+1)
    plt.plot(time_points, gt_data[:, i], 'k--', label='Truth')
    plt.plot(time_points, X_p[:, i], 'r-', label='PODS (Wrong Model)')
    plt.plot(time_points, X_s[:, i], 'g:', label='Shooting (Wrong Model)')
    plt.title(labels[i])
    plt.legend()
plt.suptitle("Recovery under Model Mismatch")
plt.tight_layout()
plt.show()