In [8]:
from dataclasses import dataclass
import jax
import jax.numpy as jnp
import numpy as np
import optax
import equinox as eqx
from helpers import random_initial_state

import exciting_environments as excenvs

@dataclass
class EnvParameters:
    name: str
    batch_size: int
    l: float
    m: float
    tau: float
    max_torque: float

@dataclass
class DPCParameters:
    layer_sizes: list
    learning_rate: float
    optimizer_type: str
    epochs: int

@dataclass
class TrainingResults:
    losses: list
    state_trajectories: list
    action_trajectories: list

def train_dpc_controller(env_params, dpc_params, key):
    env = excenvs.make(env_params.name, batch_size=env_params.batch_size, l=env_params.l, m=env_params.m, tau=env_params.tau, max_torque=env_params.max_torque)

    class PolicyNetwork(eqx.Module):
        layers: list[eqx.nn.Linear]
        
        def __init__(self, layer_sizes, key):
            self.layers = []
            for (fan_in, fan_out) in zip(layer_sizes[:-1], layer_sizes[1:]):
                key, subkey = jax.random.split(key)
                self.layers.append(eqx.nn.Linear(fan_in, fan_out, use_bias=True, key=subkey))
        
        def __call__(self, x):
            for layer in self.layers[:-1]:
                x = jax.nn.tanh(layer(x))
            return jnp.tanh(self.layers[-1](x))

    ref_state = jnp.tile(jnp.array([[0, 0]]), (env_params.batch_size, 1)).astype(jnp.float32)

    @eqx.filter_jit
    def loss_fn(policy, initial_state, ref_state, key):
        def generate_actions(carry, _):
            state, key = carry
            key, subkey = jax.random.split(key)
            policy_params = jnp.concatenate([state, ref_state], axis=-1)
            action = jax.vmap(policy)(policy_params)
            next_state = jax.vmap(env._ode_exp_euler_step)(state, action, env.env_state_normalizer, env.action_normalizer, env.static_params)
            return (next_state, key), (next_state, action, state)
        (_, (predict_states, actions, initial_states)) = jax.lax.scan(generate_actions, (initial_state, key), None, length=3000)
        mse = jnp.mean((predict_states - ref_state)**2)
        return mse, predict_states, actions, initial_states

    @eqx.filter_value_and_grad
    def compute_loss(policy, initial_state, key):
        mse_loss, _, _, _ = loss_fn(policy, initial_state, ref_state, key)
        return mse_loss

    if dpc_params.optimizer_type == 'adam':
        optimizer = optax.adam(dpc_params.learning_rate)
    elif dpc_params.optimizer_type == 'rmsprop':
        optimizer = optax.rmsprop(dpc_params.learning_rate, decay=0.9, eps=1e-8)
    elif dpc_params.optimizer_type == 'sgd':
        optimizer = optax.sgd(dpc_params.learning_rate, momentum=0.9)
    else:
        raise ValueError(f"Unknown optimizer type: {dpc_params.optimizer_type}")
    
    key, subkey = jax.random.split(key)
    policy = PolicyNetwork(dpc_params.layer_sizes, key=key)
    opt_state = optimizer.init(policy)

    @eqx.filter_jit
    def update_state(policy, initial_state, key, opt_state):
        loss, grads = compute_loss(policy, initial_state, key)
        updates, opt_state = optimizer.update(grads, opt_state)
        policy = eqx.apply_updates(policy, updates)
        return loss, policy, opt_state

    losses = []
    state_trajectories = []
    action_trajectories = []

    for epoch in range(dpc_params.epochs):
        jax_key, subkey = jax.random.split(key)
        batch_initial_states = random_initial_state(subkey, env_params.batch_size)
        loss, policy, opt_state = update_state(policy, batch_initial_states, subkey, opt_state)
        losses.append(loss.item())
        
        _, predicted_states, actions, initial_states = loss_fn(policy, batch_initial_states, ref_state, subkey)
        state_trajectories.append(predicted_states)
        action_trajectories.append(actions)

    results = TrainingResults(losses=losses, state_trajectories=state_trajectories, action_trajectories=action_trajectories)
    return policy, results


In [2]:
#import exciting_environments as excenvs

In [3]:
#env = excenvs.make('Pendulum-v0', batch_size=1, l=1, m=1, tau=1e-2, max_torque=2)

In [2]:
#from helpers import random_initial_state, plot_results