# `JAXsim` Showcase: PD Controller

<a target="_blank" href="https://colab.research.google.com/github/flferretti/jaxsim/blob/example/notebook/examples/PD_controller.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

First, we install the necessary packages and import them.

In [1]:
from IPython.display import clear_output, HTML

# Install JAX and Gazebo
!pip install -U -q "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install -q git+https://github.com/ami-iit/jaxsim@new_api
!apt -qq update && apt install -qq --no-install-recommends gazebo
clear_output()

import jax
import jax.numpy as jnp
from jaxsim import logging

logging.set_logging_level(logging.LoggingLevel.INFO)
logging.info(f"Running on {jax.devices()}")

[34mjaxsim[85289][0m [1;30mINFO[0m Running on [cuda(id=0)]


We will use a simple cartpole model for this example. The cartpole model is a 2D model with a cart that can move horizontally and a pole that can rotate around the cart. The state of the cartpole is given by the position of the cart, the angle of the pole, the velocity of the cart, and the angular velocity of the pole. The control input is the horizontal force applied to the cart.

In [2]:
import requests

url = "https://raw.githubusercontent.com/flferretti/jaxsim/example/notebook/examples/assets/cartpole.urdf"

response = requests.get(url)
if response.status_code == 200:
    model_urdf_path = response.text
else:
    logging.error("Failed to fetch data")

JAXsim offers a simple high-level API in order to extract quantities needed in most robotic applications. 

In [3]:
from jaxsim.high_level.model import Model

model = Model.build_from_model_description(
    model_description=model_urdf_path, is_urdf=True
)

[34mjaxsim[85289][0m [1;30mINFO[0m Combining the pose of base link 'rail' with the pose of joint 'world_to_rail'
[34mjaxsim[85289][0m [1;30mINFO[0m The kinematic graph doesn't need to be reduced


Let's reset the cartpole to a random state.

In [4]:
random_positions = jax.random.uniform(
    minval=-1.0, maxval=1.0, shape=(model.dofs(),), key=jax.random.PRNGKey(0)
)

model.reset_joint_positions(positions=random_positions)

The visualization is done using the [`meshcat-viz-python`](https://github.com/ami-iit/meshcat-viz-python) package. Let's import it and create a `Visualizer` object.

In [None]:
try:
    from meshcat_viz.world import MeshcatWorld
except:
    !pip install -q git+https://github.com/ami-iit/meshcat-viz-python
    clear_output()
    from meshcat_viz.world import MeshcatWorld

from IPython.display import Javascript

world = MeshcatWorld()
world.meshcat_visualizer.render_static()

world.insert_model(
    model_description=model_urdf_path, model_name="Cartpole", is_urdf=True
)

Let's see how the model behaves when not controlled:

In [None]:
from jaxsim.simulation.ode_integration import IntegratorType

for _ in range(200):
    model.integrate(t0=0.0, tf=0.01, integrator_type=IntegratorType.EulerSemiImplicit)
    
    world.update_model(
        model_name="Cartpole",
        base_position=model.base_position(),
        joint_positions=model.joint_positions(),
        joint_names=model.joint_names(),
    )

Let's now define the PD controller. We will use the following equations:

\begin{align} \tau &= K_p \left( q_d - q \right) + K_d \left( \dot{q}_d - \dot{q} \right) \end{align}

In [None]:
# Define the PD gains
KP = 10.0
KD = 1.0


def pd_controller(
    q: jax.Array, q_d: jax.Array, q_dot: jax.Array, q_dot_d: jax.Array
) -> jax.Array:
    return KP * (q_d - q) + KD * (q_dot_d - q_dot)

Now, we can use the `pd_controller` function to compute the torque to apply to the cartpole. Our aim is to stabilize the cartpole in the upright position, so we set the desired position `q_d` to 0 and the desired velocity `q_dot_d` to 0.

In [None]:
for _ in range(200):
    model.set_joint_generalized_force_targets(
        forces=pd_controller(
            q=model.joint_positions(),
            q_d=jnp.array([0.0, 0.0]),
            q_dot=model.joint_velocities(),
            q_dot_d=jnp.array([0.0, 0.0]),
        )
    )

    logging.info(f"Joint generalized forces: {model.data.model_input.tau}")
    
    model.integrate(t0=0.0, tf=0.01, integrator_type=IntegratorType.EulerSemiImplicit)
    world.update_model(
        model_name="Cartpole",
        base_position=model.base_position(),
        joint_positions=model.joint_positions(),
        joint_names=model.joint_names(),
    )