In [1]:
from etils import epath
import mujoco
import numpy as np
import mujoco_warp as mjwarp
from mujoco_warp import put_model, put_data

In [2]:
import mujoco_warp
print(mujoco_warp.__file__)

/home/andrew/Documents/Projects/2025-mujoco/mujoco_warp_prs/2025_03_24-csr/mujoco_warp/__init__.py


In [3]:
def fixture(fname: str, keyframe: int = -1, sparse: bool = True):
  path = epath.resource_path("mujoco_warp") / "test_data" / fname
  mjm = mujoco.MjModel.from_xml_path(path.as_posix())
  mjm.opt.jacobian = sparse
  mjd = mujoco.MjData(mjm)
  if keyframe > -1:
    mujoco.mj_resetDataKeyframe(mjm, mjd, keyframe)
  # give the system a little kick to ensure we have non-identity rotations
  mjd.qvel = np.random.uniform(-0.01, 0.01, mjm.nv)
  mujoco.mj_step(mjm, mjd, 3)  # let dynamics get state significantly non-zero
  mujoco.mj_forward(mjm, mjd)
  m = put_model(mjm)
  d = put_data(mjm, mjd)
  return mjm, mjd, m, d

_TOLERANCE = 5e-5

def _assert_eq(a, b, name):
  tol = _TOLERANCE * 10  # avoid test noise
  err_msg = f"mismatch: {name}"
  np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol)

In [4]:
mjm, mjd, m, d = fixture("pendula.xml", sparse=True)

Warp 1.7.0.dev20250324 initialized:
   Git commit: f3814e7e5459e5fd13032cf0fddb3daddd510f30
   CUDA Toolkit 12.8, Driver 12.8
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "NVIDIA GeForce RTX 4060 Ti" (8 GiB, sm_89, mempool enabled)
   Kernel cache:
     /home/andrew/.cache/warp/1.7.0.dev20250324


In [16]:
loop_helper = m.M_loop_helper.numpy()
loop_helper_starts = m.M_loop_helper_starts.numpy()
for k in range(m.nv):
    beg, end = loop_helper_starts[k], loop_helper_starts[k + 1]
    for ind in range(end - beg):
        print(loop_helper[beg + ind])

[96  0]
[96  1]
[96  2]
[96  3]
[96  4]
[96  5]
[96  6]
[95  0]
[95  1]
[95  2]
[95  3]
[95  4]
[95  5]
[94  0]
[94  1]
[94  2]
[94  3]
[94  4]
[93  0]
[93  1]
[93  2]
[93  3]
[92  0]
[92  1]
[92  2]
[91  0]
[91  1]
[90  0]
[88  0]
[88  1]
[88  2]
[88  3]
[88  4]
[88  5]
[87  0]
[87  1]
[87  2]
[87  3]
[87  4]
[86  0]
[86  1]
[86  2]
[86  3]
[85  0]
[85  1]
[85  2]
[84  0]
[84  1]
[83  0]
[81  0]
[81  1]
[81  2]
[81  3]
[81  4]
[80  0]
[80  1]
[80  2]
[80  3]
[79  0]
[79  1]
[79  2]
[78  0]
[78  1]
[77  0]
[75  0]
[75  1]
[75  2]
[75  3]
[74  0]
[74  1]
[74  2]
[73  0]
[73  1]
[72  0]
[70  0]
[70  1]
[70  2]
[69  0]
[69  1]
[68  0]
[66  0]
[66  1]
[65  0]
[63  0]
[60  0]
[60  1]
[60  2]
[59  0]
[59  1]
[58  0]
[56  0]
[56  1]
[56  2]
[55  0]
[55  1]
[54  0]
[52  0]
[52  1]
[51  0]
[49  0]
[46  0]
[46  1]
[45  0]
[43  0]
[40  0]
[40  1]
[40  2]
[39  0]
[39  1]
[38  0]
[36  0]
[36  1]
[35  0]
[33  0]
[30  0]
[25  0]
[25  1]
[24  0]
[22  0]
[19  0]
[19  1]
[19  2]
[19  3]
[19  4]
[18  0]


In [12]:
loop_helper_starts

array([  0,  28,  49,  64,  74,  80,  83,  84,  84,  90,  96,  99, 100,
       100, 103, 104, 104, 110, 113, 114, 114, 115, 115, 115, 115, 118,
       119, 119, 134, 144, 150, 153, 154, 154], dtype=int32)

In [15]:
loop_helper

array([[96,  0],
       [96,  1],
       [96,  2],
       [96,  3],
       [96,  4],
       [96,  5],
       [96,  6],
       [95,  0],
       [95,  1],
       [95,  2],
       [95,  3],
       [95,  4],
       [95,  5],
       [94,  0],
       [94,  1],
       [94,  2],
       [94,  3],
       [94,  4],
       [93,  0],
       [93,  1],
       [93,  2],
       [93,  3],
       [92,  0],
       [92,  1],
       [92,  2],
       [91,  0],
       [91,  1],
       [90,  0],
       [88,  0],
       [88,  1],
       [88,  2],
       [88,  3],
       [88,  4],
       [88,  5],
       [87,  0],
       [87,  1],
       [87,  2],
       [87,  3],
       [87,  4],
       [86,  0],
       [86,  1],
       [86,  2],
       [86,  3],
       [85,  0],
       [85,  1],
       [85,  2],
       [84,  0],
       [84,  1],
       [83,  0],
       [81,  0],
       [81,  1],
       [81,  2],
       [81,  3],
       [81,  4],
       [80,  0],
       [80,  1],
       [80,  2],
       [80,  3],
       [79,  0

In [7]:
m.M_diags.numpy()

array([ 0,  2,  5,  9, 14, 20, 21, 23, 26, 27, 28, 29, 31, 32, 34, 37, 41,
       42, 44, 47, 48, 50, 53, 57, 61, 62, 64, 67, 71, 76, 82, 89, 97],
      dtype=int32)

In [8]:
m.nM

98

In [12]:
diag_inds = np.zeros(m.nM, dtype=np.int32)

i = 0
for k in range(m.nv):
    rowadr_k = mjd.M_rowadr[k]
    diag_k = mjd.M_rowadr[k] + mjd.M_rownnz[k]
    for _ in range(rowadr_k, diag_k):
        diag_inds[i] = diag_k
        i += 1

# import warp as wp
# wp_diag_inds = wp.array(diag_inds, dtype=wp.int32, ndim=1)

In [13]:
diag_inds

array([ 1,  3,  3,  6,  6,  6, 10, 10, 10, 10, 15, 15, 15, 15, 15, 21, 21,
       21, 21, 21, 21, 22, 24, 24, 27, 27, 27, 28, 29, 30, 32, 32, 33, 35,
       35, 38, 38, 38, 42, 42, 42, 42, 43, 45, 45, 48, 48, 48, 49, 51, 51,
       54, 54, 54, 58, 58, 58, 58, 62, 62, 62, 62, 63, 65, 65, 68, 68, 68,
       72, 72, 72, 72, 77, 77, 77, 77, 77, 83, 83, 83, 83, 83, 83, 90, 90,
       90, 90, 90, 90, 90, 98, 98, 98, 98, 98, 98, 98, 98], dtype=int32)