Skip to content

Commit

Permalink
Put ABA and RNEA with motors in separate files
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 6, 2023
1 parent 5a8a905 commit c706b21
Show file tree
Hide file tree
Showing 4 changed files with 513 additions and 112 deletions.
83 changes: 18 additions & 65 deletions src/jaxsim/physics/algos/aba.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def aba(
f_ext: jtp.Matrix = None,
) -> Tuple[jtp.Vector, jtp.Vector]:
"""
Articulated Body Algorithm (ABA) algorithm with motor dynamics for forward dynamics.
Articulated Body Algorithm (ABA) algorithm for forward dynamics.
"""

x_fb, q, qd, _, tau, f_ext = utils.process_inputs(
Expand All @@ -35,25 +35,13 @@ def aba(
S = model.motion_subspaces(q=q)
λ = model.parent_array()

# Extract motor parameters from the physics model
Γ = jnp.array([*model._joint_motor_gear_ratio.values()])
IM = jnp.array(
[jnp.eye(6) * m for m in [*model._joint_motor_inertia.values()]] * model.NB
)
K̅ᵥ = Γ.T * jnp.array([*model._joint_motor_viscous_friction.values()]) * Γ
m_S = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0)

# Initialize buffers
v = jnp.array([jnp.zeros([6, 1])] * model.NB)
MA = jnp.array([jnp.zeros([6, 6])] * model.NB)
pA = jnp.array([jnp.zeros([6, 1])] * model.NB)
c = jnp.array([jnp.zeros([6, 1])] * model.NB)
i_X_λi = jnp.zeros_like(i_X_pre)

m_v = jnp.array([jnp.zeros([6, 1])] * model.NB)
m_c = jnp.array([jnp.zeros([6, 1])] * model.NB)
pR = jnp.array([jnp.zeros([6, 1])] * model.NB)

# Base pose B_X_W and velocity
base_quat = jnp.vstack(x_fb[0:4])
base_pos = jnp.vstack(x_fb[4:7])
Expand Down Expand Up @@ -93,36 +81,27 @@ def aba(
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
]

pass_1_carry = (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0)
pass_1_carry = (i_X_λi, v, c, MA, pA, i_X_0)

# Pass 1
def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:
ii = i - 1
i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0 = carry
i_X_λi, v, c, MA, pA, i_X_0 = carry

# Compute parent-to-child transform
i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)

# Propagate link velocity
vJ = S[i] * qd[ii] * (qd.size != 0)
m_vJ = m_S[i] * qd[ii] * (qd.size != 0)
vJ = S[i] * qd[ii] if qd.size != 0 else S[i] * 0

v_i = i_X_λi[i] @ v[λ[i]] + vJ
v = v.at[i].set(v_i)

m_v_i = i_X_λi[i] @ v[λ[i]] + m_vJ
m_v = m_v.at[i].set(v_i)

c_i = Cross.vx(v[i]) @ vJ
c = c.at[i].set(c_i)
m_c_i = Cross.vx(m_v[i]) @ m_vJ
m_c = m_c.at[i].set(m_c_i)

# Initialize articulated-body inertia
MA_i = jnp.array(M[i])
Expand All @@ -136,72 +115,46 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]:
pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(f_ext[i])
pA = pA.at[i].set(pA_i)

pR_i = Cross.vx_star(m_v[i]) @ IM[i] @ m_v[i] - K̅ᵥ[i] * m_v[i]
pR = pR.at[i].set(pR_i)

return (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0), None
return (i_X_λi, v, c, MA, pA, i_X_0), None

(i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0), _ = jax.lax.scan(
(i_X_λi, v, c, MA, pA, i_X_0), _ = jax.lax.scan(
f=loop_body_pass1,
init=pass_1_carry,
xs=np.arange(start=1, stop=model.NB),
)

U = jnp.zeros_like(S)
m_U = jnp.zeros_like(S)
d = jnp.zeros(shape=(model.NB, 1))
u = jnp.zeros(shape=(model.NB, 1))
m_u = jnp.zeros(shape=(model.NB, 1))

Pass2Carry = Tuple[
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
jtp.MatrixJax,
]

pass_2_carry = (U, m_U, d, u, m_u, MA, pA)
pass_2_carry = (U, d, u, MA, pA)

# Pass 2
def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]:
ii = i - 1
U, m_U, d, u, m_u, MA, pA = carry
U, d, u, MA, pA = carry

# Compute intermediate results
u_i = tau[ii] - S[i].T @ pA[i] if tau.size != 0 else -S[i].T @ pA[i]
u = u.at[i].set(u_i.squeeze())

has_motors = Γ[i] != 1

m_u_i = (
tau[ii] / Γ[i] * has_motors - m_S[i].T @ pR[i]
if tau.size != 0
else -m_S[i].T @ pR[i]
)
m_u = m_u.at[i].set(m_u_i.squeeze())

U_i = MA[i] @ S[i]
U = U.at[i].set(U_i)

m_U_i = IM[i] @ m_S[i]
m_U = m_U.at[i].set(m_U_i)

d_i = S[i].T @ MA[i] @ S[i] + m_S[i].T @ IM[i] @ m_S[i]
d_i = S[i].T @ U[i]
d = d.at[i].set(d_i.squeeze())

u_i = tau[ii] - S[i].T @ pA[i] if tau.size != 0 else -S[i].T @ pA[i]
u = u.at[i].set(u_i.squeeze())

# Compute the articulated-body inertia and bias forces of this link
Ma = MA[i] + IM[i] - U[i] / d[i] @ U[i].T - m_U[i] / d[i] @ m_U[i].T
pa = (
pA[i]
+ pR[i]
+ Ma[i] @ c[i]
+ IM[i] @ m_c[i]
+ U[i] / d[i] * u[i]
+ m_U[i] / d[i] * m_u[i]
)
Ma = MA[i] - U[i] / d[i] @ U[i].T
pa = pA[i] + Ma @ c[i] + U[i] * u[i] / d[i]

# Propagate them to the parent, handling the base link
def propagate(
Expand All @@ -224,9 +177,9 @@ def propagate(
operand=(MA, pA),
)

return (U, m_U, d, u, m_u, MA, pA), None
return (U, d, u, MA, pA), None

(U, m_U, d, u, m_u, MA, pA), _ = jax.lax.scan(
(U, d, u, MA, pA), _ = jax.lax.scan(
f=loop_body_pass2,
init=pass_2_carry,
xs=np.flip(np.arange(start=1, stop=model.NB)),
Expand All @@ -253,8 +206,8 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> Tuple[Pass3Carry, None]:
a_i = i_X_λi[i] @ a[λ[i]] + c[i]

# Compute joint accelerations
qdd_ii = (u[i] + m_u[i] - (U[i].T + m_U[i].T) @ a_i) / d[i]
qdd = qdd.at[ii].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd
qdd_ii = (u[i] - U[i].T @ a_i) / d[i]
qdd = qdd.at[i - 1].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd

a_i = a_i + S[i] * qdd[ii] if qdd.size != 0 else a_i
a = a.at[i].set(a_i)
Expand Down
Loading

0 comments on commit c706b21

Please sign in to comment.