## Problem 2


In this solution we will analyze the solution of problem 2 in assignment one. For each part of the exercise I will report: 
1. **Exercise Text:** Include the problem statement as in the original hw.
2. **Implementation:** Write the necessary Python code.
3. **Explanation:** Provide a thorough, explanation of each step, focusing on common errors.

---------------------------------------


We continue to consider the dynamically-extended unicycle model and investigate a way to linearize the dynamics around any state. 
First, we will perform the linearization analytically, and then leverage modern computation tools which will be incredibly helpful especially if the dynamics are complicated!
We can efficiently compute gradients via automatic differentiation.
JAX is an automatic differentiation library. 

### (a) Linearize dynamics analytically
Linearize the dynamics given in Problem 1 part (a) about a point $(\mathbf{x}_0, \mathbf{u}_0)$. That is, for linearized dynamics of the form $\dot{\mathbf{x}} \approx A\mathbf{x}+ B\mathbf{u} + C$, give expressions for $A$, $B$, and $C$. 

### Solution 

We can start the problem defining:

$$
\mathbf{x} = \begin{bmatrix}x \\\\ y \\\\ \theta \\\\ v\end{bmatrix},\quad
\mathbf{u} = \begin{bmatrix}\omega \\\\ a\end{bmatrix}
$$

and

$$
f(\mathbf{x},\mathbf{u}) =
\begin{bmatrix}
v\cos\theta \\\\
v\sin\theta \\\\
\omega \\\\
a
\end{bmatrix}.
$$

In the vectors, we are representing the position $(x,y)$, the orientation $\theta$, the linear speed $v$ and the two control inputs $\omega$ (turn rate) and $a$ (forward acceleration).

We can now make the following mathematical assumptions and simplifications:

1. $\dot{x} = v\cos\theta$  
If you move forward at speed $v$ while pointing at angle $\theta$, your horizontal velocity component is $v\cos\theta$.

2. $\dot{y} = v\sin\theta$  
Similarly, the vertical velocity is $v\sin\theta$.

3. $\dot{\theta} = \omega$  
The heading changes at the commanded turn rate $\omega$.

4. $\dot{v} = a$  
We assume the input $a$ directly controls the acceleration of the forward speed.

Now, we want to derive the linear approximation:

$$
\dot{\mathbf{x}} \approx A\,\mathbf{x} + B\,\mathbf{u} + C
$$

around an operating point $(\mathbf{x}_0, \mathbf{u}_0)$.

The first step is to define the point of linearization:

$$
\mathbf{x}_0 = \begin{bmatrix}x_0 \\\\ y_0 \\\\ \theta_0 \\\\ v_0\end{bmatrix}, \quad
\mathbf{u}_0 = \begin{bmatrix}\omega_0 \\\\ a_0\end{bmatrix}.
$$

By the theory of linearized dynamical systems, we compute the Jacobian of $f$ with respect to $\mathbf{x}$. In this way, we obtain the matrix $A$:

$$
A = \left.\frac{\partial f}{\partial \mathbf{x}}\right|_{(\mathbf{x}_0,\mathbf{u}_0)} =
\begin{bmatrix}
0 & 0 & -v\sin\theta & \cos\theta \\\\
0 & 0 &  v\cos\theta & \sin\theta \\\\
0 & 0 &  0           & 0          \\\\
0 & 0 &  0           & 0
\end{bmatrix}_{(\mathbf{x}_0, \mathbf{u}_0)}.
$$

The same can be done for $B$, by computing the Jacobian of $f$ with respect to $\mathbf{u}$:

$$
B = \left.\frac{\partial f}{\partial \mathbf{u}}\right|_{(\mathbf{x}_0,\mathbf{u}_0)} =
\begin{bmatrix}
0 & 0 \\\\
0 & 0 \\\\
1 & 0 \\\\
0 & 1
\end{bmatrix}.
$$

At this point, to calculate $C$, we can write down the matching condition:

$$
A\,\mathbf{x}_0 + B\,\mathbf{u}_0 + C = f(\mathbf{x}_0,\mathbf{u}_0),
$$

so that the linear model exactly reproduces the nonlinear dynamics at the chosen point. We can now write:

$$
C = f(\mathbf{x}_0,\mathbf{u}_0) - A\,\mathbf{x}_0 - B\,\mathbf{u}_0,
$$

which corresponds to the expression for the vector $C$.

The reason why $A$ and $B$ are defined as above is that, in the first-order Taylor expansion of $f$ around $(\mathbf{x}_0,\mathbf{u}_0)$,

$$
f(\mathbf{x}_0 + \Delta x,\, \mathbf{u}_0 + \Delta u)
\approx f(\mathbf{x}_0,\mathbf{u}_0)
+ \underbrace{\frac{\partial f}{\partial \mathbf{x}}}_{A}\,\Delta x
+ \underbrace{\frac{\partial f}{\partial \mathbf{u}}}_{B}\,\Delta u,
$$

the matrices of partial derivatives naturally capture how small perturbations $\Delta x$ and $\Delta u$ influence the time derivative $\dot{\mathbf{x}}$.


### (a) CODE IMPLEMENTATION

Let's start with importing the necessary packages to solve the exercise. Those are standard for everyone and they come from the different python libraries. 

In [27]:
import abc
from typing import Callable
import jax
import jax.numpy as jnp
from jax import jacrev
import matplotlib.pyplot as plt
import numpy as np
import functools
import cvxpy as cp

THe cell below contains some functions relative to exercise 1 of hw1 that will be necessary to run the code. 

In [28]:

class Dynamics(metaclass=abc.ABCMeta):
    dynamics_func: Callable
    state_dim: int
    control_dim: int

    def __init__(self, dynamics_func, state_dim, control_dim):
        self.dynamics_func = dynamics_func
        self.state_dim = state_dim
        self.control_dim = control_dim

    def __call__(self, state, control, time=0):
        return self.dynamics_func(state, control, time)
    
def dynamic_unicycle_ode(state, control, time):
    x = state[0]
    y = state[1]
    theta = state[2]
    v = state[3]
    omega = control[0]
    a = control[1]
    dx = v*jnp.cos(theta)
    dy = v*jnp.sin(theta)
    dtheta = omega
    dv = a 
    return jnp.array([dx, dy,dtheta,dv])


state_dim = 4
control_dim = 2
continuous_dynamics = Dynamics(dynamic_unicycle_ode, state_dim, control_dim)

def euler_integrate(dynamics, dt):
    # zero-order hold
    def integrator(x, u, t):
        derivative = dynamics(x, u, t)
        new_state = x+ dt*derivative
        return new_state
    return integrator

def runge_kutta_integrator(dynamics, dt=0.1):
    # zero-order hold
    def integrator(x, u, t):
        k1 = dynamics(x, u, t)
        k2= dynamics(x+ 0.5*dt*k1, u, t)
        k3 = dynamics(x+ 0.5*dt*k2, u, t)
        k4 = dynamics(x+dt*k3, u, t)

        return x+ dt/6*(k1 + 2*k2 + 2*k3 + k4) 
    return integrator


In [29]:
def linearize_unicycle_continuous_time_analytic(state, control, time):
    '''
    Linearizes the continuous time dynamics of the dynamic unicyle using analytic expression
    Inputs:
        state     : A jax.numpy array of size (n,)
        control   : A jax.numpy array of size (m,)
        time      : A real scalar

    Outputs:
        A : A jax.numpy array of size (n,n)
        B : A jax.numpy array of size (n,m)
        C : A jax.numpy array of size (n,1)
    '''
    theta = state[2]
    v = state[3]
    A = jnp.array([
    [0, 0, -v*jnp.sin(theta), jnp.cos(theta)],
    [0, 0,  v*jnp.cos(theta), jnp.sin(theta)],
    [0, 0, 0, 0],
    [0, 0, 0, 0]
    ])
    B = jnp.array([
    [0, 0],
    [0, 0],
    [1, 0],
    [0, 1]
    ])
    f = dynamic_unicycle_ode(state, control, time)# this function comes from exercise (a) problem 1
    C = f-A @ state - B@ control
    C = C.reshape((4,1))
    return A, B, C 
    


Even if the theory above explains how the code is made, I can now puntualize some concept of the code, trying to solve every doubt that someone might have. 

Why do we use:"theta = state[2]; v = state[3]??

The state vector is defined as:
$$
\mathbf{x} = \begin{bmatrix}x,\; y,\; \theta,\; v\end{bmatrix}^T,
$$  
so the indexing corresponds to:  
- `state[0]` → $x$  
- `state[1]` → $y$  
- `state[2]` → $\theta$, the heading angle  
- `state[3]` → $v$, the forward velocity  

This allow us extract theta and v from their correct positions to construct the matrix $A$.

Why we use reshape on C??

After computing  
$$
C = f - A\,\text{state} - B\,\text{control},
$$  
we obtain a NumPy array with shape (4,). By convention, and to avoid potential errors in matrix operations, it's better to reshape it into a column vector of shape (4, 1).  
This "reshape" ensures that $C$ is always treated not as a row vector. By the way, the code works even if we dont't write this line. 

Why we use f = dynamic_unicycle_ode(state, control, time)???

The function `dynamic_unicycle_ode` returns the actual vector of state derivatives at a given point. It is calculated and define in problem 1 part (a).

$$
f(\mathbf{x}, \mathbf{u}) =
\begin{bmatrix}
v\cos\theta \\\\
v\sin\theta \\\\
\omega \\\\
a
\end{bmatrix}.
$$

This gives the exact nonlinear $\dot{\mathbf{x}}$ for the current $(\mathbf{x}, \mathbf{u})$.  
Thanks to this we can calculate the matrix C as expalined before. 

Note: we could alternatively unpack omega = control[0] and a = control[1] and then define:

C = jnp.array([
    [v * jnp.cos(theta)],
    [v * jnp.sin(theta)],
    [omega],
    [a]
])

I think that is more intuitive to do everything in terms of the nonlinear model. 


### (b) Evaluate linearized dynamics (analytic)
Using your answer from 2(a), evaluate $A$, $B$, and $C$ for $\mathbf{x}_0 = [0, 0, \frac{\pi}{4}, 2.]^T$ and $\mathbf{u}_0 = [0.1, 1.]^T$. Give your answer to 2 decimal places.

In [30]:
x0 = jnp.array([0, 0, jnp.pi/4, 2])
u0 = jnp.array([0.1, 1.0])
time =0.0
A,B,C = linearize_unicycle_continuous_time_analytic(x0, u0, time)
print("A =", jnp.round(A,2))
print("B =", jnp.round(B,2))
print("C =", jnp.round(C, 2))

A = [[ 0.    0.   -1.41  0.71]
 [ 0.    0.    1.41  0.71]
 [ 0.    0.    0.    0.  ]
 [ 0.    0.    0.    0.  ]]
B = [[0 0]
 [0 0]
 [1 0]
 [0 1]]
C = [[ 1.11]
 [-1.11]
 [ 0.  ]
 [ 0.  ]]


EXPLANATION: In this part we define the operating point vectors x0 and u0 as `jnp.array`,  
and then call the previously defined function with time = 0. 
Calling the function linearize_unicycle_continuous_time_analutic we are ablke to "extract" the matrices A, B, C following the theoretical explenetion of the first steps. 

We also use `jnp.round` to round the results to two decimal places, as requested by the statement of the exercise.


### (c) Linearize dynamics using JAX autodiff
Time to test out Jax's autodifferentiation capabilities! JAX has an [Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html) that provides more details about the various autodiff functions, forward vs backward autodiff, jacboians, hessians, and so forth. You are strongly encouraged read through it.

Using Jax and its built-in `jax.jacobian` function, fill in the `linearize_autodiff` function that takes in a dynamics function, and a state and control to linearize about, and returns the $A$, $B$, and $C$ matrices describing the linearized dynamics. Test your function using the continuous-time dynamics with $\mathbf{x}_0 = [0.0, 0.0, \frac{\pi}{4}, 2.0]^T$ and $\mathbf{u}_0 = [0.1, 1.0]^T$ and use the provided test code to verify that the outputs you get from your function are the same as the values you get from `linearize_unicycle_continuous_time_analytic`.

Side note for the curious: the `jnp.allclose` function tests if all corresponding elements of two arrays are within a small tolerance of each other. When working with finite-precision machine arithmetic, you can almost never test two numbers for exact equality directly, because different rounding errors in different computations very often result in very slightly different values even when the two calculations should theoretically result in the same number. For this reason, real numbers in software (which on almost all modern hardware are represented in [IEEE 754 floating-point format](https://en.wikipedia.org/wiki/IEEE_754)) are usually considered to be equal if they are close enough that their difference could be reasonably explained by rounding errors.

In [31]:
def linearize_autodiff1(function_name, state, control, time):
    '''
    Linearizes the any dynamics using jax autodiff.
    Inputs:
        function_name: name of function to be linearized. Takes state, control, and time as inputs.
        state     : A jax.numpy array of size (n,); the state to linearize about
        control   : A jax.numpy array of size (m,); the control to linearize about
        time      : A real scalar; the time to linearize about

    Outputs:
        A : A jax.numpy array of size (n,n)
        B : A jax.numpy array of size (n,m)
        C : A jax.numpy array of size (n,1)
    '''

    f1 = lambda x: function_name(x, control, time)
    A = jacrev(f1)(state)

    f2 = lambda u: function_name(state, u, time)
    B = jacrev(f2)(control)
    
    f = function_name(state, control, time)

    C = f -A@state -B@control.reshape((2,))
    C = C.reshape((4, 1))


    return A, B, C # update this line
        

In this code, we build:

1. f_x(x) = f(x, u₀, t₀), which is dependent only on the state x.
2. f_u(u) = f(x₀, u, t₀), which is dependent only on the control u.

We then call JAX’s `jax.jacobian`( )`jacrev`):

- `A = jacrev(f_x)(x₀)` produces the full ∂f/∂x matrix at our linearization point.
- `B = jacrev(f_u)(u₀)` produces the full ∂f/∂u matrix.

Next, we evaluate: f0 = f(x₀, u₀, t₀) to form the term C = f_0 - A x - B u reshaped to a column vector to match the dimensions. 

In this way we know that the linearized model matches the nonlinear dynamics dot{x} ≈ A x + B u + C at the point we are considering, x0 and u0.

Below I report also another solution which might be more intuitive. The results are the exact same, it is just only another method. 

In [18]:
def linearize_autodiff2(function_name, state, control, time):
    def dynamics_state(x):
        return function_name(x, control, time)
    def dynamics_control(u):
        return function_name(state, u, time)

    A = jax.jacobian(dynamics_state)(state)
    B = jax.jacobian(dynamics_control)(control)

    f0 = function_name(state, control, time)
    C = (f0 - A @ state - B @ control).reshape(-1, 1)

    return A, B, C

Let's now test the implementation. To have a correct implementation we should have True in all 3 the results. This means that the implementation using JAX autodiff has been done correctly, becaue it means that the matrix done with the linearized version matches the previos, standard version. I'm running both linearize_autoduff1 and linearize_autoduff2, provided above, to show that both can be used equally. 

In [19]:
# test code:
state = jnp.array([0.0, 0.0, jnp.pi/4, 2.])
control = jnp.array([0.1, 1.])
time = 0.0

A_autodiff1, B_autodiff1, C_autodiff1 = linearize_autodiff1(continuous_dynamics, state, control, time)
A_autodiff2, B_autodiff2, C_autodiff2 = linearize_autodiff2(continuous_dynamics, state, control, time)
A_analytic, B_analytic, C_analytic = linearize_unicycle_continuous_time_analytic(state, control, time)

print('A matrices match:', jnp.allclose(A_autodiff1, A_analytic))
print('B matrices match:', jnp.allclose(B_autodiff1, B_analytic))
print('C matrices match:', jnp.allclose(C_autodiff1, C_analytic))
print('A matrices match:', jnp.allclose(A_autodiff2, A_analytic))
print('B matrices match:', jnp.allclose(B_autodiff2, B_analytic))
print('C matrices match:', jnp.allclose(C_autodiff2, C_analytic))

A matrices match: True
B matrices match: True
C matrices match: True
A matrices match: True
B matrices match: True
C matrices match: True


### (d) Linearize discrete-time dynamics
Assuming your answer from 2(c) matched 2(b) and that you are convinced of the power of automatic differentiation, use your `linearize_autodiff` function on `discrete_dynamics_euler` and `discrete_dynamics_rk` with $\mathbf{x}_0 = [0.0, 0.0, \frac{\pi}{4}, 2.0]^T$ and $\mathbf{u}_0 = [0.1, 1.0]^T$. (Imagine trying to differentiate the expressions analytically! It would be tedious!)

Let $\Delta t=0.1$.

In [None]:
x0 = jnp.array([0.0, 0.0, jnp.pi/4, 2.0])
u0 = jnp.array([0.1, 1.0])
dt = 0.1

discrete_dynamics_euler = Dynamics(
    euler_integrate(continuous_dynamics, dt), state_dim, control_dim
)
discrete_dynamics_rk = Dynamics(
    runge_kutta_integrator(continuous_dynamics, dt), state_dim, control_dim
)
xse = linearize_autodiff1(discrete_dynamics_euler, x0, u0, dt)
xsr = linearize_autodiff1(discrete_dynamics_rk, x0, u0, dt)

A_e, B_e, C_e = xse
A_r, B_r, C_r = xsr

print("A_euler =\n", jnp.round(A_e, 2))
print("B_euler =\n", jnp.round(B_e, 2))
print("C_euler =\n", jnp.round(C_e, 2))
print("A_rk =\n", jnp.round(A_r, 2))
print("B_rk =\n", jnp.round(B_r, 2))
print("C_rk =\n", jnp.round(C_r, 2))

A_euler =
 [[ 1.    0.   -0.14  0.07]
 [ 0.    1.    0.14  0.07]
 [ 0.    0.    1.    0.  ]
 [ 0.    0.    0.    1.  ]]
B_euler =
 [[0.         0.        ]
 [0.         0.        ]
 [0.09999999 0.        ]
 [0.         0.09999999]]
C_euler =
 [[ 0.11]
 [-0.11]
 [-0.  ]
 [-0.  ]]
A_rk =
 [[ 1.          0.         -0.14999999  0.07      ]
 [ 0.          1.          0.14        0.07      ]
 [ 0.          0.          1.          0.        ]
 [ 0.          0.          0.          1.        ]]
B_rk =
 [[-0.01        0.        ]
 [ 0.01        0.        ]
 [ 0.09999999  0.        ]
 [ 0.          0.09999999]]
C_rk =
 [[ 0.12]
 [-0.11]
 [-0.  ]
 [-0.  ]]


We applied our `linearize_autodiff` to two discrete time approximations of the unicycle model—Euler integration and 4th‐order Runge–Kutta. We used:
$$
x_0 = [0,\;0,\;\tfrac{\pi}{4},\;2]^T,\qquad
u_0 = [0.1,\;1.0]^T,\qquad\Delta t=0.1.
$$
The code can be seen with the following workflow: 
  1. defined `discrete_dynamics_euler` via `euler_integrate(continuous_dynamics, dt)` and `discrete_dynamics_rk` via `runge_kutta_integrator(continuous_dynamics, dt)`. The functions are taken from problem 1.   
  2. Called `linearize_autodiff1()` to compute $A$, $B$, and $C$.  
  3. Rounded all entries to two decimal places with `jnp.round` as asked in the exercise.

As asked in the problem, dt = 0.1

About the results we can say the following: 
- **Euler:** the system is approximated by taking a straight step using the initial slope. Since it's simpler, it should be less accurate.

- **RK4:**  It looks at multiple slopes during the step, and it should give a more accurate linearization.




### (e) Applying `vmap` to linearize over multiple points
Now, try to linearize your dynamics over multiple state-control values using `vmap`!

In [26]:
key = jax.random.PRNGKey(42)  # Set a fixed seed
n_samples = 1000
state_dim = 4  # 4-dimensional state
ctrl_dim = 2  # 2-dimensional control

time = 0.0
random_states = jax.random.normal(key, shape=(n_samples, state_dim))
random_controls = jax.random.normal(key, shape=(n_samples, ctrl_dim))

def linearize(x, u):
    return linearize_autodiff1(discrete_dynamics_rk, x, u, time)

lin = jax.vmap(linearize, in_axes=(0,0))
Alin, Blin, Clin = lin(random_states, random_controls)

print("A linearized has shape:", Alin.shape)
print("B linearized has shape:", Blin.shape)
print("C linearized has shape:", Clin.shape)

A linearized has shape: (1000, 4, 4)
B linearized has shape: (1000, 4, 2)
C linearized has shape: (1000, 4, 1)


We define the function `linearize(x, u)` to apply our `linearize_autodiff1` method to a single state–control pair $(x, u)$.  
Specifically, the function `linearize(x, u)` follows the workflow:
- Takes one pair of inputs,
- Linearizes the discrete-time dynamics(using RK4 already shown above)
- Returns matrices $A$, $B$, and $C$.

Then, instead of calling this function manually 1000 times, we automize use:

lin = jax.vmap(linearize, in_axes=(0, 0))
Alin, Blin, Clin = lin(random_states, random_controls)

Also:
- 'jax.vmap' vectorizes the linearize function. We can see that as  applying that to each row of `random_states` and `random_controls`.  
- `in_axes=(0, 0)` basically tells JAX to map over $x$ and $u$ inputs along their first axis (batch dimension).

Running the cell we can also see that the output are the matrices A, B, C lizearized witht the shape corrisponding to respectively (1000, 4, 4), (1000, 4, 2) and (1000, 4, 1).