In [37]:
# !python /data/benny_cai/miniconda3/envs/brax/lib/python3.9/site-packages/brax/v1/tools/urdf_converter.py --xml_model_path 'legged_studio_a1_more_colli.urdf' --config_path 'a1'

I0125 20:41:11.544692 140288389826368 urdf_converter.py:58] Loading urdf model from legged_studio_a1_more_colli.urdf


In [1]:
# !python /data/benny_cai/miniconda3/envs/brax/lib/python3.9/site-packages/brax/v1/tools/mujoco_converter.py --xml_model_path 'a1_mjcf.txt' --config_path 'a1_mjcf'

## With interpolation

In [1]:
from brax.positional import pipeline
from brax.io import mjcf

from IPython.display import HTML, clear_output
import jax
from jax import numpy as jp
from brax.io import html
import json
import numpy as np
import copy
import transforms3d.quaternions as quat

In [2]:
from scipy.spatial.transform import Rotation as R
from scipy.spatial.transform import Slerp

In [3]:
with open('a1_mjcf.txt', 'r') as file:
    config = file.read()

m = mjcf.loads(config, asset_path='')

In [4]:
file = 'a1_origin_motions/jump_backward_flip.txt'
with open(file, 'r') as f:
    data = json.load(f)
frames = data['Frames']
frameDuration = data['FrameDuration']
frames = np.array(frames)
frames.shape

(130, 19)

In [5]:
# compute frame velocities
num_frames = frames.shape[0]
frame_vel_size = 18
dt = frameDuration
frames_vels = []

for i in range(frames.shape[0]-1):
    frame0, frame1 = frames[i], frames[i+1]

    root_pos0 = frame0[0:3]
    root_pos1 = frame1[0:3]
    root_rot0 = frame0[3:7]
    root_rot1 = frame1[3:7]
    joints0 = frame0[7:19]
    joints1 = frame1[7:19]
    root_rot0 = [root_rot0[3], root_rot0[0], root_rot0[1], root_rot0[2]]
    root_rot1 = [root_rot1[3], root_rot1[0], root_rot1[1], root_rot1[2]]
    
    root_vel = (root_pos1 - root_pos0) / dt
    root_rot_diff = quat.qmult(root_rot1, quat.qconjugate(root_rot0))
    root_rot_diff_axis, root_rot_diff_angle = quat.quat2axangle(root_rot_diff)
    root_ang_vel = (root_rot_diff_angle / dt) * root_rot_diff_axis
    joints_vel = (joints1 - joints0) / dt
    curr_frame_vel = np.concatenate((root_vel, root_ang_vel, joints_vel))
    frames_vels.append(curr_frame_vel)

# replicate the velocity at the last frame
if num_frames > 1:
  frames_vels.append(frames_vels[-1])

frames_vels = np.array(frames_vels)
frames_vels.shape

(130, 18)

In [6]:
targetDuration = 0.05
interpolated_idx = 0
interp_frames = []
interp_frames_vels = []
for i in range(frames.shape[0]-1):
    if interpolated_idx * targetDuration >= frameDuration * i \
        and interpolated_idx * targetDuration <= frameDuration * (i+1):
        frame0, frame1 = frames[i], frames[i+1]
        frame_vel0, frame_vel1 = frames_vels[i], frames_vels[i+1]
        
        blend = (interpolated_idx * targetDuration - frameDuration * i) / frameDuration

        root_pos0 = frame0[0:3]
        root_pos1 = frame1[0:3]
        root_rot0 = frame0[3:7]
        root_rot1 = frame1[3:7]
        joints0 = frame0[7:19]
        joints1 = frame1[7:19]

        quats = [root_rot0, root_rot1]
        quats = [[quat[3], quat[0], quat[1], quat[2]] for quat in quats]
        key_rots = R.from_quat(quats)
        key_times = [frameDuration * i, frameDuration * (i+1)]
        slerp = Slerp(key_times, key_rots)
        time = interpolated_idx * targetDuration
        interp_rots = slerp(time)
        
        blend_root_pos = (1.0 - blend) * root_pos0 + blend * root_pos1
        blend_joints = (1.0 - blend) * joints0 + blend * joints1
        interp_frames.append(np.hstack((blend_root_pos, interp_rots.as_quat(), blend_joints)))

        blend_frame_vel = (1.0 - blend) * frame_vel0 + blend * frame_vel1
        interp_frames_vels.append(blend_frame_vel)
        interpolated_idx += 1

In [7]:
interp_frames = np.array(interp_frames)
interp_frames.shape

(44, 19)

In [8]:
interp_frames_vels = np.array(interp_frames_vels)
interp_frames_vels.shape

(44, 18)

In [9]:
jit_env_reset = jax.jit(pipeline.init)
jit_env_step = jax.jit(pipeline.step)

rng = jax.random.PRNGKey(seed=1)
rollout = []
for frame, frame_vel in zip(interp_frames, interp_frames_vels):
    state = jit_env_reset(m, frame, frame_vel)
    rollout.append(state)

In [11]:
# HTML(html.render(m.replace(dt=0.002*25), rollout))

In [15]:
rollout[43].xd_i.ang

Array([[ 0.      , -6.051734,  0.      ],
       [ 0.      , -6.051734,  0.      ],
       [ 0.      , -6.051734,  0.      ],
       [ 0.      , -6.051734,  0.      ],
       [ 0.      , -6.051734,  0.      ],
       [ 0.      , -6.051734,  0.      ],
       [ 0.      , -6.051734,  0.      ],
       [ 0.      , -6.051734,  0.      ],
       [ 0.      , -6.051734,  0.      ],
       [ 0.      , -6.051734,  0.      ],
       [ 0.      , -6.051734,  0.      ],
       [ 0.      , -6.051734,  0.      ],
       [ 0.      , -6.051734,  0.      ]], dtype=float32)

In [10]:
# serialize data, positional:
data = []
for roll in rollout:
    QP = np.hstack((
        roll.q, 
        roll.qd, 
        roll.x.pos.flatten(), 
        roll.x.rot.flatten(), 
        roll.xd.vel.flatten(), 
        roll.xd.ang.flatten(),
        roll.x_i.pos.flatten(), 
        roll.x_i.rot.flatten(), 
        roll.xd_i.vel.flatten(), 
        roll.xd_i.ang.flatten(),
        roll.j.pos.flatten(), 
        roll.j.rot.flatten(), 
        roll.jd.vel.flatten(), 
        roll.jd.ang.flatten(),
        roll.a_p.pos.flatten(), 
        roll.a_p.rot.flatten(), 
        roll.a_c.pos.flatten(), 
        roll.a_c.rot.flatten(),
        roll.mass
    ))
    data.append(QP)

In [11]:
data = np.array(data)
data.shape

(321, 739)

In [12]:
np.save('a1_ref_motion/trot.npy', data)

## No interpolation

In [1]:
from brax.positional import pipeline
from brax.io import mjcf

from IPython.display import HTML, clear_output
import jax
from jax import numpy as jp
from brax.io import html
import json
import numpy as np

In [2]:
with open('a1_mjcf.txt', 'r') as file:
    config = file.read()

m = mjcf.loads(config, asset_path='')

In [10]:
file = 'a1_origin_motions/jump08.txt'
with open(file, 'r') as f:
    data = json.load(f)
frames = data['Frames']
frames = np.array(frames)
frames.shape

(526, 19)

In [11]:
jit_env_reset = jax.jit(pipeline.init)
jit_env_step = jax.jit(pipeline.step)

rng = jax.random.PRNGKey(seed=1)
# state = jit_env_reset(m, m.init_q, jp.zeros(m.qd_size()))
rollout = []
for frame in frames:
    frame[3], frame[4], frame[5], frame[6] = frame[6], frame[3], frame[4], frame[5]
    # frame[3], frame[6] = frame[6], frame[3]
    state = jit_env_reset(m, frame, jp.zeros(m.qd_size()))
    rollout.append(state)

In [15]:
# HTML(html.render(m.replace(dt=0.02), rollout))

In [33]:
# # serialize data, generalized:
# data = []
# for roll in rollout:
#     QP = np.hstack((roll.q, roll.qd, roll.x.pos.flatten(), 
#                     roll.x.rot.flatten(), roll.xd.vel.flatten(), roll.xd.ang.flatten(),
#                     roll.root_com.flatten(),
#                     roll.cinr.transform.pos.flatten(),
#                     roll.cinr.transform.rot.flatten(),
#                     roll.cinr.i.flatten(),
#                     roll.cinr.mass,
#                     roll.cd.vel.flatten(),
#                     roll.cd.ang.flatten(),
#                     roll.cdof.vel.flatten(),
#                     roll.cdof.ang.flatten(),
#                     roll.cdofd.vel.flatten(),
#                     roll.cdofd.ang.flatten(),
#                     roll.mass_mx.flatten(),
#                     roll.mass_mx_inv.flatten(),
#                     roll.con_jac.flatten(),
#                     roll.con_diag,
#                     roll.con_aref,
#                     roll.qf_smooth,
#                     roll.qf_constraint,
#                     roll.qdd
#                    ))
#     data.append(QP)

In [10]:
# serialize data, positional:
data = []
for roll in rollout:
    QP = np.hstack((
        roll.q, 
        roll.qd, 
        roll.x.pos.flatten(), 
        roll.x.rot.flatten(), 
        roll.xd.vel.flatten(), 
        roll.xd.ang.flatten(),
        roll.x_i.pos.flatten(), 
        roll.x_i.rot.flatten(), 
        roll.xd_i.vel.flatten(), 
        roll.xd_i.ang.flatten(),
        roll.j.pos.flatten(), 
        roll.j.rot.flatten(), 
        roll.jd.vel.flatten(), 
        roll.jd.ang.flatten(),
        roll.a_p.pos.flatten(), 
        roll.a_p.rot.flatten(), 
        roll.a_c.pos.flatten(), 
        roll.a_c.rot.flatten(),
        roll.mass
    ))
    data.append(QP)

In [11]:
data = np.array(data)
data.shape

(33, 739)

In [12]:
np.save('a1_ref_motion/trot.npy', data)

In [31]:
dir(rollout[0])

['T',
 '__add__',
 '__annotations__',
 '__class__',
 '__dataclass_fields__',
 '__dataclass_params__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__mul__',
 '__ne__',
 '__neg__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__sub__',
 '__subclasshook__',
 '__truediv__',
 '__weakref__',
 '_flax_dataclass',
 'a_c',
 'a_p',
 'concatenate',
 'contact',
 'index_set',
 'index_sum',
 'j',
 'jd',
 'mass',
 'q',
 'qd',
 'replace',
 'reshape',
 'select',
 'slice',
 'take',
 'tree_replace',
 'vmap',
 'x',
 'x_i',
 'xd',
 'xd_i']

In [39]:
rollout[10].contact

In [34]:
rollout[0].x_i.pos - rollout[0].x.pos

Array([[-8.68828793e-05,  4.12178272e-03, -2.51710415e-04],
       [-3.29609215e-03, -7.08241016e-04, -3.27527523e-05],
       [-1.00523382e-02,  2.26075128e-02, -2.53679752e-02],
       [ 1.30997702e-01,  2.60286033e-03,  1.65096223e-02],
       [-3.32412124e-03,  5.62351197e-04, -3.05473804e-05],
       [-2.12564766e-02, -2.30624303e-02, -1.64915025e-02],
       [ 1.10549971e-01,  1.24979019e-03, -7.22280219e-02],
       [ 3.32412124e-03, -5.62336296e-04,  3.01599503e-05],
       [-2.28128880e-02,  2.30522826e-02, -1.42775476e-02],
       [ 9.70614851e-02,  9.59122926e-03, -8.90327096e-02],
       [ 3.29473615e-03,  7.04083592e-04,  1.25974417e-04],
       [-2.35372782e-02, -2.19015554e-02, -1.48997307e-02],
       [ 1.04690775e-01,  7.94543326e-03, -8.01027641e-02]],      dtype=float32)

In [2]:
rollouts = np.load('/data/benny_cai/diffmimic/logs/exp_300_32_0.0003_10_0.4_0_threshold_0.4_True_54_60_600_True_0.02_0.5_0.01_0.01_trot/eval_traj_10.npy')

In [3]:
rollouts.shape

(32, 32, 739)

In [5]:
# rollouts[0]

### Try

In [1]:
from brax.positional import pipeline
from brax.io import mjcf

from IPython.display import HTML, clear_output
import jax
from jax import numpy as jp
from brax.io import html
import json
import numpy as np
import transforms3d.quaternions as quat
from brax.envs.base import PipelineEnv, State

In [2]:
with open('a1_mjcf.txt', 'r') as file:
    config = file.read()

m = mjcf.loads(config, asset_path='')

ctrl_range:  [[-0.802851  0.802851]
 [-1.0472    4.18879 ]
 [-2.69653  -0.916298]
 [-0.802851  0.802851]
 [-1.0472    4.18879 ]
 [-2.69653  -0.916298]
 [-0.802851  0.802851]
 [-1.0472    4.18879 ]
 [-2.69653  -0.916298]
 [-0.802851  0.802851]
 [-1.0472    4.18879 ]
 [-2.69653  -0.916298]]
actuator_biasprm:  [[   0. -100.    0.    0.    0.    0.    0.    0.    0.    0.]
 [   0. -100.    0.    0.    0.    0.    0.    0.    0.    0.]
 [   0. -100.    0.    0.    0.    0.    0.    0.    0.    0.]
 [   0. -100.    0.    0.    0.    0.    0.    0.    0.    0.]
 [   0. -100.    0.    0.    0.    0.    0.    0.    0.    0.]
 [   0. -100.    0.    0.    0.    0.    0.    0.    0.    0.]
 [   0. -100.    0.    0.    0.    0.    0.    0.    0.    0.]
 [   0. -100.    0.    0.    0.    0.    0.    0.    0.    0.]
 [   0. -100.    0.    0.    0.    0.    0.    0.    0.    0.]
 [   0. -100.    0.    0.    0.    0.    0.    0.    0.    0.]
 [   0. -100.    0.    0.    0.    0.    0.    0.    0.    0.

In [15]:
file = 'a1_origin_motions/jump_backward_flip.txt'
with open(file, 'r') as f:
    data = json.load(f)
frames = data['Frames']
frameDuration = data['FrameDuration']
frames = np.array(frames)
frames.shape

(130, 19)

In [16]:
num_frames = frames.shape[0]
frame_vel_size = 18
dt = frameDuration
frames_vels = []

for i in range(frames.shape[0]-1):
    frame0, frame1 = frames[i], frames[i+1]

    root_pos0 = frame0[0:3]
    root_pos1 = frame1[0:3]
    root_rot0 = frame0[3:7]
    root_rot1 = frame1[3:7]
    joints0 = frame0[7:19]
    joints1 = frame1[7:19]
    root_rot0 = [root_rot0[3], root_rot0[0], root_rot0[1], root_rot0[2]]
    root_rot1 = [root_rot1[3], root_rot1[0], root_rot1[1], root_rot1[2]]
    
    root_vel = (root_pos1 - root_pos0) / dt
    root_rot_diff = quat.qmult(root_rot1, quat.qconjugate(root_rot0))
    root_rot_diff_axis, root_rot_diff_angle = quat.quat2axangle(root_rot_diff)
    root_ang_vel = (root_rot_diff_angle / dt) * root_rot_diff_axis
    joints_vel = (joints1 - joints0) / dt
    curr_frame_vel = np.concatenate((root_vel, root_ang_vel, joints_vel))
    frames_vels.append(curr_frame_vel)

# replicate the velocity at the last frame
if num_frames > 1:
  frames_vels.append(frames_vels[-1])

frames_vels = np.array(frames_vels)
frames_vels.shape

(130, 18)

In [17]:
jit_env_reset = jax.jit(pipeline.init)
jit_env_step = jax.jit(pipeline.step)

rollout = []
for frame, frame_vel in zip(frames, frames_vels):
    frame[3], frame[4], frame[5], frame[6] = frame[6], frame[3], frame[4], frame[5]
    # state = jit_env_reset(m, frame, 20*jp.ones(m.qd_size())) # try
    state = jit_env_reset(m, frame, frame_vel)
    rollout.append(state)

In [19]:
# HTML(html.render(m.replace(dt=0.02), rollout))

In [5]:
# for state in rollout:
#     print("##############")
#     print("x: ", state.x.pos[1])
#     print("x_i: ", state.x_i.pos[1])    

In [12]:
rollout[80].x.pos

Array([[-0.23447   ,  0.        ,  0.51898   ],
       [-0.22583362, -0.04700003,  0.7017762 ],
       [-0.22583362, -0.13205008,  0.7017762 ],
       [-0.11724531, -0.13205008,  0.53382194],
       [-0.22583362,  0.04700003,  0.7017762 ],
       [-0.22583362,  0.13205008,  0.7017762 ],
       [-0.11724531,  0.13205008,  0.53382194],
       [-0.24310637, -0.04700003,  0.33618385],
       [-0.24310637, -0.13205008,  0.33618385],
       [-0.33995968, -0.13205008,  0.16119963],
       [-0.24310637,  0.04700003,  0.33618385],
       [-0.24310637,  0.13205008,  0.33618385],
       [-0.33995968,  0.13205008,  0.16119963]], dtype=float32)

In [13]:
rollout[80].x_i.pos

Array([[-0.23397055,  0.0041    ,  0.5189564 ],
       [-0.22602084, -0.04763503,  0.69847035],
       [-0.21371555, -0.10972308,  0.67707115],
       [-0.09109609, -0.13205008,  0.6632668 ],
       [-0.22602084,  0.04763503,  0.69847035],
       [-0.21371555,  0.10972308,  0.67707115],
       [-0.09109609,  0.13205008,  0.6632668 ],
       [-0.24298108, -0.04763503,  0.33949262],
       [-0.25917155, -0.10972308,  0.31384334],
       [-0.23711628, -0.13205008,  0.07835653],
       [-0.24298108,  0.04763503,  0.33949262],
       [-0.25917155,  0.10972308,  0.31384334],
       [-0.23711628,  0.13205008,  0.07835653]], dtype=float32)

In [20]:
rollout[80].xd_i.ang

Array([[-19.03386 ,  20.000011,  20.921593],
       [-18.089993,  20.000011,  40.89932 ],
       [-18.089993,  40.000023,  40.89932 ],
       [-18.089993,  60.000034,  40.89932 ],
       [-18.089993,  20.000011,  40.89932 ],
       [-18.089993,  40.000023,  40.89932 ],
       [-18.089993,  60.000034,  40.89932 ],
       [-18.089993,  20.000011,  40.89932 ],
       [-18.089993,  40.000023,  40.89932 ],
       [-18.089993,  60.00004 ,  40.89932 ],
       [-18.089993,  20.000011,  40.89932 ],
       [-18.089993,  40.000023,  40.89932 ],
       [-18.089993,  60.00004 ,  40.89932 ]], dtype=float32)

In [9]:
rollout[80].xd.ang

Array([[ 0.        , -6.324409  ,  0.        ],
       [ 0.        , -6.324409  ,  0.        ],
       [ 0.        , -5.948884  ,  0.        ],
       [ 0.        , -7.412592  ,  0.        ],
       [ 0.        , -6.324409  ,  0.        ],
       [ 0.        , -5.948884  ,  0.        ],
       [ 0.        , -7.412592  ,  0.        ],
       [ 0.        , -6.324409  ,  0.        ],
       [ 0.        , -0.9296856 ,  0.        ],
       [ 0.        , -0.07485604,  0.        ],
       [ 0.        , -6.324409  ,  0.        ],
       [ 0.        , -0.9296856 ,  0.        ],
       [ 0.        , -0.07485604,  0.        ]], dtype=float32)

In [10]:
rollout[80].xd_i.vel

Array([[-0.27399588,  0.        ,  0.96836567],
       [-1.4093155 ,  0.        ,  1.0186429 ],
       [-1.2832555 ,  0.        ,  1.0919158 ],
       [-1.3906043 ,  0.        ,  1.8596396 ],
       [-1.4093155 ,  0.        ,  1.0186429 ],
       [-1.2832555 ,  0.        ,  1.0919158 ],
       [-1.3906043 ,  0.        ,  1.8596396 ],
       [ 0.8610067 ,  0.        ,  0.91137946],
       [ 0.9027024 ,  0.        ,  0.8956515 ],
       [ 1.0508143 ,  0.        ,  0.8282424 ],
       [ 0.8610067 ,  0.        ,  0.91137946],
       [ 0.9027024 ,  0.        ,  0.8956515 ],
       [ 1.0508143 ,  0.        ,  0.8282424 ]], dtype=float32)

In [11]:
rollout[80].xd.vel

Array([[-0.27414516,  0.        ,  0.965207  ],
       [-1.430223  ,  0.        ,  1.0198269 ],
       [-1.430223  ,  0.        ,  1.0198269 ],
       [-0.43108255,  0.        ,  1.6658062 ],
       [-1.430223  ,  0.        ,  1.0198269 ],
       [-1.430223  ,  0.        ,  1.0198269 ],
       [-0.43108255,  0.        ,  1.6658062 ],
       [ 0.88193274,  0.        ,  0.9105871 ],
       [ 0.88193274,  0.        ,  0.9105871 ],
       [ 1.044613  ,  0.        ,  0.82054394],
       [ 0.88193274,  0.        ,  0.9105871 ],
       [ 0.88193274,  0.        ,  0.9105871 ],
       [ 1.044613  ,  0.        ,  0.82054394]], dtype=float32)

### Random try

In [1]:
from brax.positional import pipeline
from brax.io import mjcf

from IPython.display import HTML, clear_output
import jax
from jax import numpy as jp
from brax.io import html
import json
import numpy as np
import transforms3d.quaternions as quat
from brax.envs.base import PipelineEnv, State

In [2]:
with open('a1_mjcf.txt', 'r') as file:
    config = file.read()

m = mjcf.loads(config, asset_path='')

force_range = np.full((12, 2), [-50.0, 50.0], dtype=np.float32)
bias_qd = -10 * np.ones(12, dtype=np.float32)
m = m.replace(actuator=m.actuator.replace(force_range=force_range, bias_qd=bias_qd))

In [1]:
# # jit_env_reset = jax.jit(pipeline.init)
# # jit_env_step = jax.jit(pipeline.step)
# jit_env_reset = pipeline.init
# jit_env_step = pipeline.step

# rollout = []
# rng = jax.random.PRNGKey(seed=1)
# state = jit_env_reset(m, m.init_q, jp.zeros(m.qd_size()))
# for i in range(5):
#     rollout.append(state)
#     act = 10 * jp.sin(i / 100) * jp.ones(m.act_size())
#     print(act)
#     state = jit_env_step(m, state, act)

In [4]:
HTML(html.render(m, rollout))

In [3]:
dir(m)

['T',
 '__add__',
 '__annotations__',
 '__class__',
 '__dataclass_fields__',
 '__dataclass_params__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__mul__',
 '__ne__',
 '__neg__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__sub__',
 '__subclasshook__',
 '__truediv__',
 '__weakref__',
 '_flax_dataclass',
 '_model',
 'act_size',
 'actuator',
 'ang_damping',
 'baumgarte_erp',
 'collide_scale',
 'concatenate',
 'density',
 'dof',
 'dof_link',
 'dof_ranges',
 'dt',
 'enable_fluid',
 'geom_masks',
 'geoms',
 'get_mjx_model',
 'get_model',
 'gravity',
 'index_set',
 'index_sum',
 'init_q',
 'joint_scale_ang',
 'joint_scale_pos',
 'link',
 'link_names',
 'link_parents',
 'link_types',
 'matrix_inv_iterations',
 'num_links',
 'q_idx',
 'q_size',
 'qd_idx',
 'qd_size',
 'rep

In [34]:
m.link_names

['trunk',
 'FR_hip',
 'FR_thigh',
 'FR_calf',
 'FL_hip',
 'FL_thigh',
 'FL_calf',
 'RR_hip',
 'RR_thigh',
 'RR_calf',
 'RL_hip',
 'RL_thigh',
 'RL_calf']

In [31]:
dir(m.actuator)

['T',
 '__add__',
 '__annotations__',
 '__class__',
 '__dataclass_fields__',
 '__dataclass_params__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__mul__',
 '__ne__',
 '__neg__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__sub__',
 '__subclasshook__',
 '__truediv__',
 '__weakref__',
 '_flax_dataclass',
 'bias_q',
 'bias_qd',
 'concatenate',
 'ctrl_range',
 'force_range',
 'gain',
 'gear',
 'index_set',
 'index_sum',
 'q_id',
 'qd_id',
 'replace',
 'reshape',
 'select',
 'slice',
 'take',
 'tree_replace',
 'vmap']

In [3]:
m.actuator.index_set

<bound method Base.index_set of Actuator(q_id=Array([ 7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18], dtype=int32), qd_id=Array([ 6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17], dtype=int32), ctrl_range=Array([[-0.802851,  0.802851],
       [-1.0472  ,  4.18879 ],
       [-2.69653 , -0.916298],
       [-0.802851,  0.802851],
       [-1.0472  ,  4.18879 ],
       [-2.69653 , -0.916298],
       [-0.802851,  0.802851],
       [-1.0472  ,  4.18879 ],
       [-2.69653 , -0.916298],
       [-0.802851,  0.802851],
       [-1.0472  ,  4.18879 ],
       [-2.69653 , -0.916298]], dtype=float32), force_range=array([[-50.,  50.],
       [-50.,  50.],
       [-50.,  50.],
       [-50.,  50.],
       [-50.,  50.],
       [-50.,  50.],
       [-50.,  50.],
       [-50.,  50.],
       [-50.,  50.],
       [-50.,  50.],
       [-50.,  50.],
       [-50.,  50.]], dtype=float32), gain=Array([100., 100., 100., 100., 100., 100., 100., 100., 100., 100., 100.,
       100.], dtype=float32), gear=Array([1., 1.,

In [5]:
m.actuator.q_id

Array([ 7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18], dtype=int32)

In [6]:
m.actuator.qd_id

Array([ 6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17], dtype=int32)

In [7]:
m.actuator.ctrl_range

Array([[-0.802851,  0.802851],
       [-1.0472  ,  4.18879 ],
       [-2.69653 , -0.916298],
       [-0.802851,  0.802851],
       [-1.0472  ,  4.18879 ],
       [-2.69653 , -0.916298],
       [-0.802851,  0.802851],
       [-1.0472  ,  4.18879 ],
       [-2.69653 , -0.916298],
       [-0.802851,  0.802851],
       [-1.0472  ,  4.18879 ],
       [-2.69653 , -0.916298]], dtype=float32)

In [8]:
m.actuator.force_range

Array([[-33.5,  33.5],
       [-33.5,  33.5],
       [-33.5,  33.5],
       [-33.5,  33.5],
       [-33.5,  33.5],
       [-33.5,  33.5],
       [-33.5,  33.5],
       [-33.5,  33.5],
       [-33.5,  33.5],
       [-33.5,  33.5],
       [-33.5,  33.5],
       [-33.5,  33.5]], dtype=float32)

In [9]:
m.actuator.gain

Array([100., 100., 100., 100., 100., 100., 100., 100., 100., 100., 100.,
       100.], dtype=float32)

In [10]:
m.actuator.gear

Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)

In [11]:
m.actuator.bias_q

Array([-100., -100., -100., -100., -100., -100., -100., -100., -100.,
       -100., -100., -100.], dtype=float32)

In [12]:
m.actuator.bias_qd

array([-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
       -10.], dtype=float32)

In [32]:
m.gravity

Array([ 0.  ,  0.  , -9.81], dtype=float32)

In [39]:
m.ang_damping

Array(0., dtype=float32)

In [40]:
m.baumgarte_erp

Array(0.1, dtype=float32)

In [42]:
m.vel_damping

Array(0., dtype=float32)

In [41]:
m.joint_scale_ang

Array(0.2, dtype=float32)

In [63]:
m.link.inertia

Inertia(transform=Transform(pos=Array([[ 0.00000e+00,  4.10000e-03, -5.00000e-04],
       [-3.31100e-03, -6.35000e-04,  3.10000e-05],
       [-3.23700e-03,  2.23270e-02, -2.73260e-02],
       [ 4.72659e-03,  0.00000e+00, -1.31975e-01],
       [-3.31100e-03,  6.35000e-04,  3.10000e-05],
       [-3.23700e-03, -2.23270e-02, -2.73260e-02],
       [ 4.72659e-03,  0.00000e+00, -1.31975e-01],
       [ 3.31100e-03, -6.35000e-04,  3.10000e-05],
       [-3.23700e-03,  2.23270e-02, -2.73260e-02],
       [ 4.72659e-03,  0.00000e+00, -1.31975e-01],
       [ 3.31100e-03,  6.35000e-04,  3.10000e-05],
       [-3.23700e-03, -2.23270e-02, -2.73260e-02],
       [ 4.72659e-03,  0.00000e+00, -1.31975e-01]], dtype=float32), rot=Array([[-1.8256460e-03,  7.0638072e-01, -6.4208568e-04,  7.0782948e-01],
       [ 5.0752789e-01,  5.0626791e-01,  4.9150690e-01,  4.9449891e-01],
       [ 9.9912524e-01, -2.5639306e-03, -4.0953111e-02, -8.0609117e-03],
       [ 7.0688641e-01,  1.7653009e-02,  1.7653009e-02,  7.068864

In [45]:
dir(m.get_model())

['__class__',
 '__copy__',
 '__deepcopy__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '_address',
 'actuator',
 'actuator_acc0',
 'actuator_actadr',
 'actuator_actearly',
 'actuator_actlimited',
 'actuator_actnum',
 'actuator_actrange',
 'actuator_biasprm',
 'actuator_biastype',
 'actuator_cranklength',
 'actuator_ctrllimited',
 'actuator_ctrlrange',
 'actuator_dynprm',
 'actuator_dyntype',
 'actuator_forcelimited',
 'actuator_forcerange',
 'actuator_gainprm',
 'actuator_gaintype',
 'actuator_gear',
 'actuator_group',
 'actuator_length0',
 'actuator_lengthrange',
 'actuator_plugin',
 'actuator_trnid',
 'actuator_trntype',
 'actuator_user',
 'body',
 'body_bvhadr',
 'body_

In [48]:
m.get_model().actuator_gaintype

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)

In [25]:
import numpy as np
import brax
import jax.numpy as jnp
from brax.v1 import QP

In [26]:
def deserialize_qp(nparray) -> brax.v1.QP:
    """
    Get QP from a trajectory numpy array
    """
    num_bodies = nparray.shape[-1] // 13    # pos (,3) rot (,4) vel (,3) ang (,3)
    batch_dims = nparray.shape[:-1]
    slices = [num_bodies * x for x in [0, 3, 7, 10, 13]]
    pos = jnp.reshape(nparray[..., slices[0]:slices[1]], batch_dims + (num_bodies, 3))
    rot = jnp.reshape(nparray[..., slices[1]:slices[2]], batch_dims + (num_bodies, 4))
    vel = jnp.reshape(nparray[..., slices[2]:slices[3]], batch_dims + (num_bodies, 3))
    ang = jnp.reshape(nparray[..., slices[3]:slices[4]], batch_dims + (num_bodies, 3))
    return QP(pos=pos, rot=rot, vel=vel, ang=ang)

In [27]:
demo_traj = jnp.array(np.load("/data/benny_cai/diffmimic/data/demo_humanoid/backflip_fps_30_20s.npy"))
traj = deserialize_qp(demo_traj)

In [28]:
traj.pos[0]

Array([[-9.17549152e-03, -5.00592403e-04,  9.06869233e-01],
       [ 9.38418088e-04, -1.45998541e-02,  1.14238191e+00],
       [ 9.36635770e-03, -2.80493423e-02,  1.36571264e+00],
       [-1.27840545e-02, -2.12204292e-01,  1.37511373e+00],
       [ 2.42230911e-02, -2.33315229e-01,  1.10364866e+00],
       [-1.50721595e-02,  1.53341085e-01,  1.39721406e+00],
       [ 3.83804031e-02,  2.33637244e-01,  1.13991308e+00],
       [-8.54627974e-03, -8.52319375e-02,  9.01769638e-01],
       [-4.61073928e-02, -8.45008269e-02,  4.81901020e-01],
       [-8.92889574e-02, -8.38480517e-02,  7.43125677e-02],
       [-9.80470423e-03,  8.42307508e-02,  9.11968768e-01],
       [-1.86487865e-02,  9.42028016e-02,  4.90633547e-01],
       [-4.56155650e-02,  1.04159504e-01,  8.17728490e-02],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],      dtype=float32)

In [29]:
traj.pos[1]

Array([[-2.2039369e-02, -7.6524785e-04,  9.2189753e-01],
       [-9.9811275e-03, -1.5006463e-02,  1.1573101e+00],
       [ 2.2919516e-03, -2.8476859e-02,  1.3804612e+00],
       [-1.8280035e-02, -2.1280058e-01,  1.3901373e+00],
       [ 1.7896939e-02, -2.6992917e-01,  1.1237992e+00],
       [-2.3010259e-02,  1.5270777e-01,  1.4124612e+00],
       [ 2.7267911e-02,  2.7206999e-01,  1.1701119e+00],
       [-2.0650990e-02, -8.5481696e-02,  9.1670150e-01],
       [-4.5604337e-02, -8.8515274e-02,  4.9590564e-01],
       [-7.5220779e-02, -9.1538496e-02,  8.7118223e-02],
       [-2.3427747e-02,  8.3951198e-02,  9.2709351e-01],
       [-1.7760465e-02,  9.7381786e-02,  5.0579965e-01],
       [-4.0345814e-02,  1.1069594e-01,  9.6769013e-02],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00]], dtype=float32)

In [30]:
(traj.pos[2] - traj.pos[1]) * 30

Array([[-0.48628327, -0.01009895,  0.17979383],
       [-0.3943685 ,  0.00349361,  0.17529488],
       [-0.18615729,  0.0177104 ,  0.1614368 ],
       [-0.14068264,  0.01433417,  0.1889205 ],
       [-0.52025527, -1.4238203 ,  0.59458137],
       [-0.19190976,  0.01434401,  0.17582059],
       [-0.7501001 ,  1.4015625 ,  0.93684554],
       [-0.46897197, -0.00994742,  0.18294454],
       [ 0.11530571, -0.1641532 ,  0.16388476],
       [ 0.79951924, -0.3124246 ,  0.13542674],
       [-0.5035946 , -0.01025051,  0.17664671],
       [ 0.16844228,  0.12190618,  0.20849705],
       [ 0.27730715,  0.260166  ,  0.20824805],
       [ 0.        ,  0.        ,  0.        ]], dtype=float32)

In [31]:
traj.vel[0]

Array([[-0.469554  , -0.00797463,  0.12001988],
       [-0.4673155 , -0.00796742,  0.35616028],
       [-0.4634365 , -0.00799305,  0.58002067],
       [-0.48660314, -0.19117548,  0.60001004],
       [-0.49555668, -0.20990266,  0.3260072 ],
       [-0.48768312,  0.17504293,  0.6000707 ],
       [-0.4965914 ,  0.19719788,  0.3263222 ],
       [-0.46926656, -0.09286114,  0.12001975],
       [-0.45745462, -0.0944721 , -0.30135766],
       [-0.4536336 , -0.09587395, -0.71120745],
       [-0.46984145,  0.07691189,  0.12002002],
       [-0.45746323,  0.07876797, -0.30134013],
       [-0.45653316,  0.0800641 , -0.71120703],
       [ 0.        ,  0.        ,  0.        ]], dtype=float32)

In [32]:
traj.vel[1]

Array([[-0.49148667, -0.00529936,  0.06291912],
       [-0.48878595, -0.00525952,  0.29905465],
       [-0.48426947, -0.00529149,  0.5229031 ],
       [-0.5074469 , -0.18846576,  0.54295456],
       [-0.5190404 , -0.20764497,  0.26908198],
       [-0.5083911 ,  0.17775302,  0.5430259 ],
       [-0.51970327,  0.20125186,  0.26947835],
       [-0.49124068, -0.090186  ,  0.06293063],
       [-0.47741437, -0.09173141, -0.35838574],
       [-0.47450975, -0.09296793, -0.7682436 ],
       [-0.4917327 ,  0.07958728,  0.06290761],
       [-0.4775209 ,  0.08145858, -0.35839462],
       [-0.47712597,  0.08254933, -0.768263  ],
       [ 0.        ,  0.        ,  0.        ]], dtype=float32)

In [36]:
traj.ang[30]

Array([[-1.4424235e-01, -8.2465286e+00, -3.3158797e-01],
       [-9.6879974e-02, -8.4499722e+00, -9.2401201e-01],
       [ 2.8955668e-01, -8.2368488e+00, -4.8967814e-01],
       [-1.2309293e+00, -8.4750319e+00, -5.7500887e-01],
       [-1.9217244e+00, -8.2494555e+00, -5.8249003e-01],
       [ 6.5171719e-01, -8.8281488e+00,  7.6852733e-01],
       [ 1.6910141e+00, -8.6498804e+00,  5.4413885e-01],
       [-5.8439988e-01, -1.1167219e+01,  6.6201389e-02],
       [-3.3837408e-01, -1.0865814e+01, -4.5214754e-01],
       [-9.9272108e-01, -1.2105953e+01, -3.3946621e-01],
       [-7.5331284e-03, -1.1034703e+01, -3.3094007e-01],
       [-8.1315078e-02, -1.0634511e+01,  7.3646523e-02],
       [ 4.9581072e-01, -1.1932925e+01, -1.1791527e-01],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00]], dtype=float32)