Skip to content

Commit

Permalink
Update existing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Aug 4, 2023
1 parent 296f34b commit 1b101d8
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 49 deletions.
50 changes: 23 additions & 27 deletions tests/test_eom.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,32 @@
import pathlib

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from pytest import param as p

from jaxsim.high_level.common import VelRepr
from jaxsim.high_level.model import Model

from . import utils_idyntree, utils_models, utils_rng
from .utils_models import Robot


@pytest.mark.parametrize(
"robot, vel_repr",
[
(utils_models.Robot.DoublePendulum, VelRepr.Inertial),
(utils_models.Robot.DoublePendulum, VelRepr.Body),
(utils_models.Robot.DoublePendulum, VelRepr.Mixed),
(utils_models.Robot.Ur10, VelRepr.Inertial),
(utils_models.Robot.Ur10, VelRepr.Body),
(utils_models.Robot.Ur10, VelRepr.Mixed),
(utils_models.Robot.AnymalC, VelRepr.Inertial),
(utils_models.Robot.AnymalC, VelRepr.Body),
(utils_models.Robot.AnymalC, VelRepr.Mixed),
(utils_models.Robot.Cassie, VelRepr.Inertial),
(utils_models.Robot.Cassie, VelRepr.Body),
(utils_models.Robot.Cassie, VelRepr.Mixed),
# (utils_models.Robot.iCub, VelRepr.Inertial),
# (utils_models.Robot.iCub, VelRepr.Body),
# (utils_models.Robot.iCub, VelRepr.Mixed),
p(*[Robot.DoublePendulum, VelRepr.Inertial], id="DoublePendulum-Inertial"),
p(*[Robot.DoublePendulum, VelRepr.Body], id="DoublePendulum-Body"),
p(*[Robot.DoublePendulum, VelRepr.Mixed], id="DoublePendulum-Mixed"),
p(*[Robot.Ur10, VelRepr.Inertial], id="Ur10-Inertial"),
p(*[Robot.Ur10, VelRepr.Body], id="Ur10-Body"),
p(*[Robot.Ur10, VelRepr.Mixed], id="Ur10-Mixed"),
p(*[Robot.AnymalC, VelRepr.Inertial], id="AnymalC-Inertial"),
p(*[Robot.AnymalC, VelRepr.Body], id="AnymalC-Body"),
p(*[Robot.AnymalC, VelRepr.Mixed], id="AnymalC-Mixed"),
p(*[Robot.Cassie, VelRepr.Inertial], id="Cassie-Inertial"),
p(*[Robot.Cassie, VelRepr.Body], id="Cassie-Body"),
p(*[Robot.Cassie, VelRepr.Mixed], id="Cassie-Mixed"),
],
)
def test_eom(robot: utils_models.Robot, vel_repr: VelRepr) -> None:
Expand Down Expand Up @@ -102,13 +101,10 @@ def test_eom(robot: utils_models.Robot, vel_repr: VelRepr) -> None:
# Test individual terms of the EoM
# ================================

jit_enabled = True
fn = jax.jit if jit_enabled else lambda x: x

M_jaxsim = fn(model_jaxsim.free_floating_mass_matrix)()
g_jaxsim = fn(model_jaxsim.free_floating_gravity_forces)()
h_jaxsim = fn(model_jaxsim.free_floating_bias_forces)()
J_jaxsim = np.vstack([link.jacobian() for link in model_jaxsim.links()])
M_jaxsim = model_jaxsim.free_floating_mass_matrix()
g_jaxsim = model_jaxsim.free_floating_gravity_forces()
J_jaxsim = jnp.vstack([link.jacobian() for link in model_jaxsim.links()])
h_jaxsim = model_jaxsim.free_floating_bias_forces()

sl = np.s_[0:] if model_jaxsim.floating_base() else np.s_[6:]
assert M_jaxsim[sl, sl] == pytest.approx(M_idt[sl, sl], abs=1e-3)
Expand All @@ -120,13 +116,13 @@ def test_eom(robot: utils_models.Robot, vel_repr: VelRepr) -> None:
# Test the forward dynamics computed with CRB
# ===========================================

J_ff = fn(model_jaxsim.generalized_free_floating_jacobian)()
f_ext = fn(model_jaxsim.external_forces)().flatten()
nud = np.hstack(fn(model_jaxsim.forward_dynamics_crb)(tau=tau))
J_ff = model_jaxsim.generalized_free_floating_jacobian()
f_ext = model_jaxsim.external_forces().flatten()
ν̇ = np.hstack(model_jaxsim.forward_dynamics_crb(tau=tau))
S = np.block(
[np.zeros(shape=(model_jaxsim.dofs(), 6)), np.eye(model_jaxsim.dofs())]
).T

assert h_jaxsim[sl] == pytest.approx(
(S @ tau + J_ff.T @ f_ext - M_jaxsim @ nud)[sl], abs=1e-3
(S @ tau + J_ff.T @ f_ext - M_jaxsim @ ν̇)[sl], abs=1e-3
)
40 changes: 18 additions & 22 deletions tests/test_forward_dynamics.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
import jax
import numpy as np
import pytest
from pytest import param as p

from jaxsim.high_level.common import VelRepr
from jaxsim.high_level.model import Model

from . import utils_models, utils_rng
from .utils_models import Robot


@pytest.mark.parametrize(
"robot, vel_repr",
[
(utils_models.Robot.DoublePendulum, VelRepr.Inertial),
(utils_models.Robot.DoublePendulum, VelRepr.Body),
(utils_models.Robot.DoublePendulum, VelRepr.Mixed),
(utils_models.Robot.Ur10, VelRepr.Inertial),
(utils_models.Robot.Ur10, VelRepr.Body),
(utils_models.Robot.Ur10, VelRepr.Mixed),
(utils_models.Robot.AnymalC, VelRepr.Inertial),
(utils_models.Robot.AnymalC, VelRepr.Body),
(utils_models.Robot.AnymalC, VelRepr.Mixed),
(utils_models.Robot.Cassie, VelRepr.Inertial),
(utils_models.Robot.Cassie, VelRepr.Body),
(utils_models.Robot.Cassie, VelRepr.Mixed),
# (utils_models.Robot.iCub, VelRepr.Inertial),
# (utils_models.Robot.iCub, VelRepr.Body),
# (utils_models.Robot.iCub, VelRepr.Mixed),
p(*[Robot.DoublePendulum, VelRepr.Inertial], id="DoublePendulum-Inertial"),
p(*[Robot.DoublePendulum, VelRepr.Body], id="DoublePendulum-Body"),
p(*[Robot.DoublePendulum, VelRepr.Mixed], id="DoublePendulum-Mixed"),
p(*[Robot.Ur10, VelRepr.Inertial], id="Ur10-Inertial"),
p(*[Robot.Ur10, VelRepr.Body], id="Ur10-Body"),
p(*[Robot.Ur10, VelRepr.Mixed], id="Ur10-Mixed"),
p(*[Robot.AnymalC, VelRepr.Inertial], id="AnymalC-Inertial"),
p(*[Robot.AnymalC, VelRepr.Body], id="AnymalC-Body"),
p(*[Robot.AnymalC, VelRepr.Mixed], id="AnymalC-Mixed"),
p(*[Robot.Cassie, VelRepr.Inertial], id="Cassie-Inertial"),
p(*[Robot.Cassie, VelRepr.Body], id="Cassie-Body"),
p(*[Robot.Cassie, VelRepr.Mixed], id="Cassie-Mixed"),
],
)
def test_aba(robot: utils_models.Robot, vel_repr: VelRepr) -> None:
Expand Down Expand Up @@ -61,15 +59,13 @@ def test_aba(robot: utils_models.Robot, vel_repr: VelRepr) -> None:
tau = model.joint_generalized_forces_targets()

# Compute model acceleration with ABA
jit_enabled = True
fn = jax.jit if jit_enabled else lambda x: x
a_WB_aba, sdd_aba = fn(model.forward_dynamics_aba)(tau=tau)
v̇_WB_aba, s̈_aba = model.forward_dynamics_aba(tau=tau)

# ==============================================
# Compute forward dynamics with dedicated method
# ==============================================

a_WB, sdd = model.forward_dynamics_crb(tau=tau)
v̇_WB, = model.forward_dynamics_crb(tau=tau)

assert sdd.squeeze() == pytest.approx(sdd_aba.squeeze(), abs=0.5)
assert a_WB.squeeze() == pytest.approx(a_WB_aba.squeeze(), abs=0.2)
assert .squeeze() == pytest.approx(s̈_aba.squeeze(), abs=0.5)
assert v̇_WB.squeeze() == pytest.approx(v̇_WB_aba.squeeze(), abs=0.2)

0 comments on commit 1b101d8

Please sign in to comment.