Skip to content

Commit

Permalink
Implement Forward Kinematics with jax.lax.scan
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Sep 20, 2022
1 parent db580c8 commit d58d6f8
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions src/jaxsim/physics/algos/forward_kinematics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Tuple

import jax
import jax.numpy as jnp
import numpy as np

Expand All @@ -21,21 +24,43 @@ def forward_kinematics_model(
r = jnp.vstack(x_fb[4:7])
W_X_0 = jnp.linalg.inv(Plucker.from_rot_and_trans(Quaternion.to_dcm(qn), r))

# This is the 6D velocity transform from i-th link frame to the world frame
W_X_i = jnp.zeros(shape=[model.NB, 6, 6])
W_X_i = W_X_i.at[0].set(W_X_0)

i_X_pre = model.joint_transforms(q=q)
pre_X_λi = model.tree_transforms

# This is the parent-to-child 6D velocity transforms of all links
i_X_λi = jnp.zeros_like(i_X_pre)

for i in np.arange(start=1, stop=model.NB):
# Parent array mapping: i -> λ(i).
# Exception: λ(0) must not be used, it's initialized to -1.
λ = model.parent

PropagateKinematicsCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
propagate_kinematics_carry = (i_X_λi, W_X_i)

def propagate_kinematics(
carry: PropagateKinematicsCarry, i: jtp.Int
) -> Tuple[PropagateKinematicsCarry, None]:

i_X_λi, W_X_i = carry

i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)

W_X_i_i = W_X_i[model.parent[i]] @ jnp.linalg.inv(i_X_λi[i])
W_X_i_i = W_X_i[λ[i]] @ jnp.linalg.inv(i_X_λi[i])
W_X_i = W_X_i.at[i].set(W_X_i_i)

return (i_X_λi, W_X_i), None

(_, W_X_i), _ = jax.lax.scan(
f=propagate_kinematics,
init=propagate_kinematics_carry,
xs=np.arange(start=1, stop=model.NB),
)

return jnp.stack([Plucker.to_transform(X) for X in list(W_X_i)])


Expand Down

0 comments on commit d58d6f8

Please sign in to comment.