In [None]:
import os
import sys
one_level_up_dir = os.path.abspath(os.path.join(os.getcwd(), "../../"))
sys.path.append(one_level_up_dir)
import jax._src.random as prng
import jax
from jax import config, jit
import jax.numpy as jnp
import jax.random as jr
import diffrax as dfx
import matplotlib.pyplot as plt
import optax as ox
import equinox as eqx
from gpdx.systems.nonlinear_dynamics import VanDerPol
from gpdx.control.trajectory_optimizers import *
from gpdx.control.cost_functions import QuadraticCost
from gpdx.control.mpc import *
# from gpdx.dataset import DiffEqDataset
from gpdx.nn.node import NeuralODE, EnsembleNeuralODE
from gpdx.nn.nnvectorfield import NeuralVectorField, EnsembleNeuralVectorField
from gpdx.control.mpc import ensembleIndirectMPC
from gpdx.control.trajectory_optimizers import EnsemblePMPForward

from gpdx.fit import *
config.update("jax_enable_x64", True)
config.update("jax_debug_nans", True)
PRNGKey = prng.KeyArray
key = jr.PRNGKey(1239) 
key, subkey = jr.split(key)
%load_ext autoreload
%autoreload 2

## Define control problem and generate data to train NODE

In [None]:
from gpdx.control.control_task import CartPoleTask
env = CartPoleTask()


## data generation

In [None]:
## data generation
num_trials = 1 # RL trials
t0 = env.t0
tf = env.tf
dt = env.Delta_t
dt0_dense = 5e-3 # 

num_initial_trials = 250 # initial dataset - note that this is of 1 sec, so really this is 50 trials.
num_obs = int((tf-t0)/dt)
D_in, D_out = env.D_sys+ env.D_control, env.D_sys
noise = jnp.array([jnp.sqrt(env.measurement_noise_std) for _ in range(D_out)]) # sigma^2 = 0.05 in Hegde experiment 4.1

## ode params
stepsize_controller=dfx.PIDController(rtol=1e-4, atol=1e-5, jump_ts=None)
internal_solver = dfx.Dopri5() 
dt0_internal = 0.025

## neural network
ensemble_size = 5
data_per_ensemble = num_initial_trials
hidden_dim = 32
layer_sizes = (D_in, hidden_dim, hidden_dim, D_out)
activation = jax.nn.elu

## training
num_iters = 3_500
init_obs_noise = 0.5
batch_size = -1 # -1 or num_trials for no batching
lr = 0.0015
log_rate = 20

# MPC 
maxiter = 25
H = 1.

In [None]:
jr.uniform(key, (env.D_sys,), minval=jnp.array([-1, -3, -5, -5]), maxval=jnp.array([2, 9, 5, 10]))

In [None]:
# initial state distribution
real_system = env.real_system
key, subkey = jr.split(key)

# randomly sample observation times. 
## split trials in segments of 1 second.
tf = 1.
num_obs = int((tf-t0)/dt)

ts_uniform = jnp.concatenate([jnp.linspace(t0, tf-env.Delta_t, num_obs)[None] for _ in range(num_initial_trials)], axis=0)
ts = jnp.sort(ts_uniform, axis=1)
ts_dense = jnp.concatenate([jnp.linspace(env.t0, tf, int( ((1/dt0_dense))*(tf-t0)))[None] for _ in range(num_initial_trials)], axis=0)
freqs = jnp.arange(1,num_initial_trials+1)
indices = jnp.linspace(0.2, 1.5, num_initial_trials)

print(key)
us = jr.uniform(key=key, shape=(num_initial_trials, ts_dense.shape[-1]),minval=env.lb, maxval=env.ub)[...,None]
key, subkey = jr.split(key, 2)

# us = jax.vmap(lambda t, i: 1.3*jnp.cos(3*jr.uniform(key(round(2*i+3**i)))+17.5*jr.uniform(key)*jnp.pi*t*i+jr.normal(key)), in_axes=(0,0))(ts_dense, indices)[..., None]

# us = jax.vmap(lambda t, i: 1.3*jnp.cos(3*jr.uniform(key(round(2*i+3**i)))+17.5*jr.uniform(key)*jnp.pi*t*i+jr.normal(key))+
#               2.5*jnp.cos(6.5*jr.uniform(key*jnp.pi*t*i+4*jr.normal(key))), in_axes=(0,0))(ts_dense, indices)[..., None]


key, subkey = jr.split(key)
get_initial_condition = lambda key: env.get_initial_condition(key) + jr.uniform(key, (env.D_sys,), minval=jnp.array([-1, -1, -5, -5]), maxval=jnp.array([3.5, 4, 5, 4]))

# simulate data from real system
key, subkey = jr.split(key)
data, true_y0s = env.real_system.generate_synthetic_data(subkey,
                                           num_initial_trials,
                                           dt0=dt0_dense,
                                           ts=ts,
                                           us=us,
                                           obs_stddev=noise,
                                           ts_dense=ts_dense,
                                           x0_distribution=get_initial_condition,
                                           standardize_at_initialisation=False,
                                          )
data.__post_init__()

tf = env.tf
num_obs = int((tf-t0)/dt)
# 

In [None]:
print(env.x_star)
fig, ax = plt.subplots(1,1,figsize=(5,5))

ys_full_scale = data.ys#data.inverse_standardize(data.ys)
for i in range(data.ts.shape[0]):
    plt.plot(data.ts[i], ys_full_scale[i,:,0])
    plt.plot(data.ts[i], ys_full_scale[i,:,1])
plt.ylabel('x')
plt.xlabel('t')
# ax.scatter(env.x_star[0], env.x_star[1], color='gold')#, env.x_star[2])
plt.show()

plt.figure()
plt.title('phase plane (x and theta)')
for i in range(data.ts.shape[0]):
    plt.plot(ys_full_scale[i,:,1], ys_full_scale[i,:,3])
ax.scatter(env.x_star[0], env.x_star[1], color='gold')
plt.show()

us_full_scale = data.us#data.inverse_standardize_us(data.us) 
plt.figure()
plt.title('u(t)')
for i in range(30):
    plt.plot(data.ts_dense[i], us_full_scale[i,:,0])

plt.show()

In [None]:
data.ys.shape, data.ts.shape, data.us.shape, data.ts_dense.shape

## Train neural ODE

In [None]:
from gpdx.nn.nnvectorfield import EnsembleNeuralVectorField
from gpdx.nn.node import EnsembleNeuralODE
from gpdx.fit import fit_node
from gpdx.nn.node import mse_loss_ensemble


def train_network(data, key, num_iters):
    key, subkey = jr.split(key)
    keys = jr.split(subkey, ensemble_size)


    ensemble_datasets = jax.vmap(get_batch, in_axes=(None, None, 0, None))(data, data.n, keys, False)
    print(ensemble_datasets.n, ensemble_datasets.ys.shape)
    key, subkey = jr.split(key)

    ensemble_vectorfield = EnsembleNeuralVectorField(
                                ensemble_size=ensemble_size,
                                layer_sizes=layer_sizes,
                                activation=activation,
                                D_sys=real_system.D_sys,
                                D_control=real_system.D_control,
                                key=key,)

    """ testing the neural ODE call with the newly created vectorfield class."""

    # initialize p(x0)
    x0_mean_init = ensemble_datasets.ys[:,:,0,:]
    x0_diag_raw = jnp.zeros_like(x0_mean_init)-1.

    ensemble_node = EnsembleNeuralODE(
                            ensemble_size=ensemble_size,
                            obs_noise_raw=jnp.log(jnp.exp(init_obs_noise)-1.),
                            x0_mean=x0_mean_init-1,
                            x0_diag_raw=x0_diag_raw,
                            vectorfield=ensemble_vectorfield,
                            solver=dfx.Dopri5(),
                            dt0=dt0_internal,
                            stepsize_controller=dfx.ConstantStepSize(),#dfx.PIDController(rtol=1e-3,atol=1e-5),
                            D_sys=real_system.D_sys,
                            D_control=real_system.D_control,
                            )

    optim = ox.adam(learning_rate=lr)

    opt_ensemble_node, history = fit_node(model=ensemble_node, 
            objective=mse_loss_ensemble, 
            train_data=ensemble_datasets, 
            optim = optim, 
            key=subkey,
            num_iters=num_iters,
            batch_size=batch_size,
            log_rate=log_rate,)
    key, subkey = jr.split(key, 2)

    plt.figure()
    plt.plot(history)
    plt.xlabel('Iterations')
    plt.ylabel('Ensemble MSE')
    plt.title(f'Final loss: {history[-1]}')
    plt.show()

    return opt_ensemble_node


In [None]:
def run_single_trial(opt_ensemble_node, key:jr.PRNGKey, data:DiffEqDataset):

        key, subkey = jr.split(key)        
        n_segments = 2
        print(f'Running PMP with {n_segments} segments')
        us_init = jnp.arange(0, H, env.Delta_t) # only used for getting the shape for MPC vector - PMP doesnt requrie this.
        pmp_solver = EnsemblePMPForward(f=opt_ensemble_node.vectorfield,
                                D_sys=real_system.D_sys, 
                                D_control=real_system.D_control,
                                ensemble_size=ensemble_size,
                                n_segments=n_segments,
                                state_cost=env.state_cost,
                                termination_cost=env.termination_cost,
                                maxiter=maxiter, 
                                lb=env.lb*jnp.ones((us_init.shape[0], real_system.D_control)),  ## To do: scale thios with the u standardization!
                                ub=env.ub*jnp.ones((us_init.shape[0], real_system.D_control)),
                                # standardize_x=data.standardize,
                                # inverse_standardize_x=data.inverse_standardize,
                                # standardize_u=data.standardize_us,
                                # inverse_standardize_u=data.inverse_standardize_us,
                                # sigma_x = data._original_ys_std,
                                # sigma_u = data._original_us_std,
                                )


        ensemble_indirect_mpc = ensembleIndirectMPC(traj_optimizer=pmp_solver,
                real_system=real_system,
                internal_system=opt_ensemble_node,
                state_cost=env.state_cost,
                termination_cost=env.termination_cost,
                verbose=True,
                )
        trial_ts = jnp.linspace(env.t0, env.tf, int( ((1/env.Delta_t))*(env.tf-env.t0))) #new_trial_dataset.ts

        ts, ts_dense, X, Y, U, R = ensemble_indirect_mpc.simulate(
                        x0=env.get_initial_condition(subkey), 
                        ts=trial_ts,
                        Delta_t=env.Delta_t,
                        x_star=env.x_star,
                        dt0_internal=dt0_internal,
                        dt0_dense=dt0_dense,
                        H=H,
                        obs_noise=0.,
                        key=subkey,
                        # standardize_x=data.standardize,
                        # inverse_standardize_x=data.inverse_standardize,
                        # standardize_u=data.standardize_us,
                        # inverse_standardize_u=data.inverse_standardize_us,
                    )
        key, subkey = jr.split(key)


        pmp_ensemble_trial_cost = R[-1]
        plt.plot(ts_dense, R)
        plt.title(f'Indirect approach with NODE: Integrated Cost: {pmp_ensemble_trial_cost}')
        plt.ylabel('Cost ')
        plt.xlabel('t')
        plt.show()

        
        # visualize 
        labels = [r'$x$', r'$\theta$', r'$\dot{x}$', r'$\dot{\theta}$']
        for i in range(X.shape[-1]):
                dim_color = f'C{i*2}'
                plt.plot(ts_dense, X[:,i], label=labels[i], color=dim_color)

        plt.axhline(y=env.x_star[0], color='C0', linestyle=':')
        plt.axhline(y=env.x_star[1], color='C2', linestyle=':')
        plt.axhline(y=env.x_star[2], color='C4', linestyle=':')
        plt.axhline(y=env.x_star[3], color='C4', linestyle=':')
        plt.title(f'Integrated Cost: {R[-1]}')
        plt.ylabel('x')
        plt.xlabel('t')
        plt.legend()
        plt.show()

        return ts, ts_dense, X, Y, U, R

def run_single_trial_meanHamiltonian(opt_ensemble_node, key:jr.PRNGKey, data:DiffEqDataset):

        key, subkey = jr.split(key)        
        n_segments = 2
        print(f'Running PMP with {n_segments} segments')
        us_init = jnp.arange(0, H, env.Delta_t) # only used for getting the shape for MPC vector - PMP doesnt requrie this.
        pmp_solver = EnsemblePMPForwardMeanHamiltonian(f=opt_ensemble_node.vectorfield,
                                D_sys=real_system.D_sys, 
                                D_control=real_system.D_control,
                                ensemble_size=ensemble_size,
                                n_segments=n_segments,
                                state_cost=env.state_cost,
                                termination_cost=env.termination_cost,
                                maxiter=maxiter, 
                                lb=env.lb*jnp.ones((us_init.shape[0], real_system.D_control)),  ## To do: scale thios with the u standardization!
                                ub=env.ub*jnp.ones((us_init.shape[0], real_system.D_control)),
                                # standardize_x=data.standardize,
                                # inverse_standardize_x=data.inverse_standardize,
                                # standardize_u=data.standardize_us,
                                # inverse_standardize_u=data.inverse_standardize_us,
                                # sigma_x = data._original_ys_std,
                                # sigma_u = data._original_us_std,
                                )


        ensemble_indirect_mpc = ensembleIndirectMPC(traj_optimizer=pmp_solver,
                real_system=real_system,
                internal_system=opt_ensemble_node,
                state_cost=env.state_cost,
                termination_cost=env.termination_cost,
                verbose=True,
                )
        trial_ts = jnp.linspace(env.t0, env.tf, int( ((1/env.Delta_t))*(env.tf-env.t0))) #new_trial_dataset.ts

        ts, ts_dense, X, Y, U, R = ensemble_indirect_mpc.simulate(
                        x0=env.get_initial_condition(subkey), 
                        ts=trial_ts,
                        Delta_t=env.Delta_t,
                        x_star=env.x_star,
                        dt0_internal=dt0_internal,
                        dt0_dense=dt0_dense,
                        H=H,
                        obs_noise=0.,
                        key=subkey,
                        # standardize_x=data.standardize,
                        # inverse_standardize_x=data.inverse_standardize,
                        # standardize_u=data.standardize_us,
                        # inverse_standardize_u=data.inverse_standardize_us,
                    )
        key, subkey = jr.split(key)


        pmp_ensemble_trial_cost = R[-1]
        plt.plot(ts_dense, R)
        plt.title(f'Indirect approach with NODE: Integrated Cost: {pmp_ensemble_trial_cost}')
        plt.ylabel('Cost ')
        plt.xlabel('t')
        plt.show()

        
        # visualize 
        labels = [r'$x$', r'$\theta$', r'$\dot{x}$', r'$\dot{\theta}$']
        for i in range(X.shape[-1]):
                dim_color = f'C{i*2}'
                plt.plot(ts_dense, X[:,i], label=labels[i], color=dim_color)

        plt.axhline(y=env.x_star[0], color='C0', linestyle=':')
        plt.axhline(y=env.x_star[1], color='C2', linestyle=':')
        plt.axhline(y=env.x_star[2], color='C4', linestyle=':')
        plt.axhline(y=env.x_star[3], color='C4', linestyle=':')
        plt.title(f'Integrated Cost: {R[-1]}')
        plt.ylabel('x')
        plt.xlabel('t')
        plt.legend()
        plt.show()

        return ts, ts_dense, X, Y, U, R



In [None]:
def split_dataset(dataset:DiffEqDataset, timepoints_per_split:int):
    """
        Take a single trajectory dataset, and split it into N trajectories of length T_n. 
        This assumes the timepoints_per_split argument is correct for N. 
    """
    ys, us, ts, ts_dense = dataset.ys[0], dataset.us[0], dataset.ts[0], dataset.ts_dense[0]
    num_new_datasets = int(ts.shape[0]/timepoints_per_split)
    dense_timepoints_per_split = int(timepoints_per_split * (dataset.ts_dense[0].shape[0]/dataset.ts[0].shape[0]))
    new_dataset = None
    for n in range(num_new_datasets):
        ys_n = ys[n*timepoints_per_split:((n+1)*timepoints_per_split)][None]
        us_n = us[n*dense_timepoints_per_split:((n+1)*dense_timepoints_per_split)][None]
        ts_n = ts[n*timepoints_per_split:((n+1)*timepoints_per_split)][None]
        ts_dense_n =ts_dense[n*dense_timepoints_per_split:((n+1)*dense_timepoints_per_split)][None]
        data_n = DiffEqDataset(ys=ys_n, us=us_n, ts=ts_n, ts_dense=ts_dense_n, standardize_at_initialisation=False)
        if new_dataset is None:
            new_dataset = data_n
        else:
            new_dataset = new_dataset + data_n

    return new_dataset


## Run RL trials

In [None]:
data.ts.shape, data.us.shape

In [None]:
trial_costs = []
for i in range(num_trials):
    print(f'Starting trial {i+1}/{num_trials}')

    # train deep NODE ensemble 
    key, subkey = jr.split(key)
    opt_ensemble_node = train_network(data, subkey, num_iters = num_iters)#num_iters + 200*i)

    # run trial
    # key, subkey = jr.split(key)
    # ts, ts_dense, X, Y, U, R = run_single_trial(opt_ensemble_node=opt_ensemble_node, key=subkey, data=data)
    # trial_costs.append(R[-1])

    # augment dataset
    # new_trial_dataset = DiffEqDataset(ts[None,...], Y[None,...], U[None,...], ts_dense=ts_dense[None,...], standardize_at_initialisation=True)
    # new_trial_dataset.__post_init__()

    # # split into multiple shorter trajectories.
    # new_trial_datasets = split_dataset(new_trial_dataset, 40)
    # new_trial_datasets.__post_init__()
    
    # # add with previous data.
    # data = data + new_trial_datasets
    # data.__post_init__()

    
trial_costs = jnp.array(trial_costs)
plt.figure()
plt.title('Costs over time')
plt.xlabel('Trial')
plt.ylabel('Cost')
plt.plot(jnp.arange(trial_costs.shape[0]), trial_costs)
plt.show()


In [None]:
# data.standardize_us(jnp.zeros((2,1))), data._original_us_mean

In [None]:
# data.ys.shape, data.us.shape, new_trial_dataset.us.shape

In [None]:
# print(new_trial_datasets.ys.shape, new_trial_datasets.ts.shape, new_trial_datasets.us.shape, new_trial_datasets.ts_dense.shape)
# data.ys.shape, data.ts.shape, data.us.shape, data.ts_dense.shape

In [None]:
def run_single_trial_sqp(opt_ensemble_node, key:jr.PRNGKey):

        key, subkey = jr.split(key)        
        
        us_init = jnp.zeros((jnp.arange(0, H, env.Delta_t).shape[0], env.D_control))
   
        adam_solver = SLSQP(
                # step_size=0.05,
                lb=env.lb*jnp.ones_like(us_init),
                ub=env.ub*jnp.ones_like(us_init),
                maxiter=maxiter,)
        
        ensemble_direct_mpc = ensembleDirectMPC(
                traj_optimizer=adam_solver,
                real_system=env.real_system,
                internal_system=opt_ensemble_node,
                state_cost=env.state_cost,
                termination_cost=env.termination_cost,
                verbose=True,
                )
        trial_ts = jnp.linspace(env.t0, env.tf, int( ((1/env.Delta_t))*(env.tf-env.t0))) #new_trial_dataset.ts


        ts, ts_dense, X, Y, U, R = ensemble_direct_mpc.simulate(x0=env.get_initial_condition(),
                ts=trial_ts,
                Delta_t=env.Delta_t,
                dt0_dense=dt0_dense,
                x_star=env.x_star,
                H=H, 
                )


        pmp_ensemble_trial_cost = R[-1]
        plt.plot(ts_dense, R)
        plt.title(f'SQP: Direct approach with NODE: Integrated Cost: {pmp_ensemble_trial_cost}')
        plt.ylabel('Cost ')
        plt.xlabel('t')
        plt.show()

        
        # visualize 
        labels = [r'$x$', r'$\theta$', r'$\dot{x}$', r'$\dot{\theta}$']
        for i in range(X.shape[-1]):
                dim_color = f'C{i*2}'
                plt.plot(ts_dense, X[:,i], label=labels[i], color=dim_color)

        plt.axhline(y=env.x_star[0], color='C0', linestyle=':')
        plt.axhline(y=env.x_star[1], color='C2', linestyle=':')
        plt.axhline(y=env.x_star[2], color='C4', linestyle=':')
        plt.axhline(y=env.x_star[3], color='C4', linestyle=':')
        plt.title(f'Integrated Cost: {R[-1]}')
        plt.ylabel('x')
        plt.xlabel('t')
        plt.legend()
        plt.show()

        return ts, ts_dense, X, Y, U, R


In [None]:
def run_single_trial_bfgs(opt_ensemble_node, key:jr.PRNGKey):

        key, subkey = jr.split(key)        
  
        us_init = jnp.zeros((jnp.arange(0, H, env.Delta_t).shape[0], env.D_control))
   
        bfgs_solver = LBFGSB(
                # step_size=0.05,
                lb=env.lb*jnp.ones_like(us_init),
                ub=env.ub*jnp.ones_like(us_init),
                maxiter=maxiter,)
        
        ensemble_direct_mpc = ensembleDirectMPC(
                traj_optimizer=bfgs_solver,
                real_system=env.real_system,
                internal_system=opt_ensemble_node,
                state_cost=env.state_cost,
                termination_cost=env.termination_cost,
                verbose=True,
                )
        trial_ts = jnp.linspace(env.t0, env.tf, int( ((1/env.Delta_t))*(env.tf-env.t0))) #new_trial_dataset.ts


        ts, ts_dense, X, Y, U, R = ensemble_direct_mpc.simulate(x0=env.get_initial_condition(),
                ts=trial_ts,
                Delta_t=env.Delta_t,
                dt0_dense=dt0_dense,
                x_star=env.x_star,
                H=H, 
                )


        pmp_ensemble_trial_cost = R[-1]
        plt.plot(ts_dense, R)
        plt.title(f'BFGS: Direct approach with NODE: Integrated Cost: {pmp_ensemble_trial_cost}')
        plt.ylabel('Cost ')
        plt.xlabel('t')
        plt.show()

        
        # visualize 
        labels = [r'$x$', r'$\theta$', r'$\dot{x}$', r'$\dot{\theta}$']
        for i in range(X.shape[-1]):
                dim_color = f'C{i*2}'
                plt.plot(ts_dense, X[:,i], label=labels[i], color=dim_color)

        plt.axhline(y=env.x_star[0], color='C0', linestyle=':')
        plt.axhline(y=env.x_star[1], color='C2', linestyle=':')
        plt.axhline(y=env.x_star[2], color='C4', linestyle=':')
        plt.axhline(y=env.x_star[3], color='C4', linestyle=':')
        plt.title(f'Integrated Cost: {R[-1]}')
        plt.ylabel('x')
        plt.xlabel('t')
        plt.legend()
        plt.show()

        return ts, ts_dense, X, Y, U, R


In [None]:
def run_single_trial_cem(opt_ensemble_node, key:jr.PRNGKey):

        key, subkey = jr.split(key)        

        us_init = jnp.zeros((jnp.arange(0, H, env.Delta_t).shape[0], env.D_control))
        cem_solver = iCEM(
                pop_size=700,
                elite_size=int(0.13*700),
                # alpha=0.3,
                # step_size=0.1,
                beta=1,
                lb=env.lb*jnp.ones_like(us_init),
                ub=env.ub*jnp.ones_like(us_init),
                maxiter=maxiter,)
        
        ensemble_direct_mpc = ensembleDirectMPC(
                traj_optimizer=cem_solver,
                real_system=env.real_system,
                internal_system=opt_ensemble_node,
                state_cost=env.state_cost,
                termination_cost=env.termination_cost,
                verbose=True,
                )
        trial_ts = jnp.linspace(env.t0, env.tf, int( ((1/env.Delta_t))*(env.tf-env.t0))) #new_trial_dataset.ts


        ts, ts_dense, X, Y, U, R = ensemble_direct_mpc.simulate(x0=env.get_initial_condition(),
                ts=trial_ts,
                Delta_t=env.Delta_t,
                dt0_dense=dt0_dense,
                x_star=env.x_star,
                H=H, 
                )


        pmp_ensemble_trial_cost = R[-1]
        plt.plot(ts_dense, R)
        plt.title(f'CEM: direct approach with NODE: Integrated Cost: {pmp_ensemble_trial_cost}')
        plt.ylabel('Cost ')
        plt.xlabel('t')
        plt.show()

        
        # visualize 
        labels = [r'$x$', r'$\theta$', r'$\dot{x}$', r'$\dot{\theta}$']
        for i in range(X.shape[-1]):
                dim_color = f'C{i*2}'
                plt.plot(ts_dense, X[:,i], label=labels[i], color=dim_color)

        plt.axhline(y=env.x_star[0], color='C0', linestyle=':')
        plt.axhline(y=env.x_star[1], color='C2', linestyle=':')
        plt.axhline(y=env.x_star[2], color='C4', linestyle=':')
        plt.axhline(y=env.x_star[3], color='C4', linestyle=':')
        plt.title(f'Integrated Cost: {R[-1]}')
        plt.ylabel('x')
        plt.xlabel('t')
        plt.legend()
        plt.show()

        return ts, ts_dense, X, Y, U, R


In [1]:
def run_single_trial_adam(opt_ensemble_node, key:jr.PRNGKey):

        key, subkey = jr.split(key)        
        us_init = jnp.zeros((jnp.arange(0, H, env.Delta_t).shape[0], env.D_control))
        adam_solver = Adam(
                step_size=0.01,
                lb=env.lb*jnp.ones_like(us_init),
                ub=env.ub*jnp.ones_like(us_init),
                maxiter=maxiter,)
        
        ensemble_direct_mpc = ensembleDirectMPC(
                traj_optimizer=adam_solver,
                real_system=env.real_system,
                internal_system=opt_ensemble_node,
                state_cost=env.state_cost,
                termination_cost=env.termination_cost,
                verbose=True,
                )
        trial_ts = jnp.linspace(env.t0, env.tf, int( ((1/env.Delta_t))*(env.tf-env.t0))) #new_trial_dataset.ts


        ts, ts_dense, X, Y, U, R = ensemble_direct_mpc.simulate(x0=env.get_initial_condition(),
                ts=trial_ts,
                Delta_t=env.Delta_t,
                dt0_dense=dt0_dense,
                x_star=env.x_star,
                H=H, 
                )


        pmp_ensemble_trial_cost = R[-1]
        plt.plot(ts_dense, R)
        plt.title(f'Adam: direct approach with NODE: Integrated Cost: {pmp_ensemble_trial_cost}')
        plt.ylabel('Cost ')
        plt.xlabel('t')
        plt.show()

        
        # visualize 
        labels = [r'$x$', r'$\theta$', r'$\dot{x}$', r'$\dot{\theta}$']
        for i in range(X.shape[-1]):
                dim_color = f'C{i*2}'
                plt.plot(ts_dense, X[:,i], label=labels[i], color=dim_color)

        plt.axhline(y=env.x_star[0], color='C0', linestyle=':')
        plt.axhline(y=env.x_star[1], color='C2', linestyle=':')
        plt.axhline(y=env.x_star[2], color='C4', linestyle=':')
        plt.axhline(y=env.x_star[3], color='C4', linestyle=':')
        plt.title(f'Integrated Cost: {R[-1]}')
        plt.ylabel('x')
        plt.xlabel('t')
        plt.legend()
        plt.show()

        return ts, ts_dense, X, Y, U, R


NameError: name 'jr' is not defined

In [None]:
subkey

## PMP

### Mean Hamiltonian

In [None]:
_ = run_single_trial_meanHamiltonian(opt_ensemble_node=opt_ensemble_node, key=subkey, data=data)

### Mean Posterior

In [None]:
_ = run_single_trial(opt_ensemble_node=opt_ensemble_node, key=subkey, data=data)

## CEM



In [None]:
_ = run_single_trial_cem(opt_ensemble_node=opt_ensemble_node, key=subkey)#, data=data)

## SQP

In [None]:
_ = run_single_trial_sqp(opt_ensemble_node=opt_ensemble_node, key=subkey)#, data=data)

## Adam

In [None]:
_ = run_single_trial_adam(opt_ensemble_node=opt_ensemble_node, key=subkey)#, data=data)

## BFGS

In [None]:
_ = run_single_trial_bfgs(opt_ensemble_node=opt_ensemble_node, key=subkey)#, data=data)

### Colors
Method 1: Blue (#1f77b4, a muted MATLAB-like blue).

Method 2: Red (#d62728, a subdued red).

Method 3: Green (#2ca02c, a soft green).
alpha 0.2 for confidence intervals.


##  Linwidth
linewidth=1.5 o

## background lines:
plt.grid(True, linestyle='--', alpha=0.5)


In [None]:
# import numpy as np
# import matplotlib.pyplot as plt

# # Dummy data
# trials = np.arange(1, 21)
# methods = {
#     "Method 1": np.exp(-trials/5) * 20 + np.random.normal(0, 5, (10, 20)),
#     "Method 2": np.exp(-trials/4) * 20 + np.random.normal(0, 5, (10, 20)),
#     "Method 3": np.exp(-trials/3) * 20 + np.random.normal(0, 5, (10, 20))
# }
# analytical_max = np.full(20, -5)  # Example baseline

# # Colors
# colors = ["#1f77b4", "#d62728", "#2ca02c"]

# # Plot
# plt.figure(figsize=(6, 4))
# for i, (name, data) in enumerate(methods.items()):
#     mean = data.mean(axis=0)
#     std = data.std(axis=0)
#     ci = 1.96 * std / np.sqrt(10)  # 95% CI
#     plt.plot(trials, mean, label=name, color=colors[i], linewidth=1.5, alpha=0.8)
#     plt.fill_between(trials, mean - ci, mean + ci, color=colors[i], alpha=0.2)

# plt.plot(trials, analytical_max, "--", color="#7f7f7f", label="True Model Minimum")
# plt.xlabel("Number of Trials", fontsize=12)
# plt.ylabel("Cost", fontsize=12)
# plt.grid(True, linestyle="--", alpha=0.5)
# plt.legend(loc="upper right", fontsize=10)
# plt.tight_layout()

# # Save for IEEE
# # plt.savefig("cartpole_cost.pdf", format="pdf", bbox_inches="tight")
# plt.show()


# bar_width = 0.2

# # Plot
# plt.figure(figsize=(6, 4))
# for i, (name, costs) in enumerate(methods.items()):
#     mean = costs.mean(axis=0)
#     std = costs.std(axis=0)
#     ci = 1.96 * std / np.sqrt(10)  # 95% CI
#     # Bars with error bars
#     plt.bar(trials + i * bar_width - bar_width, mean, width=bar_width, 
#             color=colors[i], alpha=0.8, label=name, yerr=ci, capsize=2, error_kw={"ecolor": "black"})
#     # Interpolated mean line
#     plt.plot(trials, mean, color=colors[i], linewidth=1.5)

# plt.plot(trials, analytical_max, "--", color="#7f7f7f", label="Analytical Maximum", linewidth=1.5)
# plt.xlabel("Number of Trials", fontsize=12)
# plt.ylabel("Cost", fontsize=12)
# plt.xticks(np.arange(0, 20, 5))
# plt.grid(True, linestyle="--", alpha=0.5)
# plt.legend(loc="upper right", bbox_to_anchor=(1.15, 1), fontsize=10)
# plt.tight_layout()