<a href="https://colab.research.google.com/github/karenl7/AA548-spr2024/blob/dev/homework/hw1_starter_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

In [None]:
class DynUnicycle:

    def state_derivative(self, state, control):
        '''
        Computes x_dot where x_dot = f(x, u)
        Inputs:
            state     : A jax.numpy array of size (n,)
            control   : A jax.numpy array of size (m,)

        Output:
            state_derivative : A jax.numpy array of size (n,)
        '''
        pass

    def euler_step(self, state, control, dt):
        '''
        Computes x(k+1) using euler integration
        Inputs:
            state     : A jax.numpy array of size (n,)
            control   : A jax.numpy array of size (m,)
            dt        : time step, a float

        Output:
            next_state : A jax.numpy array of size (n,)
        '''
        pass

    def analytic_step(self, state, control, dt):
        '''
        Computes x(k+1) using analytic expression from integration.
        Assumes zero-order hold
        Inputs:
            state     : A jax.numpy array of size (n,)
            control   : A jax.numpy array of size (m,)
            dt        : time step, a float

        Output:
            next_state : A jax.numpy array of size (n,)
        '''
        pass

    def RK4_step(self, state, control, dt):
        '''
        Computes x(k+1) using Runge-Kutta integration.
        Assumes zero-order hold
        Inputs:
            state     : A jax.numpy array of size (n,)
            control   : A jax.numpy array of size (m,)
            dt        : time step, a float

        Output:
            next_state : A jax.numpy array of size (n,)
        '''
        pass

    def linearize_continuous_time_analytic(self, state, control):
        '''
        Linearizes the continuous time dynamics using analytic expression
        Inputs:
            state     : A jax.numpy array of size (n,)
            control   : A jax.numpy array of size (m,)

        Outputs:
            A : A jax.numpy array of size (n,n)
            B : A jax.numpy array of size (n,m)
            C : A jax.numpy array of size (n,1)
        '''
        pass

def discrete_time_simulate(discrete_time_dyn, initial_state, control_sequence):
    '''
    Propgates states through discrete_time_dyn using control_sequence, starting at initial_state.
    Inputs:
        discrete_time_dyn : A function that takes in a state and control, and returns the next state
        initial_state     : A jax.numpy array of size (n,)
        control_sequence  : A sequence of control inputs. A jax.numpy array of size (T,m) where T is the control inputs

    Output:
        A sequence of states from executing the control sequence. A jax.numpy array of size (T+1, n).
    '''
    pass

def linearize_autodiff(function_name, state, control):
    '''
    Linearizes the any dynamics using jax autodiff.
    Inputs:
        function_name: name of function to be linearized. Takes state and control as inputs.
        state     : A jax.numpy array of size (n,)
        control   : A jax.numpy array of size (m,)

    Outputs:
        A : A jax.numpy array of size (n,n)
        B : A jax.numpy array of size (n,m)
        C : A jax.numpy array of size (n,1)
    '''
    pass


# Problem 1

The plotting code is provided below. But you need to fill out some functions first in order to run it.
- `state_derivative`
- `euler_step`
- `analytic_step`
- `RK4_step`
- `discrete_time_simulate`


In [None]:
# generating initial state and control sequence
initial_state = jnp.array([0, 0, jnp.pi/4, 2.])  # initial state, size [4,] array
def generate_control_sequence(dt, horizon=5):
    N = round(horizon / dt)
    w_sequence = jnp.sin(jnp.linspace(0, jnp.pi, N))  # w varies sinusoidally over time
    return jnp.stack([w_sequence, jnp.zeros(N)], 1)  # a is zero (constant speed), size [T, 2] array

In [None]:
## Feel free to edit this code to tailor it to your implementation
unicycle = DynUnicycle()

plt.figure(figsize=(15,4))

dt_array = [0.01, 0.1, 0.5]

for (i,dt) in enumerate(dt_array):

    controls = generate_control_sequence(dt, 5)

    plt.subplot(1,3,i+1)
    traj = discrete_time_simulate(jax.jit(lambda s,c: unicycle.analytic_step(s,c,dt)), initial_state, controls)
    plt.plot(traj[:,0], traj[:,1], label="Analytic hold")
    plt.scatter(traj[:,0], traj[:,1], s=2)


    traj = discrete_time_simulate(jax.jit(lambda s,c: unicycle.euler_step(s,c,dt)), initial_state, controls)
    plt.plot(traj[:,0], traj[:,1], label="Euler")
    plt.scatter(traj[:,0], traj[:,1], s=2)

    traj = discrete_time_simulate(jax.jit(lambda s,c: unicycle.RK4_step(s,c,dt)), initial_state, controls)
    plt.plot(traj[:,0], traj[:,1], label="RK4")
    plt.scatter(traj[:,0], traj[:,1], s=2)

    plt.grid()
    plt.legend()

    plt.title("dt=%.3f"%dt)

# plt.savefig(...) # you can save fig

# Problem 2

In [None]:
x0 = jnp.array([0., 0., jnp.pi/4, 2.])
u0 = jnp.array([0.1, 1.])

### 2(c)  

You need to fill in code for `linearize` function.
And also `linearize_continuous_time_analytic` to test your analytic solution.

Some code to print out values from linearizing the continuous time dynamics about a point

In [None]:
## Feel free to edit this code to tailor it to your implementation

decimal_places = 2

A, B, C = unicycle.linearize_continuous_time_analytic(x0, u0)
print("Linearization from analytic expression")
print("A\n", round(A, decimal_places))
print("B\n", round(B, decimal_places))
print("C\n", round(C, decimal_places))

print("\n\n")

A, B, C = linearize_autodiff(unicycle.state_derivative, x0, u0)
print("Linearization using autodiff on continuous time dynamics")
print("A\n", round(A, decimal_places))
print("B\n", round(B, decimal_places))
print("C\n", round(C, decimal_places))


### 2(d)

Some code to print out values from linearizing the zero-order hold dynamics about a point

In [None]:
## Feel free to edit this code to tailor it to your implementation

decimal_places = 2

A, B, C = linearize_autodiff(lambda s,c: unicycle.analytic_step(s, c, dt), x0, u0)
print("Linearization using autodiff on analytic step ")
print("A\n", round(A, decimal_places))
print("B\n", round(B, decimal_places))
print("C\n", round(C, decimal_places))



A, B, C = linearize_autodiff(lambda s,c: unicycle.RK4_step(s, c, dt), x0, u0)
print("Linearization using autodiff on RK4 step ")
print("A\n", round(A, decimal_places))
print("B\n", round(B, decimal_places))
print("C\n", round(C, decimal_places))


## Problem 3

In [None]:
def cart_pole(state, control):
    '''Cart-pole continuous-time dynamics'''
    pass

def cart_pole_linearized(state, control, state_eq, control_eq):
    '''Linearize cart-pole continuous-time dynamics about equilibrium point'''
    pass

def linearization_error(state, control, state_eq, control_eq):
    '''Computes linearization error'''
    pass


In [None]:
## Feel free to edit this code to tailor it to your implementation

# x, theta, xdot, thetadot
state_eq = jnp.array([0., jnp.pi, 0.0, 0.0])
control_eq = jnp.zeros(1)
states = jnp.array([[0,0.99*jnp.pi,0,0],
                    [0,0.9*jnp.pi,0,0],
                    [0,0.85*jnp.pi,0,0],
                    [0,0.5*jnp.pi,0,0],
                    [0,0,0,0],
                    [1,jnp.pi,0,0]])
controls = jnp.array([[-0.],
                      [-10.],
                      [0.],
                      [0.],
                      [0.],
                      [10.]])

errors = jax.vmap(linearization_error, in_axes=[0, 0, None, None])(states, controls, state_eq, control_eq)


In [None]:
## Feel free to edit this code to tailor it to your implementation

for i, e in enumerate(errors):
    print(f"State = {jnp.around(states[i], decimals=3)}^T")
    print(f"Input = {jnp.around(controls[i], decimals=3)}")
    print("Linearization error = %.3f"%e)
    print("Norm of f(x,u) = %.3f\n"%jnp.linalg.norm(cart_pole(states[i], controls[i])))