In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from rpasim.ode import AB
from rpasim.ode.classic_control import PopulationDynamics, Lorenz, FlightControl
from rpasim.plot.ode import plot_trajectory
from rpasim.plot.env import plot_env_trajectory
from rpasim.env import DifferentiableEnv
import matplotlib.pyplot as plt

from rpasim.style import set_style
set_style()

# Test All Implemented ODEs

This notebook plots all implemented ODEs using both:
1. Direct ODE integration (plot_trajectory)
2. Environment simulation (plot_env_trajectory)

In [None]:
# Define ODE configurations
ode_configs = [
    {
        "ode": AB(),
        "x0": torch.tensor([1.0, 1.0]),
        "T": 100.0,
        "reward_fn": lambda state: -torch.norm(state[-1] - 4.0),
        "state_limits": None
    },
    {
        "ode": PopulationDynamics(),
        "x0": torch.tensor([50.0, 10.0]),
        "T": 100.0,
        "reward_fn": lambda state: -torch.norm(state - torch.tensor([100.0, 20.0])),
        "state_limits": None
    },
    {
        "ode": Lorenz(),
        "x0": torch.tensor([1.0, 1.0, 1.0]),
        "T": 10.0,
        "reward_fn": lambda state: -torch.norm(state),  # Q = I, minimize distance from origin
        "state_limits": None
    },
    {
        "ode": FlightControl(),
        "x0": torch.tensor([0.0, 0.0, 0.0]),
        "T": 13.0,
        "reward_fn": lambda state: -25.0 * state[0]**2,  # Q=25, track x1 to zero
        "state_limits": [(-0.2, 0.4), (-float('inf'), float('inf')), (-float('inf'), float('inf'))]
    }
]

In [None]:
# Plot all ODEs
for config in ode_configs:
    print(f"\n{'='*60}")
    print(config['ode'])
    print(f"{'='*60}")
    
    # 1. Plot using ODE trajectory
    print("\nDirect ODE Integration:")
    fig, axes = plot_trajectory(
        config['ode'],
        config['x0'],
        config['T'],
        n_steps=1000
    )
    fig.suptitle(f"{config['ode'].name} - ODE Trajectory")
    plt.show()
    
    # 2. Plot using environment
    print("\nEnvironment Simulation:")
    env = DifferentiableEnv(
        initial_ode=config['ode'],
        reward_fn=config['reward_fn'],
        initial_state=config['x0'],
        time_horizon=config['T'],
        state_limits=config.get('state_limits')
    )
    
    # Run environment
    env.reset()
    env.step((config['ode'], config['T']))
    
    # Plot
    fig, axes = plot_env_trajectory(env)
    fig.suptitle(f"{config['ode'].name} - Environment Trajectory")
    plt.show()
    
    print(f"\nTotal reward: {env.get_trajectory()[2].sum():.2f}")