Skip to content

Commit

Permalink
Merge pull request #62 from flferretti/motor_dynamics
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Dec 6, 2023
2 parents 293161c + f82e351 commit eb99cab
Show file tree
Hide file tree
Showing 6 changed files with 628 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/jaxsim/high_level/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,27 @@ def position_limit(self, dof: int = None) -> tuple[jtp.Float, jtp.Float]:

return jnp.array(low, dtype=float), jnp.array(high, dtype=float)

# =============
# Motor methods
# =============
@functools.partial(oop.jax_tf.method_ro)
def motor_inertia(self) -> jtp.Vector:
""""""

return jnp.array(self.joint_description.motor_inertia, dtype=float)

@functools.partial(oop.jax_tf.method_ro)
def motor_gear_ratio(self) -> jtp.Vector:
""""""

return jnp.array(self.joint_description.motor_gear_ratio, dtype=float)

@functools.partial(oop.jax_tf.method_ro)
def motor_viscous_friction(self) -> jtp.Vector:
""""""

return jnp.array(self.joint_description.motor_viscous_friction, dtype=float)

# =================
# Multi-DoF methods
# =================
Expand Down
99 changes: 99 additions & 0 deletions src/jaxsim/high_level/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,13 +1137,31 @@ def forward_dynamics_crb(
τ = jnp.atleast_1d(τ.squeeze())
τ = jnp.vstack(τ) if τ.size > 0 else jnp.empty(shape=(0, 1))

# Extract motor parameters from the physics model
GR = self.motor_gear_ratios()
IM = self.motor_inertias()
KV = jnp.diag(self.motor_viscous_frictions())

# Compute auxiliary quantities
Γ = jnp.diag(GR)
K̅ᵥ = Γ.T @ KV @ Γ

# Compute terms of the floating-base EoM
M = self.free_floating_mass_matrix()
h = jnp.vstack(self.free_floating_bias_forces())
J = self.generalized_free_floating_jacobian()
f_ext = jnp.vstack(self.external_forces().flatten())
S = jnp.block([jnp.zeros(shape=(self.dofs(), 6)), jnp.eye(self.dofs())]).T

# Configure the slice for fixed/floating base robots
sl = np.s_[0:] if self.physics_model.is_floating_base else np.s_[6:]
sl_m = np.s_[M.shape[0] - self.dofs() :]

# Add the motor related terms to the EoM
M = M.at[sl_m, sl_m].set(M[sl_m, sl_m] + jnp.diag(Γ.T @ IM @ Γ))
h = h.at[sl_m].set(h[sl_m] + K̅ᵥ @ self.joint_velocities()[:, None])
S = S.at[sl_m].set(S[sl_m])

# Compute the generalized acceleration by inverting the EoM
ν̇ = jax.lax.select(
pred=self.floating_base(),
Expand Down Expand Up @@ -1479,6 +1497,87 @@ def integrate(
},
)

# ==============
# Motor dynamics
# ==============

@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
def set_motor_inertias(
self, inertias: jtp.Vector, joint_names: tuple[str, ...] = None
) -> None:
joint_names = joint_names or self.joint_names()

if inertias.size != len(joint_names):
raise ValueError("Wrong arguments size", inertias.size, len(joint_names))

self.physics_model._joint_motor_inertia.update(
dict(zip(self.physics_model._joint_motor_inertia, inertias))
)

logging.info("Setting attribute `motor_inertias`")

@functools.partial(oop.jax_tf.method_rw, jit=False)
def set_motor_gear_ratios(
self, gear_ratios: jtp.Vector, joint_names: tuple[str, ...] = None
) -> None:
joint_names = joint_names or self.joint_names()

if gear_ratios.size != len(joint_names):
raise ValueError("Wrong arguments size", gear_ratios.size, len(joint_names))

# Check on gear ratios if motor_inertias are not zero
for idx, gr in enumerate(gear_ratios):
if gr != 0 and self.motor_inertias()[idx] == 0:
raise ValueError(
f"Zero motor inertia with non-zero gear ratio found in position {idx}"
)

self.physics_model._joint_motor_gear_ratio.update(
dict(zip(self.physics_model._joint_motor_gear_ratio, gear_ratios))
)

logging.info("Setting attribute `motor_gear_ratios`")

@functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
def set_motor_viscous_frictions(
self, viscous_frictions: jtp.Vector, joint_names: tuple[str, ...] = None
) -> None:
joint_names = joint_names or self.joint_names()

if viscous_frictions.size != len(joint_names):
raise ValueError(
"Wrong arguments size", viscous_frictions.size, len(joint_names)
)

self.physics_model._joint_motor_viscous_friction.update(
dict(
zip(
self.physics_model._joint_motor_viscous_friction,
viscous_frictions,
)
)
)

logging.info("Setting attribute `motor_viscous_frictions`")

@functools.partial(oop.jax_tf.method_ro, jit=False)
def motor_inertias(self) -> jtp.Vector:
return jnp.array(
[*self.physics_model._joint_motor_inertia.values()], dtype=float
)

@functools.partial(oop.jax_tf.method_ro, jit=False)
def motor_gear_ratios(self) -> jtp.Vector:
return jnp.array(
[*self.physics_model._joint_motor_gear_ratio.values()], dtype=float
)

@functools.partial(oop.jax_tf.method_ro, jit=False)
def motor_viscous_frictions(self) -> jtp.Vector:
return jnp.array(
[*self.physics_model._joint_motor_viscous_friction.values()], dtype=float
)

# ===============
# Private methods
# ===============
Expand Down
4 changes: 4 additions & 0 deletions src/jaxsim/parsers/descriptions/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class JointDescription(JaxsimDataclass):
position_limit: Tuple[float, float] = (0.0, 0.0)
initial_position: Union[float, npt.NDArray] = 0.0

motor_inertia: float = 0.0
motor_viscous_friction: float = 0.0
motor_gear_ratio: float = 1.0

def __post_init__(self):
if self.axis is not None:
with self.mutable_context(
Expand Down
Loading

0 comments on commit eb99cab

Please sign in to comment.