# `JAXsim` Showcase: PD Controller

<a target="_blank" href="https://colab.research.google.com/github/flferretti/jaxsim/blob/example/notebook/examples/PD_controller.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

First, we install the necessary packages and import them.

In [None]:
from IPython.display import clear_output

# Install JAX and Gazebo
!pip install -U -q "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install -q git+https://github.com/ami-iit/jaxsim@new_api
!apt -qq update && apt install -qq --no-install-recommends gazebo 
clear_output()

import jax
import jax.numpy as jnp
from jaxsim import logging

logging.set_logging_level(logging.LoggingLevel.INFO)
logging.info(f"Running on {jax.devices()}")

We will use a simple cartpole model for this example. The cartpole model is a 2D model with a cart that can move horizontally and a pole that can rotate around the cart. The state of the cartpole is given by the position of the cart, the angle of the pole, the velocity of the cart, and the angular velocity of the pole. The control input is the horizontal force applied to the cart.

In [None]:
import requests

url = "https://raw.githubusercontent.com/flferretti/jaxsim/example/notebook/examples/assets/cartpole.urdf"

response = requests.get(url)
if response.status_code == 200:
    model_urdf_path = response.text
else:
    logging.error("Failed to fetch data")

JAXsim offers a simple high-level API in order to extract quantities needed in most robotic applications. 

In [None]:
from jaxsim.high_level.model import Model

In [None]:
model = Model.build_from_model_description(model_description=model_urdf_path, is_urdf=True)

The visualization is done using the [`meshcat-viz-python`](https://github.com/ami-iit/meshcat-viz-python) package. Let's import it and create a `Visualizer` object.

In [None]:
try:
    from meshcat_viz.world import MeshcatWorld
except:
    %pip install -q git+https://github.com/ami-iit/meshcat-viz-python
    from meshcat_viz.world import MeshcatWorld

world = MeshcatWorld()
world.open()

from IPython.display import IFrame
IFrame(src=world.web_url, width='100%', height='500px')

world.insert_model(
    model_description=model_urdf_path, model_name="Cartpole", is_urdf=True
)

Let's see how the model behaves when not controlled:

In [None]:
for _ in range(200):
    model.integrate(0.01)
    world.update_model(model_name="Cartpole", base_position=model.base_position(), joint_positions=model.joint_positions())

Let's now define the PD controller. We will use the following equations:

\begin{align} \tau &= K_p \left( q_d - q \right) + K_d \left( \dot{q}_d - \dot{q} \right) \end{align}

In [None]:
def pd_controller(q: jax.Array, q_d: jax.Array, q_dot: jax.Array, q_dot_d: jax.Array) -> jax.Array:
    return KP * (q_d - q) + KD * (q_dot_d - q_dot)

Now, we can use the `pd_controller` function to compute the torque to apply to the cartpole. Our aim is to stabilize the cartpole in the upright position, so we set the desired position `q_d` to 0 and the desired velocity `q_dot_d` to 0.

In [None]:
for _ in range(200):
    model.integrate(0.01)
    world.update_model(model_name="Cartpole", base_position=model.base_position(), joint_positions=model.joint_positions())