# `JAXsim` Showcase: PD Controller

<a target="_blank" href="https://colab.research.google.com/github/ami-iit/jaxsim/blob/main/examples/PD_controller.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

First, we install the necessary packages and import them.

In [None]:
# @title Imports and setup
from IPython.display import clear_output, HTML, display
import sys

IS_COLAB = "google.colab" in sys.modules

# Install JAX and Gazebo
if IS_COLAB:
    !{sys.executable} -m pip install -U -q jaxsim
    !apt -qq update && apt install -qq --no-install-recommends gazebo
    clear_output()

import jax
import jax.numpy as jnp
from jaxsim import logging

logging.set_logging_level(logging.LoggingLevel.INFO)
logging.info(f"Running on {jax.devices()}")

We will use a simple cartpole model for this example. The cartpole model is a 2D model with a cart that can move horizontally and a pole that can rotate around the cart. The state of the cartpole is given by the position of the cart, the angle of the pole, the velocity of the cart, and the angular velocity of the pole. The control input is the horizontal force applied to the cart.

In [None]:
# @title Fetch the URDF file
import requests

url = "https://raw.githubusercontent.com/ami-iit/jaxsim/main/examples/assets/cartpole.urdf"

response = requests.get(url)
if response.status_code == 200:
    model_urdf_string = response.text
else:
    logging.error("Failed to fetch data")

JAXsim offers a simple high-level API in order to extract quantities needed in most robotic applications. 

In [None]:
from jaxsim.high_level.model import Model

model = Model.build_from_model_description(
    model_description=model_urdf_string, is_urdf=True
)

Let's reset the cartpole to a random state.

In [None]:
random_positions = jax.random.uniform(
    minval=-1.0, maxval=1.0, shape=(model.dofs(),), key=jax.random.PRNGKey(0)
)

model.reset_joint_positions(positions=random_positions)

The visualization is done using mujoco package, to be able to render easily the animations also on Google Colab. If you are not interested in the animation, execute but do not try to understand deeply this cell.

In [None]:
# @title Set up MuJoCo renderer
!{sys.executable} -m pip install -U -q mujoco
!{sys.executable} -m pip install -q mediapy

import mediapy as media
import tempfile
import xml.etree.ElementTree as ET
import numpy as np

import distutils.util
import os
import subprocess

if IS_COLAB:
    if subprocess.run("ffmpeg -version", shell=True).returncode:
        !command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
        clear_output()

    if subprocess.run("nvidia-smi").returncode:
        raise RuntimeError(
            "Cannot communicate with GPU. "
            "Make sure you are using a GPU Colab runtime. "
            "Go to the Runtime menu and select Choose runtime type."
        )

    # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
    # This is usually installed as part of an Nvidia driver package, but the Colab
    # kernel doesn't install its driver via APT, and as a result the ICD is missing.
    # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
    NVIDIA_ICD_CONFIG_PATH = "/usr/share/glvnd/egl_vendor.d/10_nvidia.json"
    if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
        with open(NVIDIA_ICD_CONFIG_PATH, "w") as f:
            f.write(
                """{
      "file_format_version" : "1.0.0",
      "ICD" : {
         "library_path" : "libEGL_nvidia.so.0"
     }
  }
  """
            )

%env MUJOCO_GL=egl

try:
    import mujoco
except Exception as e:
    raise e from RuntimeError(
        "Something went wrong during installation. Check the shell output above "
        "for more information.\n"
        "If using a hosted Colab runtime, make sure you enable GPU acceleration "
        'by going to the Runtime menu and selecting "Choose runtime type".'
    )


def load_mujoco_model_with_camera(xml_string, camera_pos, camera_xyaxes):
    def to_mjcf_string(list_to_str):
        return " ".join(map(str, list_to_str))

    mj_model_raw = mujoco.MjModel.from_xml_string(model_urdf_string)
    path_temp_xml = tempfile.NamedTemporaryFile(mode="w+")
    mujoco.mj_saveLastXML(path_temp_xml.name, mj_model_raw)
    # Add camera in mujoco model
    tree = ET.parse(path_temp_xml)
    for elem in tree.getroot().iter("worldbody"):
        worldbody_elem = elem
    camera_elem = ET.Element("camera")
    # Set attributes
    camera_elem.set("name", "side")
    camera_elem.set("pos", to_mjcf_string(camera_pos))
    camera_elem.set("xyaxes", to_mjcf_string(camera_xyaxes))
    camera_elem.set("mode", "fixed")
    worldbody_elem.append(camera_elem)

    # Save new model
    mujoco_xml_with_camera = ET.tostring(tree.getroot(), encoding="unicode")
    mj_model = mujoco.MjModel.from_xml_string(mujoco_xml_with_camera)
    return mj_model


def from_jaxsim_to_mujoco_pos(jaxsim_jointpos, mjmodel, jaxsimmodel):
    mujocoqposaddr2jaxindex = {}
    for jaxjnt in jaxsimmodel.joints():
        jntname = jaxjnt.name()
        mujocoqposaddr2jaxindex[mjmodel.joint(jntname).qposadr[0]] = jaxjnt.index() - 1

    mujoco_jointpos = jaxsim_jointpos
    for i in range(0, len(mujoco_jointpos)):
        mujoco_jointpos[i] = jaxsim_jointpos[mujocoqposaddr2jaxindex[i]]

    return mujoco_jointpos


# To get a good camera location, you can use "Copy camera" functionality in MuJoCo GUI
mj_model = load_mujoco_model_with_camera(
    model_urdf_string,
    [3.954, 3.533, 2.343],
    [-0.594, 0.804, -0.000, -0.163, -0.120, 0.979],
)
renderer = mujoco.Renderer(mj_model, height=480, width=640)


def get_image(camera, mujocojointpos) -> np.ndarray:
    """Renders the environment state."""
    # Copy joint data in mjdata state
    d = mujoco.MjData(mj_model)
    d.qpos = mujocojointpos

    # Forward kinematics
    mujoco.mj_forward(mj_model, d)

    # use the mjData object to update the renderer
    renderer.update_scene(d, camera=camera)
    return renderer.render()

Let's see how the model behaves when not controlled:

In [None]:
from jaxsim.simulation.ode_integration import IntegratorType

sim_images = []
timestep = 0.01
for _ in range(300):
    sim_images.append(
        get_image(
            "side",
            from_jaxsim_to_mujoco_pos(
                np.array(model.joint_positions()), mj_model, model
            ),
        )
    )
    model.integrate(
        t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit
    )

media.show_video(sim_images, fps=1 / timestep)

Let's now define the PD controller. We will use the following equations:

\begin{align} 
\mathbf{M}\ddot{s} + \underbrace{\mathbf{C}\dot{s} + \mathbf{G}}_{\mathbf{H}} = \tau \\
\tau = \mathbf{H} - \mathbf{K}_p(s - s_d) - \mathbf{K}_d(\dot{s} - \dot{s}_d)
\end{align}

where $\mathbf{M}$ is the mass matrix, $\mathbf{C}$ is the Coriolis matrix, $\mathbf{G}$ is the gravity vector, $\mathbf{K}_p$ is the proportional gain matrix, $\mathbf{K}_d$ is the derivative gain matrix, $s$ is the position vector, $\dot{s}$ is the velocity vector, $\ddot{s}$ is the acceleration vector, and $s_d$ and $\dot{s}_d$ are the desired position and velocity vectors, respectively.

In [None]:
# Define the PD gains
KP = 10.0
KD = 6.0

# Compute the gravity compensation term
H = model.free_floating_bias_forces()[6:]


def pd_controller(
    q: jax.Array, q_d: jax.Array, q_dot: jax.Array, q_dot_d: jax.Array
) -> jax.Array:
    return H + KP * (q_d - q) + KD * (q_dot_d - q_dot)

Now, we can use the `pd_controller` function to compute the torque to apply to the cartpole. Our aim is to stabilize the cartpole in the upright position, so we set the desired position `q_d` to 0 and the desired velocity `q_dot_d` to 0.

In [None]:
sim_images = []
timestep = 0.01

for _ in range(300):
    sim_images.append(
        get_image(
            "side",
            from_jaxsim_to_mujoco_pos(
                np.array(model.joint_positions()), mj_model, model
            ),
        )
    )
    model.set_joint_generalized_force_targets(
        forces=pd_controller(
            q=model.joint_positions(),
            q_d=jnp.array([0.0, 0.0]),
            q_dot=model.joint_velocities(),
            q_dot_d=jnp.array([0.0, 0.0]),
        )
    )
    model.integrate(
        t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit
    )

media.show_video(sim_images, fps=1 / timestep)