In [1]:
from IPython.display import HTML, clear_output
import jax
from jax import numpy as jp

import brax
from brax.envs import env
from brax.io import json
from brax.io import html

In [2]:
# !python /data/benny_cai/miniconda3/envs/diffmimic/lib/python3.9/site-packages/brax/tools/urdf_converter.py --xml_model_path 'legged_studio_a1_more_colli.urdf' --config_path 'a1'

I0122 12:29:31.939205 139787942938432 urdf_converter.py:58] Loading urdf model from legged_studio_a1_more_colli.urdf


In [1]:
# !python /data/benny_cai/miniconda3/envs/diffmimic/lib/python3.9/site-packages/brax/tools/mujoco_converter.py --xml_model_path 'a1_mjcf.txt' --config_path 'a1_mjcf'

In [14]:
with open('a1', 'r') as file:
    content = file.read()

In [15]:
class A1(env.Env):
    def __init__(self, system_config):
        super().__init__(config=system_config)

    def reset(self, rng: jp.ndarray) -> env.State:
        qp = self.sys.default_qp()
        pos = qp.pos
        info = self.sys.info(qp)
        obs = jp.zeros(1)
        reward, done, zero = jp.zeros(3)
        metrics = {
            'torsoIsUp': zero,
            'torsoHeight': zero
        }
        info = {'rng': rng}
        return env.State(qp, obs, reward, done, metrics, info)

    def step(self, state: env.State, action: jp.ndarray) -> env.State:
        qp, info = self.sys.step(state.qp, action)
        return state.replace(qp=qp)

In [16]:
a1 = A1(content)

In [17]:
env_state = a1.reset(rng=jax.random.PRNGKey(0))
rollout = []

for i in range(10):
    print(i)
    rollout.append(env_state.qp)
    action = 10 * jp.sin(i / 100) * jp.ones(a1.action_size)
    env_state = a1.step(env_state, action)

0
1
2
3
4
5
6
7
8
9


In [18]:
HTML(html.render(a1.sys, rollout))

Convert a1 motion to QP list:

In [33]:
import pickle

file = 'a1_ref_motion/trot.pkl'
with open(file, 'rb') as f:
    data = pickle.load(f)

In [34]:
from brax import QP
import numpy as np

### For visualization only

In [2]:
rollout = []
for roll in data:
    qp = QP(pos=np.vstack((np.array(roll['pos']), [0.0, 0.0, 0.0])), rot=np.vstack((np.array(roll['rot']), [1.0, 0.0, 0.0, 0.0])), 
            vel=np.vstack((np.array(roll['vel']), [0.0, 0.0, 0.0])), ang=np.vstack((np.array(roll['ang']), [0.0, 0.0, 0.0])))
    rollout.append(qp)

In [14]:
# env_state = a1.reset(rng=jax.random.PRNGKey(0))
# HTML(html.render(a1.sys, rollout))

### For training

In [38]:
# serialize
rollout = []
for roll in data:
    p = np.vstack((np.array(roll['pos']), [0.0, 0.0, 0.0])).flatten()
    r = np.vstack((np.array(roll['rot']), [1.0, 0.0, 0.0, 0.0])).flatten()
    v = np.vstack((np.array(roll['vel']), [0.0, 0.0, 0.0])).flatten()
    a = np.vstack((np.array(roll['ang']), [0.0, 0.0, 0.0])).flatten()
    qp = np.hstack((p, r, v, a))
    rollout.append(qp)
rollout = np.array(rollout)

In [39]:
rollout.shape

(198, 182)

In [50]:
rollout[0]

array([ 0.        ,  0.        ,  0.43000001,  0.183     , -0.047     ,
        0.43000001,  0.183     , -0.13205001,  0.43000001,  0.183     ,
       -0.13205001,  0.23      ,  0.183     ,  0.047     ,  0.43000001,
        0.183     ,  0.13205001,  0.43000001,  0.183     ,  0.13205001,
        0.23      , -0.183     , -0.047     ,  0.43000001, -0.183     ,
       -0.13205001,  0.43000001, -0.183     , -0.13205001,  0.23      ,
       -0.183     ,  0.047     ,  0.43000001, -0.183     ,  0.13205001,
        0.43000001, -0.183     ,  0.13205001,  0.23      ,  0.        ,
        0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
        0.        ,  1.        ,  0.        ,  0.        ,  0.        ,
        1.        ,  0.        ,  0.        ,  0.        ,  1.        ,
        0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
        0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
        0.        ,  1.        ,  0.        ,  0.        ,  0.  

In [40]:
np.save('a1_ref_motion/trot.npy', rollout)

In [41]:
file_path = 'a1_ref_motion/trot.npy'
rollout = np.load(file_path)

In [42]:
rollout.shape

(198, 182)

In [44]:
import brax
from jax import numpy as jnp
def deserialize_qp(nparray) -> brax.QP:
    """
    Get QP from a trajectory numpy array
    """
    num_bodies = nparray.shape[-1] // 13    # pos (,3) rot (,4) vel (,3) ang (,3)
    batch_dims = nparray.shape[:-1]
    slices = [num_bodies * x for x in [0, 3, 7, 10, 13]]
    pos = jnp.reshape(nparray[..., slices[0]:slices[1]], batch_dims + (num_bodies, 3))
    rot = jnp.reshape(nparray[..., slices[1]:slices[2]], batch_dims + (num_bodies, 4))
    vel = jnp.reshape(nparray[..., slices[2]:slices[3]], batch_dims + (num_bodies, 3))
    ang = jnp.reshape(nparray[..., slices[3]:slices[4]], batch_dims + (num_bodies, 3))
    return QP(pos=pos, rot=rot, vel=vel, ang=ang)

In [45]:
qp = deserialize_qp(rollout)

In [46]:
qp.pos.shape

(198, 14, 3)

In [51]:
qp.pos

array([[[ 0.00000000e+00,  0.00000000e+00,  4.30000007e-01],
        [ 1.82999998e-01, -4.69999984e-02,  4.30000007e-01],
        [ 1.82999998e-01, -1.32050008e-01,  4.30000007e-01],
        ...,
        [-1.82999998e-01,  1.32050008e-01,  4.30000007e-01],
        [-1.82999998e-01,  1.32050008e-01,  2.30000004e-01],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],

       [[ 1.00316916e-04,  4.21822733e-06,  4.29954261e-01],
        [ 1.83102295e-01, -4.69879508e-02,  4.30069596e-01],
        [ 1.83105931e-01, -1.32037967e-01,  4.30073619e-01],
        ...,
        [-1.82905316e-01,  1.32046387e-01,  4.29866165e-01],
        [-1.84300900e-01,  1.32110357e-01,  2.29871035e-01],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],

       [[ 2.72760342e-04,  1.11565705e-05,  4.29863662e-01],
        [ 1.83277845e-01, -4.69680168e-02,  4.30188805e-01],
        [ 1.83287486e-01, -1.32018030e-01,  4.30200934e-01],
        ...,
        [-1.82742149e-01,  1.32040292e-01,

In [47]:
qp.rot.shape

(198, 14, 4)

In [49]:
qp.ang.shape

(198, 14, 3)

In [15]:
a1.sys.default_qp().ang.shape

(14, 3)

In [11]:
a1.sys.default_qp()

QP(pos=array([[ 0.        ,  0.        ,  0.41992316],
       [ 0.183     , -0.047     ,  0.41992316],
       [ 0.183     , -0.13205   ,  0.41992316],
       [ 0.17751757, -0.13205   ,  0.21999831],
       [ 0.183     ,  0.047     ,  0.41992316],
       [ 0.183     ,  0.13205   ,  0.41992316],
       [ 0.17751757,  0.13205   ,  0.21999831],
       [-0.183     , -0.047     ,  0.41992316],
       [-0.183     , -0.13205   ,  0.41992316],
       [-0.18848243, -0.13205   ,  0.21999831],
       [-0.183     ,  0.047     ,  0.41992316],
       [-0.183     ,  0.13205   ,  0.41992316],
       [-0.18848243,  0.13205   ,  0.21999831],
       [ 0.        ,  0.        ,  0.        ]]), rot=array([[ 1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [ 1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [ 9.99906050e-01,  3.04364422e-18,  1.37073550e-02,
         0.00000000e+00],
       [ 9.99997886e-01, -4.56560494e-19, -2.05616567e-03,
 

In [7]:
dir(a1)

['__abstractmethods__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_abc_impl',
 'action_size',
 'observation_size',
 'reset',
 'step',
 'sys',
 'unwrapped']

In [6]:
a1.action_size

12

In [12]:
dir(a1.sys)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__pytree_ignore__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_pbd_info',
 '_pbd_step',
 '_spring_info',
 '_spring_step',
 'actuators',
 'body',
 'colliders',
 'config',
 'default_angle',
 'default_qp',
 'forces',
 'info',
 'integrator',
 'joints',
 'num_actuators',
 'num_bodies',
 'num_forces_dof',
 'num_joint_dof',
 'num_joints',
 'step',
 'zero_info']

In [15]:
dir(a1.sys.body)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__pytree_ignore__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'active',
 'idx',
 'impulse',
 'index',
 'inertia',
 'mass']

In [20]:
a1.sys.body.index

{'base': 0,
 'FR_hip': 1,
 'FR_thigh': 2,
 'FR_calf': 3,
 'FL_hip': 4,
 'FL_thigh': 5,
 'FL_calf': 6,
 'RR_hip': 7,
 'RR_thigh': 8,
 'RR_calf': 9,
 'RL_hip': 10,
 'RL_thigh': 11,
 'RL_calf': 12,
 'floor': 13}

In [21]:
a1.sys.default_qp()

QP(pos=array([[ 0.        ,  0.        ,  0.42000001],
       [ 0.1805    , -0.047     ,  0.42000001],
       [ 0.1805    , -0.1308    ,  0.42000001],
       [ 0.1805    , -0.1308    ,  0.22      ],
       [ 0.1805    ,  0.047     ,  0.42000001],
       [ 0.1805    ,  0.1308    ,  0.42000001],
       [ 0.1805    ,  0.1308    ,  0.22      ],
       [-0.1805    , -0.047     ,  0.42000001],
       [-0.1805    , -0.1308    ,  0.42000001],
       [-0.1805    , -0.1308    ,  0.22      ],
       [-0.1805    ,  0.047     ,  0.42000001],
       [-0.1805    ,  0.1308    ,  0.42000001],
       [-0.1805    ,  0.1308    ,  0.22      ],
       [ 0.        ,  0.        ,  0.        ]]), rot=array([[1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0.,