<a href="https://colab.research.google.com/github/UW-CTRL/lmc-exercises/blob/main/01_dynamics%2BJAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dynamics, trajectories, and JAX
In this problem, you will familiarize yourself with setting up a dynamical system, generating trajectories, and learning about JAX.



## 1. Dynamics and trajectories

First, let's import the relevant packages

In [None]:
import abc
from typing import Callable
import jax.numpy as jnp
import matplotlib.pyplot as plt
import functools
import jax
import numpy as np

*NOTE:* We are using `jax.numpy` and the next section of this exercise will go more into JAX. In the meantime, just use `jax.numpy` (`jnp`) as if it was just the regular `numpy` (`np`).

Now, let's construct a base class to define dynamics of a system.

In [None]:
class Dynamics(metaclass=abc.ABCMeta):
    """Abstract base class for dynamical systems."""

    dynamics_func: Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray]
    state_dim: int
    control_dim: int

    def __init__(
        self,
        dynamics_func: Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray],
        state_dim: int,
        control_dim: int,
    ):
        """Initializes the Dynamics object.

        Args:
            dynamics_func: A callable representing the dynamics function.
            state_dim: The dimension of the state space.
            control_dim: The dimension of the control space.
        """
        self.dynamics_func = dynamics_func
        self.state_dim = state_dim
        self.control_dim = control_dim

    def __call__(
        self, state: jnp.ndarray, control: jnp.ndarray, time: float = 0.0
    ) -> jnp.ndarray:
        """Evaluates the dynamics function at a given state, control, and time.

        Args:
            state: The current state.
            control: The current control input.
            time: The current time (optional, defaults to 0).

        Returns:
            The next state.
        """
        return self.dynamics_func(state, control, time)

### (a) Obtaining discrete-time dynamics

With the continuous time dynamics, we can obtain the discrete time dynamics by integrating over a time step $\Delta t$.



#### (i) Implement both Euler integation  to obtain the discrete-time dynamics.

In [None]:
# TODO in class
def euler_integrate(
    dynamics: Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray], dt: float
) -> Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray]:
    """
    Implement Euler integration for discrete-time dynamics.

    Args:
        dynamics: A callable representing the continuous-time dynamics function.
        dt: The time step for integration.

    Returns:
        A callable representing the discrete-time dynamics using Euler integration.
    """

    # zero-order hold
    def integrator(x: jnp.ndarray, u: jnp.ndarray, t: float) -> jnp.ndarray:
        # TODO: Implement Euler integration here
        raise NotImplementedError # remove this
    return integrator

#### (ii) Implement runge-Kutta integration to obtain the discrete-time dynamics.

In [None]:
# TODO in class
def runge_kutta_integrate(
    dynamics: Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray], dt: float
) -> Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray]:
    """
    Implement Runge-Kutta integration for discrete-time dynamics.

    Args:
        dynamics: A callable representing the continuous-time dynamics function.
        dt: The time step for integration.

    Returns:
        A callable representing the discrete-time dynamics using Runge-Kutta integration.
    """

    # zero-order hold
    def integrator(x: jnp.ndarray, u: jnp.ndarray, t: float) -> jnp.ndarray:
        # TODO: Implement Runge-Kutta integration here
        raise NotImplementedError # remove this
    return integrator

### (b) Setting up unicycle dynamics
Using the `Dynamics` class, construct the continuous time dynamics for the dynamically extended unicycle model.

$$
    \dot{\mathbf{x}} = \begin{bmatrix}
        \dot{x} \\ \dot{y} \\ \dot{\theta} \\ \dot{v}
    \end{bmatrix} = \begin{bmatrix}
        v\cos\theta \\ v\sin\theta \\ \omega \\ a
    \end{bmatrix}, \qquad u=(\omega, a)
$$

In [None]:
# TODO with peers


def unicycle_dynamics_func(state, control, time=0):
    """
    Define the unicycle dynamics equations here.

    Args:
      state: The current state of the unicycle (x, y, theta, v).
      control: The control input (omega, a).
      time: The current time (optional).

    Returns:
      The time derivative of the state (dot_x, dot_y, dot_theta, dot_v).
    """
    # TODO: Implement the unicycle dynamics
    pass

### (c) Simulate discrete-time dynamics

Now, we can construct *discrete time* dynamics using the different intergation schemes, and simulate it over some horizon

In [None]:
dt = 0.1

state_dim = 4
control_dim = 2
continuous_dynamics = Dynamics(unicycle_dynamics_func, state_dim, control_dim)

discrete_dynamics_euler = Dynamics(
    euler_integrate(continuous_dynamics, dt), state_dim, control_dim
)
discrete_dynamics_rk = Dynamics(
    runge_kutta_integrate(continuous_dynamics, dt), state_dim, control_dim
)

In [None]:
# TODO in class

def simulate_dynamics(
    dynamics: Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray],
    initial_state: jnp.ndarray,
    control_sequence: jnp.ndarray,
    time_horizon: float,
    dt: float,
) -> jnp.ndarray:
    """
    Simulate the discrete-time dynamics over a given time horizon.

    Args:
        dynamics: A callable representing the discrete-time dynamics function.
        initial_state: The initial state of the system.
        control_sequence: A sequence of control inputs for each time step.
        time_horizon: The total time duration for the simulation.
        dt: The time step size.

    Returns:
        An array containing the state trajectory over the time horizon.
    """
    # TODO: Implement dynamics simulation here
    raise NotImplementedError # remove this


### (d) Compare integration schemes


#### (i) Simulate trajectories

Simulate your dynamics over 5 seconds for different values of $\Delta t$ and compare the trajectories.

Show on the same plot, the simulated trajectories for the following cases:
- Discrete-time dynamics with Euler integration, $\Delta t = 0.01$
- Discrete-time dynamics with Euler integration, $\Delta t = 0.5$
- Discrete-time dynamics with RK integration, $\Delta t = 0.01$
- Discrete-time dynamics with RK integration, $\Delta t = 0.5$



In [None]:
# TODO with peers



#### (ii) Discuss choice of integration schemes
How does the choice of integration scheme and time step size influence the resulting trajectories?

In [None]:
# DISCUSS with peers

## JAX

So far, we have just written code that we could have done with regular numpy. So why use JAX?

JAX is particularly powerful for control applications due to its ability to perform **automatic differentiation** (autograd), its **just-in-time (JIT) compilation** capabilities, and its support for **automatic vectorization** (`vmap`). Autograd is essential for optimizing control policies and estimating system parameters, while JIT compilation and `vmap` significantly speed up numerical computations, which are prevalent in simulations and model predictive control. This combination allows for efficient development and deployment of complex control algorithms.



### (a) `jax.vmap`

Suppose now that you want to simulate *many* trajectories. Rather than wrapping the `simulate` function in a for loop for multiple initial states and control sequences, we can use `jax.vmap` which is a *vectorize map* function, allowing us to apply a function, in this case `simulate` over *batched* inputs.

An example usage of how to use the `jax.vmap` is shown below. Notice that we can specify which argument should be vectorized and along which dimension.

In [None]:
# TODO in class
def foo(x, y, z):
    return x + y + z


N = 1000
x = jnp.array(np.random.randn(N))
y = jnp.array(np.random.randn(N))
z = jnp.array(np.random.randn(N))

xs = jnp.array(np.random.randn(N, N))
ys = jnp.array(np.random.randn(N, N))
zs = jnp.array(np.random.randn(N, N))

foo(x, y, z)  # non-vectorized version
# vectorized version for all inputs, 0 is the batch dimension for all inputs
jax.vmap(foo, in_axes=[0, 0, 0])(xs, ys, zs)

# x not batched, but ys and zs are with 0 as the batch dimension
jax.vmap(foo, in_axes=[None, 0, 0])(x, ys, zs)

# y not batched, but xs and zs are with 0 as the batch dimension
jax.vmap(foo, in_axes=[0, None, 0])(xs, y, zs)

# z not batched, but xs and ys are with 0 as the batch dimension
jax.vmap(foo, in_axes=[0, 0, None])(xs, ys, z)

# x and y not batched, but zs is with 0 as the batch dimension
jax.vmap(foo, in_axes=[None, None, 0])(x, y, zs)

# vectorized version for all inputs, batch dimension for xs is 1,
# while 0 is the batch dimension for yx and zs
jax.vmap(foo, in_axes=[1, 0, 0])(xs, ys, zs)


Apply `jax.vmap` for the simulate function for the following batch of initial states and control inputs.

Use the following values and simulate multiple trajectories using the `jax.vmap` function.

In [None]:
bs = 1024
time_horizon = 2  # seconds
dt = 0.1
n_steps = int(time_horizon / dt)

initial_states = jnp.array(np.random.rand(bs, state_dim))
control_sequences = jnp.array(np.random.rand(bs, n_steps, control_dim))
dynamics = Dynamics(
    runge_kutta_integrate(continuous_dynamics, dt), state_dim, control_dim
)


In [None]:
# TODO with peers
# use `jax.vmap` to simulate all trajectories and plot


### (b) `jax.jit`

Bleh! You notice that it takes some time to run it. And if you increased the duration or number of trajectories to simulate, the computation would increase.
If only we could compile the code to help reduce computation time. With JAX, you can! We can use the `jax.jit` function that performs just-in-time compilation. JAX will figure out the expected sizes of the input arrays and allocate memory based on that.

There are number of ways to just `jax.jit`, and it can get a bit tricky as your code becomes more complex. Best to read up the JAX documentation for more information.
But for relatively simple functions, you can usually just apply `jax.jit` without any fuss, and get significant speedup in your code.

In [None]:
%timeit  jax.vmap(simulate_dynamics, in_axes=(None, 0, 0, None, None))(discrete_dynamics_rk, initial_states, control_sequences, time_horizon, dt)


In [None]:
# method 1: directly apply jax.jit over the jax.vmap function
# need to provide the static_argnums argument to the first argument since that is a function input and not an array input
sim_jit = jax.jit(
    jax.vmap(simulate_dynamics, in_axes=[None, 0, 0, None, None]), static_argnums=0
)


In [None]:
# time the run
%timeit sim_jit(discrete_dynamics_rk, initial_states, control_sequences, time_horizon, dt).block_until_ready()

In [None]:
# method 2: apply jax.jit over the simulate function and then apply jax.vmap
sim_jit = jax.jit(simulate_dynamics, static_argnums=0)
sim_jit_vmap = jax.vmap(sim_jit, in_axes=[None, 0, 0, None, None])


In [None]:
%timeit sim_jit_vmap(discrete_dynamics_rk, initial_states, control_sequences, time_horizon, dt).block_until_ready()

In [None]:
# method 3: apply jax.jit over the simulate function during function construction and then apply jax.vmap
@functools.partial(jax.jit, static_argnames=("dynamics"))
def simulate_dynamics(
    dynamics: Callable[[jnp.ndarray, jnp.ndarray, float], jnp.ndarray],
    initial_state: jnp.ndarray,
    control_sequence: jnp.ndarray,
    time_horizon: float,
    dt: float,
) -> jnp.ndarray:
    """
    Simulate the discrete-time dynamics over a given time horizon.

    Args:
        dynamics: A callable representing the discrete-time dynamics function.
        initial_state: The initial state of the system.
        control_sequence: A sequence of control inputs for each time step.
        time_horizon: The total time duration for the simulation.
        dt: The time step size.

    Returns:
        An array containing the state trajectory over the time horizon.
    """
    # TODO: Implement dynamics simulation here
    # raise NotImplementedError # remove this
    xs = [initial_state]
    time = 0
    for u in control_sequence:
        xs.append(dynamics(xs[-1], u, time))
        time += dt
    return jnp.stack(xs)


In [None]:
sim_jit_vmap = jax.vmap(simulate_dynamics, in_axes=[None, 0, 0, None, None])


In [None]:
%timeit sim_jit_vmap(discrete_dynamics_rk, initial_states, control_sequences, time_horizon, dt).block_until_ready()

### (c) jax.grad

JAX's `jax.grad` function provides automatic differentiation, allowing you to easily compute gradients of your functions, up to computer precision.

In [None]:
# TODO in class
def test_function_scalar(x, y):
    return y * x[0] ** 2 + jnp.sin(x[1])


# df/dx = 2x
# df/dy = cos(y)

x_input = jnp.array([2.0, 1.0])
y_input = 1.0

# compute gradient
jax.grad(test_function_scalar, argnums=[0, 1])(x_input, y_input)

# compute jacobian
jax.jacobian(test_function_scalar, argnums=[0, 1])(x_input, y_input)

# # compute hessian
jax.hessian(test_function_scalar, argnums=[0, 1])(x_input, y_input)


In [None]:
# TODO in class
def test_function_scalar(x, y):
    return jnp.array([y * x[0] ** 2, jnp.sin(x[1])])


x_input = jnp.array([2.0, 1.0])
y_input = 1.0

# compute jacobian
jax.jacobian(test_function_scalar, argnums=[0, 1])(x_input, y_input)


#### (i) Linearize dynamics analytically

With the unicycle dynamics, linearize the dynamics about a point $(\mathbf{x}_0, \mathbf{u}_0)$. That is, for linearized dynamics of the form $\dot{\mathbf{x}} \approx A\mathbf{x}+ B\mathbf{u} + C$, give expressions for $A$, $B$, and $C$.

In [None]:
# TODO with peers
def linearized_unicycle_dynamics_analytic(
    state: jnp.ndarray, control: jnp.ndarray, time: float
):
    """
    Compute the linearized dynamics matrices A, B, and offset vector C
    for the unicycle model analytically around a given state and control point.

    Args:
      state: The state point (x, y, theta, v) around which to linearize.
      control: The control point (omega, a) around which to linearize.
      time: The time at which to evaluate the linearized dynamics.

    Returns:
      A tuple containing:
        A: The state matrix (∂f/∂x evaluated at the point).
        B: The control matrix (∂f/∂u evaluated at the point).
        C: The offset vector (f(x0, u0) - A*x0 - B*u0).
    """
    # TODO: Implement the analytical linearization here
    # Calculate A, B, and C based on the unicycle dynamics
    # A = ...
    # B = ...
    # C = ...

    raise NotImplementedError # remove this


#### (ii) Linearize dynamics using JAX

Using the autodiff capabilities of JAX, we can linearize the dynamics.

In [None]:
# TODO with peers

@functools.partial(jax.jit, static_argnames=("dynamics"))
def linearized_dynamics_jax(
    dynamics: Dynamics, state: jnp.ndarray, control: jnp.ndarray, time: float
):
    """
    Compute the linearized dynamics matrices A, B, and offset vector C
    for the unicycle model using jax.jacobian around a given state and control point.

    Args:
      dynamics: A Dynamics object representing the unicycle dynamics.
      state: The state point (x, y, theta, v) around which to linearize.
      control: The control point (omega, a) around which to linearize.
      time: The time at which to evaluate the linearized dynamics.

    Returns:
      A tuple containing:
        A: The state matrix (∂f/∂x evaluated at the point).
        B: The control matrix (∂f/∂u evaluated at the point).
        C: The offset vector (f(x0, u0) - A*x0 - B*u0).
    """
    # TODO: Implement the linearization using jax.jacobian here
    # A = ...
    # B = ...
    # C = ...
    raise NotImplementedError # remove this


#### (iii) Verify your answers match


In [None]:
# Define a sample state and control point
state_point = jnp.array([1.0, 2.0, 0.5, 3.0])  # Example state: x, y, theta, v
control_point = jnp.array([0.1, 0.2])  # Example control: omega, a
time = 0.0

dynamics = continuous_dynamics

# Compute linearized dynamics using both methods
A_analytic, B_analytic, C_analytic = linearized_unicycle_dynamics_analytic(
    state_point, control_point, time
)
A_jax, B_jax, C_jax = linearized_dynamics_jax(
    dynamics, state_point, control_point, time
)

# Verify if the results match
A_match = jnp.allclose(A_analytic, A_jax)
B_match = jnp.allclose(B_analytic, B_jax)
C_match = jnp.allclose(C_analytic, C_jax)

print(f"A matrices match: {A_match}")
print(f"B matrices match: {B_match}")
print(f"C vectors match: {C_match}")

# Optional: print the computed matrices/vectors
print("\nAnalytic:")
print("A:\n", A_analytic)
print("B:\n", B_analytic)
print("C:\n", C_analytic)
print("\nJAX:")
print("A:\n", A_jax)
print("B:\n", B_jax)
print("C:\n", C_jax)

#### (iv) Use JAX to linearize the discrete-time dynamics

In [None]:
# TODO in class


## Dynamics + JAX

Now that you can see how to construct dynamics, simulate them given a control sequence, and leveraging the power of JAX, perform the following things.

### (a) Linearize dynamics along trajectory

Given a control sequence, **simulate** the trajectory, and **linearize** the dynamics at each point along the trajectory.

In [None]:
# Define initial state and a random control sequence
initial_state = jnp.zeros(state_dim)
time_horizon = 5  # seconds
dt = 0.1
n_steps = int(time_horizon / dt)
times = jnp.linspace(0, time_horizon, n_steps)
np.random.seed(0)
control_sequence = jnp.array(np.random.randn(n_steps, control_dim))

dynamics = Dynamics(
    runge_kutta_integrate(continuous_dynamics, dt), state_dim, control_dim
)

In [None]:
# TODO by self
# This should take ~2-3 lines of code

# simulate

# linearize


### (b) Linearize along trajectory for multiple trajectories


Now suppose you are given multiple trajectories you want to linearize about.


In [None]:
# Define initial state and a random control sequence
bs = 64
time_horizon = 5  # seconds
dt = 0.1
n_steps = int(time_horizon / dt)

times = jnp.linspace(0, time_horizon, n_steps)
np.random.seed(0)

initial_states = jnp.array(np.random.rand(bs, state_dim))
control_sequences = jnp.array(np.random.rand(bs, n_steps, control_dim))

dynamics = Dynamics(
    runge_kutta_integrate(continuous_dynamics, dt), state_dim, control_dim
)


In [None]:
# TODO by self
# try make this as fast as possible by using jax.jit!

# simulate

# linearize
