Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Jacobian algorithm #121

Merged
merged 3 commits into from
Mar 27, 2024
Merged

Optimize Jacobian algorithm #121

merged 3 commits into from
Mar 27, 2024

Conversation

diegoferigo
Copy link
Member

@diegoferigo diegoferigo commented Mar 26, 2024

This PR enhances the computation of the Jacobians of all links. Before this PR, we computed in parallel with jax.vmap the free-floating left-trivialized Jacobian of all links ${}^L J_{W,L/B}$. Then, we were adjusting the input and output representations to match the desired ones.

This PR, instead of calling a parallelized version of jaxsim.rbda.jacobian, computes only once a full doubly-left Jacobian ${}^B J_{W,\text{\_}/B}$ (note that there is a $\text{\_}$ instead of $L$), defined as the free-floating Jacobian with all rows filled. The free-floating doubly-left Jacobian of the i-th link ${}^B J_{W,L/B}$ is then computed by filtering the columns of the full Jacobian with the support parent array $\kappa(i)$ of the link, that sets to zero the columns corresponding to all links not part of the path $\pi_B(L)$.

All of this allows to compute the Jacobians of all links by only vmapping the full Jacobian using $\kappa(i)$, therefore the expensive algorithm to compute the full Jacobian is executed only once. The vmapped operation is just a column filtering, therefore its cost is almost zero.

This enhancement should speed up the computation of jaxsim.api.model.forward_dynamics_crb. While at this stage is always better to use jaxsim.api.model.forward_dynamics_aba, the CRB version might become useful in all experiments that extend the forward dynamics with e.g. actuators, friction models, or any other stateless second-order dynamics elements1. In these cases, operating on ABA could become quite difficult, and prototyping using the CRB version might be pretty helpful.

cc @traversaro @DanielePucci


📚 Documentation preview 📚: https://jaxsim--121.org.readthedocs.build//121/

Footnotes

  1. Under the assumption that they do not extend the integrated state vector. In that case, things become more complex since also the integrator has to properly updated.

@diegoferigo diegoferigo self-assigned this Mar 26, 2024
@diegoferigo diegoferigo force-pushed the optimize_jacobian_algo branch 2 times, most recently from 84588d5 to e9302ad Compare March 26, 2024 17:01
Copy link
Contributor

@traversaro traversaro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Did you did any benchmark on the difference before and after this change?

@diegoferigo
Copy link
Member Author

diegoferigo commented Mar 26, 2024

I was running them while you were typing your comment :)

Here below a benchmark on a 58-DoFs ErgoCub model on CPU.

Before

In [2]: %timeit -r1 -n1 _ = js.model.generalized_free_floating_jacobian(model, data0)
28 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

In [3]: %timeit -r 10 -n 1000 _ = js.model.generalized_free_floating_jacobian(model, data0)
1.37 ms ± 82.6 µs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)

In [4]: %timeit -r1 -n1 _ = js.model.forward_dynamics_crb(model, data0)
1min 27s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

In [5]: %timeit -r10 -n1000 _ = js.model.forward_dynamics_crb(model, data0)
1.99 ms ± 69.1 µs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)

After

In [2]: %timeit -r1 -n1 _ = js.model.generalized_free_floating_jacobian(model, data0)
9.5 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

In [3]: %timeit -r 10 -n 1000 _ = js.model.generalized_free_floating_jacobian(model, data0)
185 µs ± 16.4 µs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)

In [4]: %timeit -r1 -n1 _ = js.model.forward_dynamics_crb(model, data0)
55.6 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

In [5]: %timeit -r10 -n1000 _ = js.model.forward_dynamics_crb(model, data0)
896 µs ± 51.5 µs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)

Overview

  • Compiling generalized_free_floating_jacobian is now 3x faster.
  • Compiling forward_dynamics_crb is now 1.5x faster.
  • Running generalized_free_floating_jacobian is now 7x faster.
  • Running forward_dynamics_crb is now 2x faster.
script
import jaxsim.api as js

import resolve_robotics_uri_py
import rod

# Find the urdf file.
urdf_path = resolve_robotics_uri_py.resolve_robotics_uri(
    uri="model://ergoCubSN001/model.urdf"
)

# Build the ROD model.
rod_sdf = rod.Sdf.load(sdf=urdf_path)

# Build the model.
model = js.model.JaxSimModel.build_from_model_description(
    model_description=rod_sdf.model,
)

# Create random data.
data0 = js.data.random_model_data(
    model=model,
    base_pos_bounds=((0, 0, 0.85), (0, 0, 0.85)),
    joint_vel_bounds=(0, 0),
    base_vel_lin_bounds=(0, 0),
    base_vel_ang_bounds=(0, 0),
)

%timeit -r1 -n1 _ = js.model.generalized_free_floating_jacobian(model, data0)

%timeit -r 10 -n 1000 _ = js.model.generalized_free_floating_jacobian(model, data0)

%timeit -r1 -n1 _ = js.model.forward_dynamics_crb(model, data0)

%timeit -r10 -n1000 _ = js.model.forward_dynamics_crb(model, data0)

@diegoferigo
Copy link
Member Author

In other words, on such a large model, with this change we can compute FD with CRB in less than 1ms with JAX running on CPU. Not too bad. As a comparison, the equivalent computation with ABA runs 3x faster:

In [15]: %timeit _ = js.model.forward_dynamics_aba(model, data0)
346 µs ± 6.74 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

@diegoferigo
Copy link
Member Author

cc @ami-iit/vertical_control-oriented-learning

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Optimize vectorization of free-floating Jacobian
2 participants