Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to compute the jacobian derivative of collidable points #213

Merged
merged 8 commits into from
Jul 22, 2024

Conversation

flferretti
Copy link
Collaborator

@flferretti flferretti commented Jul 19, 2024

This PR is basically a copy-paste of the logic introduced by @xela-95 in #208. The only difference is that here the computation is vmapped over all the collidable points.


📚 Documentation preview 📚: https://jaxsim--213.org.readthedocs.build//213/

@flferretti flferretti self-assigned this Jul 19, 2024
@flferretti flferretti force-pushed the feature/jacobian_derivate_collidable_points branch 2 times, most recently from 4e40904 to ddd08e8 Compare July 19, 2024 15:06
@flferretti flferretti marked this pull request as ready for review July 19, 2024 15:07
@flferretti flferretti requested a review from xela-95 July 19, 2024 15:08
@flferretti flferretti force-pushed the feature/jacobian_derivate_collidable_points branch 5 times, most recently from 285b2e3 to 70dc778 Compare July 19, 2024 16:17
Copy link
Member

@xela-95 xela-95 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @flferretti!! For me it's good to go! 🚀

Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Thanks for the new functionality. I left some minor suggestion. In addition of what written here below.


I had a hunch of a possible speed-up. Right now you compute for no reason the $J$ and $\dot{J}$ of links multiple times, once for each collidable point. Usually, many points belong to a single link, there is no need to re-compute the link Jacobian every time.

I suggest to compute the Jacobians of all links first, and then only get the one corresponding to the link you need. For a simple sphere with 250 points, the following refactor switched from 110 ms to 2.5 ms of runtime.

Suggestion
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def jacobian_derivative(
    model: js.model.JaxSimModel,
    data: js.data.JaxSimModelData,
    *,
    output_vel_repr: VelRepr | None = None,
) -> jtp.Matrix:
    r"""
    Compute the derivative of the free-floating jacobian of the contact points.
    Args:
        model: The model to consider.
        data: The data of the considered model.
        output_vel_repr:
            The output velocity representation of the free-floating jacobian derivative.
    Returns:
        The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the contact points.
    Note:
        The input representation of the free-floating jacobian derivative is the active
        velocity representation.
    """

    output_vel_repr = (
        output_vel_repr if output_vel_repr is not None else data.velocity_representation
    )

    # Get the index of the parent link and the position of the collidable point.
    parent_link_idxs = jnp.array(model.kin_dyn_parameters.contact_parameters.body)
    L_p_Ci = jnp.array(model.kin_dyn_parameters.contact_parameters.point)
    contact_idxs = jnp.arange(L_p_Ci.shape[0])

    # =====================================================
    # Compute quantities to adjust the input representation
    # =====================================================

    def compute_T(model: js.model.JaxSimModel, X: jtp.Matrix) -> jtp.Matrix:
        In = jnp.eye(model.dofs())
        T = jax.scipy.linalg.block_diag(X, In)
        return T

    def compute_Ṫ(model: js.model.JaxSimModel, : jtp.Matrix) -> jtp.Matrix:
        On = jnp.zeros(shape=(model.dofs(), model.dofs()))
         = jax.scipy.linalg.block_diag(, On)
        return 

    # Compute the operator to change the representation of ν, and its
    # time derivative.
    match data.velocity_representation:
        case VelRepr.Inertial:
            W_H_W = jnp.eye(4)
            W_X_W = Adjoint.from_transform(transform=W_H_W)
            W_Ẋ_W = jnp.zeros((6, 6))

            T = compute_T(model=model, X=W_X_W)
             = compute_Ṫ(model=model, =W_Ẋ_W)

        case VelRepr.Body:
            W_H_B = data.base_transform()
            W_X_B = Adjoint.from_transform(transform=W_H_B)
            B_v_WB = data.base_velocity()
            B_vx_WB = Cross.vx(B_v_WB)
            W_Ẋ_B = W_X_B @ B_vx_WB

            T = compute_T(model=model, X=W_X_B)
             = compute_Ṫ(model=model, =W_Ẋ_B)

        case VelRepr.Mixed:
            W_H_B = data.base_transform()
            W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
            W_X_BW = Adjoint.from_transform(transform=W_H_BW)
            BW_v_WB = data.base_velocity()
            BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
            BW_vx_W_BW = Cross.vx(BW_v_W_BW)
            W_Ẋ_BW = W_X_BW @ BW_vx_W_BW

            T = compute_T(model=model, X=W_X_BW)
             = compute_Ṫ(model=model, =W_Ẋ_BW)

        case _:
            raise ValueError(data.velocity_representation)

    # =====================================================
    # Compute quantities to adjust the output representation
    # =====================================================

    with data.switch_velocity_representation(VelRepr.Inertial):
        # Compute the Jacobian of the parent link in inertial representation.
        W_J_WL_W = js.model.generalized_free_floating_jacobian(
            model=model,
            data=data,
            output_vel_repr=VelRepr.Inertial,
        )
        # Compute the Jacobian derivative of the parent link in inertial representation.
        W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative(
            model=model,
            data=data,
            output_vel_repr=VelRepr.Inertial,
        )

    def compute_O_J̇_WC_I(
        L_p_C: jtp.Vector,
        contact_idx: jtp.Int,
    ) -> jtp.Matrix:

        parent_link_idx = parent_link_idxs[contact_idx]

        match output_vel_repr:
            case VelRepr.Inertial:
                O_X_W = W_X_W = Adjoint.from_transform(transform=jnp.eye(4))
                O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6))

            case VelRepr.Body:
                W_H_L = js.link.transform(
                    model=model, data=data, link_index=parent_link_idx
                )
                L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
                W_H_C = W_H_L @ L_H_C
                O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
                with data.switch_velocity_representation(VelRepr.Inertial):
                    W_nu = data.generalized_velocity()
                W_v_WC = W_J_WL_W[parent_link_idx] @ W_nu
                W_vx_WC = Cross.vx(W_v_WC)
                O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC

            case VelRepr.Mixed:
                W_H_L = js.link.transform(
                    model=model, data=data, link_index=parent_link_idx
                )
                L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
                W_H_C = W_H_L @ L_H_C
                W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
                CW_H_W = Transform.inverse(W_H_CW)
                O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W)
                with data.switch_velocity_representation(VelRepr.Mixed):
                    CW_J_WC_CW = jacobian(
                        model=model,
                        data=data,
                        output_vel_repr=VelRepr.Mixed,
                    )[contact_idx]
                    CW_v_WC = CW_J_WC_CW @ data.generalized_velocity()
                W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3])
                W_vx_W_CW = Cross.vx(W_v_W_CW)
                O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW

            case _:
                raise ValueError(output_vel_repr)

        O_J̇_WC_I = jnp.zeros(shape=(6, 6 + model.dofs()))
        O_J̇_WC_I += O_Ẋ_W @ W_J_WL_W[parent_link_idx] @ T
        O_J̇_WC_I += O_X_W @ W_J̇_WL_W[parent_link_idx] @ T
        O_J̇_WC_I += O_X_W @ W_J_WL_W[parent_link_idx] @ 

        return O_J̇_WC_I

    O_J̇_WC = jax.vmap(compute_O_J̇_WC_I)(L_p_Ci, contact_idxs)

    return O_J̇_WC

src/jaxsim/api/contact.py Outdated Show resolved Hide resolved
tests/test_api_contact.py Outdated Show resolved Hide resolved
tests/test_api_contact.py Outdated Show resolved Hide resolved
tests/test_api_contact.py Outdated Show resolved Hide resolved
tests/test_api_contact.py Outdated Show resolved Hide resolved
@flferretti
Copy link
Collaborator Author

Thanks a lot @xela-95 and @diegoferigo! I also added generalized_free_floating_jacobian_derivative in api.model as it was required to address Diego's suggestion.

src/jaxsim/api/model.py Outdated Show resolved Hide resolved
@diegoferigo
Copy link
Member

I also added generalized_free_floating_jacobian_derivative in api.model as it was required to address Diego's suggestion.

Yep, I didn't realize the I had it just locally from old experiments :)

@flferretti flferretti force-pushed the feature/jacobian_derivate_collidable_points branch from 1da04cb to 88c5d4a Compare July 22, 2024 10:35
@flferretti flferretti force-pushed the feature/jacobian_derivate_collidable_points branch from 885d7d2 to 3ce64f3 Compare July 22, 2024 12:59
Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last minor suggestions, not necessary for merging.

src/jaxsim/api/contact.py Outdated Show resolved Hide resolved
src/jaxsim/api/contact.py Outdated Show resolved Hide resolved
src/jaxsim/api/contact.py Outdated Show resolved Hide resolved
src/jaxsim/api/contact.py Outdated Show resolved Hide resolved
src/jaxsim/api/contact.py Outdated Show resolved Hide resolved
flferretti and others added 2 commits July 22, 2024 15:30
Co-authored-by: Diego Ferigo <diego.ferigo@iit.it>
Co-authored-by: Diego Ferigo <diego.ferigo@iit.it>
@flferretti flferretti force-pushed the feature/jacobian_derivate_collidable_points branch from c3bbe35 to 7f82bac Compare July 22, 2024 13:30
@flferretti flferretti merged commit 99eba39 into main Jul 22, 2024
24 checks passed
@flferretti flferretti deleted the feature/jacobian_derivate_collidable_points branch July 22, 2024 13:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants