# Dynamics linearization

In [None]:
import jax
import jax.numpy as jnp

**Define the dynamics**

In [None]:
def f(x, u):
	A = jnp.eye(3)
	B = jnp.array([
			[1, 2],
			[3, 4],
			[5, 6]
			])
	x_dot = A @ x + B @ u
	return x_dot

**Linearize the dynamics**

In [None]:
f_dxu = jax.jacobian(
  f, argnums=(0, 1))
x = jnp.ones(3)
u = jnp.ones(2)
A, B = f_dxu(x, u)
print(A)
print(B)

[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
[[1. 2.]
 [3. 4.]
 [5. 6.]]
