Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update notebook examples to the functional API #132

Merged
merged 3 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 92 additions & 112 deletions examples/PD_controller.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"JAXsim offers a simple high-level API in order to extract quantities needed in most robotic applications. "
"JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n",
"\n",
"- `model`: an object that defines the dynamics of the system.\n",
"- `data`: an object that contains the state of the system.\n",
"- `integrator`: an object that defines the integration method.\n",
"- `integrator_state`: an object that contains the state of the integrator."
]
},
{
Expand All @@ -77,11 +82,23 @@
"metadata": {},
"outputs": [],
"source": [
"from jaxsim.high_level.model import Model\n",
"import jaxsim.api as js\n",
"from jaxsim import integrators\n",
"\n",
"dt = 0.01\n",
"\n",
"model = Model.build_from_model_description(\n",
"model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_urdf_string, is_urdf=True\n",
")"
")\n",
"data = js.data.JaxSimModelData.build(model=model)\n",
"integrator = integrators.fixed_step.RungeKutta4SO3.build(\n",
" dynamics=js.ode.wrap_system_dynamics_for_integration(\n",
" model=model,\n",
" data=data,\n",
" system_dynamics=js.ode.system_dynamics,\n",
" ),\n",
")\n",
"integrator_state = integrator.init(x0=data.state, t0=0.0, dt=dt)"
]
},
{
Expand All @@ -101,7 +118,7 @@
" minval=-1.0, maxval=1.0, shape=(model.dofs(),), key=jax.random.PRNGKey(0)\n",
")\n",
"\n",
"model.reset_joint_positions(positions=random_positions)"
"data = data.reset_joint_positions(positions=random_positions)"
]
},
{
Expand All @@ -118,17 +135,11 @@
"outputs": [],
"source": [
"# @title Set up MuJoCo renderer\n",
"!{sys.executable} -m pip install -U -q mujoco\n",
"!{sys.executable} -m pip install -q mediapy\n",
"\n",
"import mediapy as media\n",
"import tempfile\n",
"import xml.etree.ElementTree as ET\n",
"import numpy as np\n",
"from jaxsim.mujoco.visualizer import MujocoVisualizer\n",
"from jaxsim.mujoco import RodModelToMjcf, MujocoModelHelper, MujocoVideoRecorder\n",
"from jaxsim.mujoco.loaders import UrdfToMjcf\n",
"\n",
"import distutils.util\n",
"import os\n",
"import subprocess\n",
"\n",
"if IS_COLAB:\n",
" if subprocess.run(\"ffmpeg -version\", shell=True).returncode:\n",
Expand Down Expand Up @@ -171,66 +182,28 @@
" 'by going to the Runtime menu and selecting \"Choose runtime type\".'\n",
" )\n",
"\n",
"camera = {\n",
" \"name\":\"cartpole_camera\",\n",
" \"mode\":\"fixed\",\n",
" \"pos\":\"3.954 3.533 2.343\",\n",
" \"xyaxes\":\"-0.594 0.804 -0.000 -0.163 -0.120 0.979\",\n",
" \"fovy\":\"60\",\n",
"}\n",
"\n",
"def load_mujoco_model_with_camera(xml_string, camera_pos, camera_xyaxes):\n",
" def to_mjcf_string(list_to_str):\n",
" return \" \".join(map(str, list_to_str))\n",
"\n",
" mj_model_raw = mujoco.MjModel.from_xml_string(model_urdf_string)\n",
" path_temp_xml = tempfile.NamedTemporaryFile(mode=\"w+\")\n",
" mujoco.mj_saveLastXML(path_temp_xml.name, mj_model_raw)\n",
" # Add camera in mujoco model\n",
" tree = ET.parse(path_temp_xml)\n",
" for elem in tree.getroot().iter(\"worldbody\"):\n",
" worldbody_elem = elem\n",
" camera_elem = ET.Element(\"camera\")\n",
" # Set attributes\n",
" camera_elem.set(\"name\", \"side\")\n",
" camera_elem.set(\"pos\", to_mjcf_string(camera_pos))\n",
" camera_elem.set(\"xyaxes\", to_mjcf_string(camera_xyaxes))\n",
" camera_elem.set(\"mode\", \"fixed\")\n",
" worldbody_elem.append(camera_elem)\n",
"\n",
" # Save new model\n",
" mujoco_xml_with_camera = ET.tostring(tree.getroot(), encoding=\"unicode\")\n",
" mj_model = mujoco.MjModel.from_xml_string(mujoco_xml_with_camera)\n",
" return mj_model\n",
"\n",
"\n",
"def from_jaxsim_to_mujoco_pos(jaxsim_jointpos, mjmodel, jaxsimmodel):\n",
" mujocoqposaddr2jaxindex = {}\n",
" for jaxjnt in jaxsimmodel.joints():\n",
" jntname = jaxjnt.name()\n",
" mujocoqposaddr2jaxindex[mjmodel.joint(jntname).qposadr[0]] = jaxjnt.index() - 1\n",
"\n",
" mujoco_jointpos = jaxsim_jointpos\n",
" for i in range(0, len(mujoco_jointpos)):\n",
" mujoco_jointpos[i] = jaxsim_jointpos[mujocoqposaddr2jaxindex[i]]\n",
"\n",
" return mujoco_jointpos\n",
"\n",
"\n",
"# To get a good camera location, you can use \"Copy camera\" functionality in MuJoCo GUI\n",
"mj_model = load_mujoco_model_with_camera(\n",
" model_urdf_string,\n",
" [3.954, 3.533, 2.343],\n",
" [-0.594, 0.804, -0.000, -0.163, -0.120, 0.979],\n",
")\n",
"renderer = mujoco.Renderer(mj_model, height=480, width=640)\n",
"mjcf_string, assets = UrdfToMjcf.convert(urdf=model.built_from, cameras=camera)\n",
"\n",
"mj_model_helper = MujocoModelHelper.build_from_xml(\n",
" mjcf_description=mjcf_string, assets=assets\n",
")\n",
"\n",
"def get_image(camera, mujocojointpos) -> np.ndarray:\n",
" \"\"\"Renders the environment state.\"\"\"\n",
" # Copy joint data in mjdata state\n",
" d = mujoco.MjData(mj_model)\n",
" d.qpos = mujocojointpos\n",
"\n",
" # Forward kinematics\n",
" mujoco.mj_forward(mj_model, d)\n",
"\n",
" # use the mjData object to update the renderer\n",
" renderer.update_scene(d, camera=camera)\n",
" return renderer.render()"
"# Create the video recorder.\n",
"recorder = MujocoVideoRecorder(\n",
" model=mj_model_helper.model,\n",
" data=mj_model_helper.data,\n",
" fps=int(1 / 0.010),\n",
" width=320 * 4,\n",
" height=240 * 4,\n",
")"
]
},
{
Expand All @@ -246,24 +219,27 @@
"metadata": {},
"outputs": [],
"source": [
"from jaxsim.simulation.ode_integration import IntegratorType\n",
"\n",
"sim_images = []\n",
"timestep = 0.01\n",
"for _ in range(300):\n",
" sim_images.append(\n",
" get_image(\n",
" \"side\",\n",
" from_jaxsim_to_mujoco_pos(\n",
" np.array(model.joint_positions()), mj_model, model\n",
" ),\n",
" )\n",
"import mediapy as media\n",
"\n",
"for _ in range(500):\n",
" data, integrator_state = js.model.step(\n",
" dt=dt,\n",
" model=model,\n",
" data=data,\n",
" integrator=integrator,\n",
" integrator_state=integrator_state,\n",
" joint_forces=None,\n",
" link_forces=None,\n",
" )\n",
" model.integrate(\n",
" t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit\n",
"\n",
" mj_model_helper.set_joint_positions(\n",
" positions=data.joint_positions(), joint_names=model.joint_names()\n",
" )\n",
"\n",
"media.show_video(sim_images, fps=1 / timestep)"
" recorder.record_frame(camera_name=\"cartpole_camera\")\n",
"\n",
"media.show_video(recorder.frames, fps=1 / dt)\n",
"recorder.frames = []"
]
},
{
Expand All @@ -290,13 +266,17 @@
"KP = 10.0\n",
"KD = 6.0\n",
"\n",
"# Compute the gravity compensation term\n",
"H = model.free_floating_bias_forces()[6:]\n",
"\n",
"\n",
"def pd_controller(\n",
" q: jax.Array, q_d: jax.Array, q_dot: jax.Array, q_dot_d: jax.Array\n",
" data: js.data.JaxSimModelData, q_d: jax.Array, q_dot_d: jax.Array\n",
") -> jax.Array:\n",
"\n",
" # Compute the gravity compensation term\n",
" H = js.model.free_floating_bias_forces(model=model, data=data)[6:]\n",
"\n",
" q = data.joint_positions()\n",
" q_dot = data.joint_velocities()\n",
"\n",
" return H + KP * (q_d - q) + KD * (q_dot_d - q_dot)"
]
},
Expand All @@ -313,31 +293,31 @@
"metadata": {},
"outputs": [],
"source": [
"sim_images = []\n",
"timestep = 0.01\n",
"\n",
"for _ in range(300):\n",
" sim_images.append(\n",
" get_image(\n",
" \"side\",\n",
" from_jaxsim_to_mujoco_pos(\n",
" np.array(model.joint_positions()), mj_model, model\n",
" ),\n",
" )\n",
"for _ in range(500):\n",
" control_torques = pd_controller(\n",
" data=data,\n",
" q_d=jnp.array([0.0, 0.0]),\n",
" q_dot_d=jnp.array([0.0, 0.0]),\n",
" )\n",
" model.set_joint_generalized_force_targets(\n",
" forces=pd_controller(\n",
" q=model.joint_positions(),\n",
" q_d=jnp.array([0.0, 0.0]),\n",
" q_dot=model.joint_velocities(),\n",
" q_dot_d=jnp.array([0.0, 0.0]),\n",
" )\n",
"\n",
" data, integrator_state = js.model.step(\n",
" dt=dt,\n",
" model=model,\n",
" data=data,\n",
" integrator=integrator,\n",
" integrator_state=integrator_state,\n",
" joint_forces=control_torques,\n",
" link_forces=None,\n",
" )\n",
" model.integrate(\n",
" t0=0.0, tf=timestep, integrator_type=IntegratorType.EulerSemiImplicit\n",
"\n",
" mj_model_helper.set_joint_positions(\n",
" positions=data.joint_positions(), joint_names=model.joint_names()\n",
" )\n",
"\n",
"media.show_video(sim_images, fps=1 / timestep)"
" recorder.record_frame(camera_name=\"cartpole_camera\")\n",
"\n",
"media.show_video(recorder.frames, fps=1 / dt)\n",
"recorder.frames = []"
]
}
],
Expand Down Expand Up @@ -370,7 +350,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
"version": "3.11.8"
}
},
"nbformat": 4,
Expand Down
Loading
Loading