Skip to content

Commit

Permalink
Merge pull request #15 from ami-iit/feature/fk_with_jax.lax.scan
Browse files Browse the repository at this point in the history
Implement Forward Kinematics with `jax.lax.scan`
  • Loading branch information
diegoferigo committed Sep 21, 2022
2 parents 7414921 + 61215ed commit fbf8e15
Showing 1 changed file with 28 additions and 3 deletions.
31 changes: 28 additions & 3 deletions src/jaxsim/physics/algos/forward_kinematics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Tuple

import jax
import jax.numpy as jnp
import numpy as np

Expand All @@ -21,26 +24,48 @@ def forward_kinematics_model(
r = jnp.vstack(x_fb[4:7])
W_X_0 = jnp.linalg.inv(Plucker.from_rot_and_trans(Quaternion.to_dcm(qn), r))

# This is the 6D velocity transform from i-th link frame to the world frame
W_X_i = jnp.zeros(shape=[model.NB, 6, 6])
W_X_i = W_X_i.at[0].set(W_X_0)

i_X_pre = model.joint_transforms(q=q)
pre_X_λi = model.tree_transforms

# This is the parent-to-child 6D velocity transforms of all links
i_X_λi = jnp.zeros_like(i_X_pre)

for i in np.arange(start=1, stop=model.NB):
# Parent array mapping: i -> λ(i).
# Exception: λ(0) must not be used, it's initialized to -1.
λ = model.parent

PropagateKinematicsCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
propagate_kinematics_carry = (i_X_λi, W_X_i)

def propagate_kinematics(
carry: PropagateKinematicsCarry, i: jtp.Int
) -> Tuple[PropagateKinematicsCarry, None]:

i_X_λi, W_X_i = carry

i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)

W_X_i_i = W_X_i[model.parent[i]] @ jnp.linalg.inv(i_X_λi[i])
W_X_i_i = W_X_i[λ[i]] @ jnp.linalg.inv(i_X_λi[i])
W_X_i = W_X_i.at[i].set(W_X_i_i)

return (i_X_λi, W_X_i), None

(_, W_X_i), _ = jax.lax.scan(
f=propagate_kinematics,
init=propagate_kinematics_carry,
xs=np.arange(start=1, stop=model.NB),
)

return jnp.stack([Plucker.to_transform(X) for X in list(W_X_i)])


def forward_kinematics(
model: PhysicsModel, body_index: int, q: jtp.Vector, xfb: jtp.Vector
model: PhysicsModel, body_index: jtp.Int, q: jtp.Vector, xfb: jtp.Vector
) -> jtp.Matrix:

return forward_kinematics_model(model=model, q=q, xfb=xfb)[body_index]

0 comments on commit fbf8e15

Please sign in to comment.