# RO47019: Intelligent Control Systems Practical Assignment
* Period: 2024-2025, Q4
* Course homepage: https://brightspace.tudelft.nl/d2l/home/682445
* Instructor: Cosimo Della Santina (C.DellaSantina@tudelft.nl)
* Teaching assistant: Niels Stienen (N.L.Stienen@student.tudelft.nl)
* (c) TU Delft, 2025

Make sure you fill in any place that says `YOUR CODE HERE` or `YOUR ANSWER HERE` and remove `raise NotImplementedError()` afterwards. Moreover, if you see an empty cell, please **do not** delete it, instead run that cell as you would run all other cells. Finally, please **do not** add any extra cells to this notebook or change the existing cells unless you are explicitly asked to do so.

Please fill in your name(s) and other required details below:

In [1]:
# Please fill in your names, student numbers, netID, and emails below.
STUDENT_1_NAME = "Nicola Visentin"
STUDENT_1_STUDENT_NUMBER = "6354815"
STUDENT_1_NETID = "nvisentin"
STUDENT_1_EMAIL = "N.Visentin@student.tudelft.nl"

In [2]:
# Note: this block is a check that you have filled in the above information.
# It will throw an AssertionError until all fields are filled
assert STUDENT_1_NAME != ""
assert STUDENT_1_STUDENT_NUMBER != ""
assert STUDENT_1_NETID != ""
assert STUDENT_1_EMAIL != ""

### General announcements

* Do *not* share your solutions (also after the course is finished), and do *not* copy solutions from others. By submitting your solutions, you claim that you alone are responsible for this code.

* Please post your questions regarding this assignment in the correct support forum on Brightspace, this way everybody can benefit from the response. Please note that it is **not** allowed to post any code relating to solution attempts. If you do have a particular question that you want to ask directly, please use the scheduled Q&A hours to ask the TA or if not possible otherwise, send an email to the instructor or TA.

* This notebook will have in various places a line that throws a `NotImplementedError` exception. These are locations where the assignment requires you to adapt the code! These lines are just there as a reminder for you that you have not yet adapted that particular piece of code, especially when you execute all the cells. Once your solution code replaced these lines, it should accordingly *not* throw any exceptions anymore.

* This [Jupyter notebook](https://jupyter.org/) uses `nbgrader` to help us with automated tests. `nbgrader` will make various cells in this notebook "uneditable" or "unremovable" and gives them a special id in the cell metadata. This way, when we run our checks, the system will check the existence of the cell ids and verify the number of points and which checks must be run. While there are ways that you can edit the metadata and work around the restrictions to delete or modify these special cells, you should not do that since then our nbgrader backend will not be able to parse your notebook and give you points for the assignment. 

* Please note that the above mentioned _read-only_ protection only works in Jupyter Notebook, and it does not work if you open this notebook in another editor (e.g., VSCode, PyCharm, etc.). Therefore, we recommend that you only use Jupyter Notebook for this course. If you use any other editor, you may accidentally delete cells, modify the tests, etc., which would cause you to lose points.

* If you edit a function that is imported in another notebook, you need to **restart the kernel** of the notebook where you are using the function. Otherwise, the changes will not be effective.

* **IMPORTANT**: Please make sure that your code executes without any errors before submitting the notebook. An easy way to ensure this is to use the validation script as described in the README.

# Linearization

**Author:** Maximilian Stölzle (M.W.Stolzle@tudelft.nl)

This notebook contains function helping to linearize and the discretize then system. Please complete the corresponding assignment in the `task_2d-1_linearization` notebook.

In [3]:
from functools import partial
import jax

jax.config.update("jax_platforms", "cpu")  # set default device to 'cpu'
jax.config.update("jax_enable_x64", True)  # double precision
from jax import Array, debug, jacfwd, jit, lax, vmap
import jax.numpy as jnp
from jax.scipy import linalg
from pathlib import Path
from typing import Callable, Dict, Tuple

from jax_double_pendulum.dynamics import (
    continuous_forward_dynamics,
    continuous_inverse_dynamics,
    discrete_forward_dynamics,
    continuous_linear_state_space_representation,
)

In [4]:
@partial(
    jit,
    static_argnums=0,
    static_argnames=("continuous_forward_dynamics_fn",),
)
def continuous_state_space_dynamics(
    continuous_forward_dynamics_fn: Callable,
    x: Array,
    tau: Array,
    *args_dynamics,
) -> Tuple[Array, Array]:
    """
    Compute the continuous forward dynamics of the system in decoupled form
    Args:
        continuous_forward_dynamics_fn: function to compute the continuous forward dynamics of the system
            Must have the signature th_dd = continuous_forward_dynamics_fn(th, th_d, tau, *args_dynamics)
        x: system state of shape (4, ) consisting of link angles and link angular velocities
        tau: link torques of shape (2, )
        *args_dynamics: additional arguments to pass to continuous_forward_dynamics_fn

    Returns:
        dx_dt: time derivative of system state of shape (4, ).
            Corresponds to link angular velocities and link angular accelerations
        y: system output of shape (2, ). Corresponds to the link angles.

    """

    # Extract th and th_d from x
    th = x[:2]
    th_d = x[2:]

    # Compute th_dd
    th_dd = continuous_forward_dynamics_fn(th, th_d, tau, *args_dynamics)

    # Put everything together in dx_dt and y
    dx_dt = jnp.concatenate([th_d, th_dd])
    y = th

    return dx_dt, y

In [5]:
@partial(
    jit,
    static_argnums=0,
    static_argnames=("continuous_forward_dynamics_fn",),
)
def continuous_linear_state_space_representation_autograd(
    continuous_forward_dynamics_fn: Callable,
    th_eq: Array,
    th_d_eq: Array = jnp.zeros((2,)),
    tau_eq: Array = jnp.zeros((2,)),
    *args_dynamics,
) -> Tuple[Array, Array, Array, Array]:
    """
    Linearize the system about the specified state (th, th_d) and input tau
    and return the linearized system in state space representation

    Args:
        continuous_forward_dynamics_fn: function to compute the continuous forward dynamics of the system
            Must have the signature th_dd = continuous_forward_dynamics_fn(th, th_d, tau, *args)
        th_eq: equilibrium link angles of double pendulum of shape (2, )
        th_d_eq: equilibrium link angular velocities of double pendulum of shape (2, )
        tau_eq: equilibrium link torques of double pendulum of shape (2, )
        *args_dynamics: additional arguments to pass to continuous_state_space_dynamics.
            The same additional arguments are then in turn later passed to continuous_forward_dynamics_fn.

    Returns:
        A: state transition matrix of shape (4, 4)
        B: input matrix of shape (4, 2)
        C: output matrix of shape (2, 4)
        D: feed-through matrix of shape (2, 2)
    """
    # Hint: use `jacfwd` on `continuous_state_space_dynamics` to get the gradients of the state transition
    # and outputs with respect to the state and input respectively.

    # First, compute the state x_eq
    x_eq = jnp.concatenate([th_eq, th_d_eq])

    # To use `jacfwd` on `continuous_state_space_dynamics` we need to have continuous_state_space_dynamics as:
    #     dx_dt, y = continuous_state_space_dynamics_new(x, tau)
    continuous_state_space_dynamics_new = partial(continuous_state_space_dynamics, continuous_forward_dynamics_fn)

    # Now we can use jacfwd
    continuous_state_space_dynamics_new_jacobian = jax.jacfwd(continuous_state_space_dynamics_new, argnums=(0, 1))
    (A,B), (C,D) = continuous_state_space_dynamics_new_jacobian(x_eq, tau_eq, *args_dynamics)
    
    return A, B, C, D

In [6]:
@jit
def cont2discrete_zoh(
    dt: Array,
    A: Array,
    B: Array,
    C: Array,
    D: Array,
) -> Tuple[Array, Array, Array, Array]:
    """
    Discretize continuous-time system using zero-order hold.
    Please refer to the Scipy documentation of the `cont2discrete` function for some inspiration:
        https://github.com/scipy/scipy/blob/v1.9.3/scipy/signal/_lti_conversion.py#L335-L532

    Args:
        dt: time step of the discrete-time system
        A: continuous-time state transition matrix of shape (4, 4)
        B: continuous-time input matrix of shape (4, 2)
        C: continuous-time output matrix of shape (2, 4)
        D: continuous-time feed-through matrix of shape (2, 2)

    Returns:
        Ad: discrete-time state transition matrix of shape (4, 4)
        Bd: discrete-time input matrix of shape (4, 2)
        Cd: discrete-time output matrix of shape (2, 4)
        Dd: discrete-time feed-through matrix of shape (2, 2)

    """

    n, m = B.shape

    # Create augmented matrix for computing expm
    M = jnp.zeros((n + m, n + m))
    M = M.at[:n, :n].set(A)
    M = M.at[:n, n:n + m].set(B)

    # Matrix exponential
    M_exp = jax.scipy.linalg.expm(M * dt)

    # Extract Ad and Bd
    Ad = M_exp[:n, :n]
    Bd = M_exp[:n, n:n + m]

    # Cd and Dd
    Cd = C
    Dd = D

    return Ad, Bd, Cd, Dd

In [7]:
@jit
def linearized_discrete_forward_dynamics(
    Ad: Array,
    Bd: Array,
    Cd: Array,
    Dd: Array,
    th_eq: Array,
    th_d_eq: Array,
    tau_eq: Array,
    dt: float,
    th: Array,
    th_d: Array,
    tau: Array,
) -> Tuple[Array, Array, Array]:
    """
    Compute the discrete forward dynamics of the linearized system.
    Should use the linear, discrete-time state-space description to compute the state of the system at the next
    time-step.
    Args:
        Ad: discrete-time state transition matrix of shape (4, 4)
        Bd: discrete-time input matrix of shape (4, 2)
        Cd: discrete-time output matrix of shape (2, 4)
        Dd: discrete-time feed-through matrix of shape (2, 2)
        th_eq: equilibrium link angles of shape (2, )
        th_d_eq: equilibrium link angular velocities of shape (2, )
        tau_eq: equilibrium link torques of shape (2, )
        dt: time step between the current and the next state [s]
        th_curr: current link angles of double pendulum of shape (2, )
        th_d_curr: current link angular velocities of double pendulum of shape (2, )
        tau: link torques of double pendulum of shape (2, )
    Returns:
        th_next: link angles at the next time step of double pendulum of shape (2, )
        th_d_next: link angular velocities at the next time step of double pendulum of shape (2, )
        th_dd: link angular accelerations between current and next time step of double pendulum of shape (2, )
    """
    # Compute the state (th_next, th_d_next) at the next timestep and the corresponding acceleration `th_dd`
    
    # First, build the state delta_x = x - x_eq
    x = jnp.concatenate([th, th_d])
    x_eq = jnp.concatenate([th_eq, th_d_eq])
    delta_x = x - x_eq

    # Then the input delta_tau
    delta_tau = tau - tau_eq

    # Then apply the discrete linearized dynamics
    delta_x_next = Ad @ delta_x + Bd @ delta_tau

    # Now sum again the reference as x_next = x_eq + delta_x_next
    x_next = x_eq + delta_x_next

    # Extract states
    th_next = x_next[:2]
    th_d_next = x_next[2:]

    # Compute th_dd with finite differences
    th_dd = (th_d_next - th_d) / dt

    return th_next, th_d_next, th_dd

In [8]:
# import feedback controller from controllers.ipynb
from ipynb.fs.full.controllers import ctrl_fb_pd


def closed_loop_fb_continuous_forward_dynamics(
    rp: Dict,
    th: Array,
    th_d: Array,
    tau_ext: Array,
    th_des: Array,
    th_d_des: Array,
    kp_fb: Array = jnp.zeros((2,)),
    kd_fb: Array = jnp.zeros((2,)),
) -> Array:
    """
    Adds a feedback control term to the continuous forward dynamics
    Args:
        rp: dictionary of robot parameters used for evaluating the continuous forward dynamics
        th: link angles of shape (2, )
        th_d: link angular velocities of shape (2, )
        tau_ext: external torques of shape (2, ) applied to the system in addition to the feedback torques
        th_des: desired link angles of shape (2, )
        th_d_des: desired link angular velocities of shape (2, )
        kp_fb: proportional gains of the parallel feedback controller of shape (2, 2)
        kd_fb: derivative gains of the parallel feedback controller of shape (2, 2)

    Returns:
        th_dd: link angular accelerations of shape (2, )

    """

    # Computes PD feedback torque
    
    tau_fb = ctrl_fb_pd(th, th_d, th_des, th_d_des, kp_fb, kd_fb) 
    # this simply gives (tau_fb = kp @ (th_des - th) + kd @ (th_d_des - th_d)
    
    # Computes dynamics  
    th_dd = continuous_forward_dynamics(rp, th, th_d, tau_fb+tau_ext)

    return th_dd

In [9]:
@jit
def linearize_closed_loop_fb_system_about_trajectory(
    rp: Dict,
    traj_ts: Dict[str, Array],
    kp_fb: Array = jnp.zeros((2, 2)),
    kd_fb: Array = jnp.zeros((2, 2)),
) -> Tuple[Array, Array, Array, Array, Array]:
    """
    Linearize the nonlinear double pendulum system about a given trajectory and return the
    linearized system in state space representation.
    Args:
        rp: dictionary of robot parameters used for evaluating the continuous forward dynamics
        traj_ts: dictionary of time series of trajectories
        kp_fb: proportional gains of the parallel feedback controller of shape (2, 2)
        kd_fb: derivative gains of the parallel feedback controller of shape (2, 2)

    Returns:
        tau_eq_ts: time series of equilibrium torques of shape (N, 2)
        Ad_ts: time series of discrete-time state transition matrices of shape (N, 4, 4)
        Bd_ts: time series of discrete-time input matrices of shape (N, 4, 2)
        Cd_ts: time series of discrete-time output matrices of shape (N, 2, 4)
        Dd_ts: time series of discrete-time feed-through matrices of shape (N, 2, 2)

    """
    ## number of time-steps
    
    N = traj_ts["t_ts"].shape[0]

    ## compute the equilibrium torque to follow the trajectory using inverse dynamics
    ## Hint: you can access `th_ts`, `th_d_ts`, and `th_dd_ts` in the `traj_ts` dictionary.

    # Extract timeseries of reference trajectory
    th_eq_ts = traj_ts["th_ts"] 
    th_d_eq_ts = traj_ts["th_d_ts"] 
    th_dd_eq_ts = traj_ts["th_dd_ts"] 

    # Compute required torques all along the reference trajectory
    _continuous_inverse_dynamics = partial(continuous_inverse_dynamics,rp)
    __continuous_inverse_dynamics = jax.vmap(_continuous_inverse_dynamics)
    tau_eq_ts = __continuous_inverse_dynamics(th_eq_ts, th_d_eq_ts, th_dd_eq_ts)
    # which is equivalent to:
    #     _continuous_inverse_dynamics = jax.vmap(continuous_inverse_dynamics, in_axes=(None,0,0,0))
    #     tau_eq_ts = _continuous_inverse_dynamics(rp, th_eq_ts, th_d_eq_ts, th_dd_eq_ts)

    ## transform closed_loop_fb_continuous_forward_dynamics function to cohere to interface
    ##   th_dd = closed_loop_fb_continuous_forward_dynamics_fn(th, th_d, tau_ext, th_des, th_d_des)
    
    closed_loop_fb_continuous_forward_dynamics_fn = partial(
        closed_loop_fb_continuous_forward_dynamics,
        rp,
        kp_fb=kp_fb,
        kd_fb=kd_fb,
    )

    ## transform continuous_linear_state_space_representation_autograd function to cohere to the interface
    ##   A, B, C, D = cl_lsp_autograd_fn(th_eq, th_d_eq, tau_eq, th_des, th_d_des)
    
    cl_lsp_autograd_fn = partial(
        continuous_linear_state_space_representation_autograd,
        closed_loop_fb_continuous_forward_dynamics_fn,
    )

    ## linearize the closed-loop system at each time-step
    
    _cl_lsp_autograd_fn = jax.vmap(cl_lsp_autograd_fn)
    A_ts, B_ts, C_ts, D_ts = _cl_lsp_autograd_fn(th_eq_ts, th_d_eq_ts, tau_eq_ts, th_eq_ts, th_d_eq_ts)
    
    ## compute the time step
    
    dt = jnp.mean(traj_ts["t_ts"][1:] - traj_ts["t_ts"][:-1])

    ## discretize the state space system using the zero-order hold method

    _cont2discrete_zoh = partial(cont2discrete_zoh, dt)
    __cont2discrete_zoh = jax.vmap(_cont2discrete_zoh)
    Ad_ts, Bd_ts, Cd_ts, Dd_ts = __cont2discrete_zoh(A_ts, B_ts, C_ts, D_ts)
    # which is equivalent to:
    #    _cont2discrete_zoh = jax.vmap(cont2distrete, in_axes=(None,0,0,0,0))
    #    Ad_ts, Bd_ts, Cd_ts, Dd_ts = _cont2discrete_zoh(dt, A_ts, B_ts, C_ts, D_ts)

    return tau_eq_ts, Ad_ts, Bd_ts, Cd_ts, Dd_ts
