# Auto-differentiation in Python via JAX

In [1]:
import jax
import jax.numpy as jnp

JAX is a tool that can automatically differentiate functions. For example, consider the simple function:
\begin{equation}
f(x) = \frac{1}{2}\lVert x \rVert^2,
\end{equation}
where $x \in \mathbb{R}^n$. We can differentiate this with JAX as follows:

In [2]:
def f(x):
    return jnp.sum(x**2)/2

grad_f = jax.grad(f)         # compute the gradient function
x = jnp.array([0., 1., 2.])  # use JAX arrays! 

print('x:        ', x)
print('f(x):     ', f(x))
print('grad_f(x):', grad_f(x))

x:         [0. 1. 2.]
f(x):      2.5
grad_f(x): [0. 1. 2.]


### Example: Inverted Pendulum Dynamics Linearization

The dynamics of the inverted pendulum are given by:
\begin{equation}
    \dot{x}_1 = x_2, \quad \dot{x}_2 = \frac{g}{l} \sin(x_1) + \frac{1}{ml^2}u, 
\end{equation}
where $x = [\theta, \dot{\theta}]^\top$ where $\theta$ is the angle of rotation relative to the vertical axis, $l$ is the rod length, $m$ is the mass at the end of the rod, $g$ is acceleration due to gravity, and $u$ is the control torque about the axis of rotation.

We can linearize these dynamics using JAX as shown below:

In [3]:
def inverted_pendulum_dynamics(x, u, g=9.81, m=1, l=1):
    """
    Evaluate the inverted pendulum dynamics.
    """
    θ, dθ_dt = x
    dx_dt = jnp.array([dθ_dt, (g/l)*jnp.sin(θ) + (1/m*l**2)*u])
    return dx_dt

# Linearize around the stationary upright position with zero control
# (i.e. the pendulum is perfectly balanced)
f_jac = jax.jacobian(inverted_pendulum_dynamics, argnums=(0, 1))
x = jnp.array([0., 0.])
u = 0.                
A, B = f_jac(x, u) # Evaluate Jacobian at equilibrium point

In [4]:
print(A)
print(B)

[[0.   1.  ]
 [9.81 0.  ]]
[0. 1.]
