### Linearization of the Pendulum model.


In [None]:

from typing import Tuple
import os
import jax
import jax.numpy as jnp
import chex
import mujoco
import mujoco.mjx as mjx

import mjnax
from mjnax.mjxenv import MjxStateType, MjxModelType


# By default JAX set float types into float32. The line below enables
# float64 data type.
jax.config.update("jax_enable_x64", True)


@jax.jit
def linearization(mjx_model: MjxModelType,
                  mjx_state: MjxStateType,
                  ctrl_input: chex.Array
                  ) -> Tuple[chex.Array,
                             chex.Array]:
    """ Linearize 

    Args:
        mjx_model (MjxModelType): mjx model
        mjx_state (MjxStateType): mjx state
        ctrl_input (chex.Array): ctrl input

    Returns:
        Tuple[chex.Array, chex.Array]:
            - A matrix (derivative w.r.t state)
            - B matrix (derivative w.r.t input)
    """
    state = jnp.concatenate([mjx_state.qpos, mjx_state.qvel])
    n_qpos = mjx_state.qpos.shape[0]

    def step(state: chex.Array, ctrl_input: chex.Array) -> chex.Array:
        """ Step function.

        Args:
            state (chex.Array): qpos and qvel
            ctrl_input (chex.Array): control input

        Returns:
            chex.Array: next qpos and qvel
        """
        qpos = state[:n_qpos]
        qvel = state[n_qpos:]
        _mjx_state = mjx_state.replace(ctrl=ctrl_input, qpos=qpos, qvel=qvel)
        next_mjx_state = mjx.step(mjx_model, _mjx_state)
        return jnp.concatenate([next_mjx_state.qpos, next_mjx_state.qvel])

    # Create jacobian generating functions
    jacobian_a = jax.jacobian(step, argnums=[0])
    jacobian_b = jax.jacobian(step, argnums=[1])

    # Calculate the jacobian matrices at the given state and input
    a_matrix = jacobian_a(state, ctrl_input)
    b_matrix = jacobian_b(state, ctrl_input)
    return a_matrix, b_matrix

> Compilation will take some time in the first run!

### Linearize pendulum 

In [None]:
pendulum_xml: str = "assets/pendulum.xml"
absolute_path = os.path.join(mjnax.__path__[0], pendulum_xml)

# Initiate the MJX model and state
mj_model = mujoco.MjModel.from_xml_path(absolute_path)
mjx_model = mjx.put_model(mj_model)
mjx_state = mjx.make_data(mjx_model)

# Set qpos to 0 (default is 30 degrees) and linearize
mjx_state = mjx_state.replace(qpos=jnp.array([0.]))
a_matrix, b_matrix = linearization(mjx_model, mjx_state, jnp.array([0.]))

a_matrix, b_matrix