In [1]:
from brax.positional import pipeline
from brax.io import mjcf

from IPython.display import HTML, clear_output
import jax
from jax import numpy as jp
from brax.io import html
import json
import numpy as np
import transforms3d.quaternions as quat
from brax.envs.base import PipelineEnv, State

In [14]:
class HumanoidEnv(PipelineEnv):
    def __init__(self):
        path = 'amp_humanoid.xml'
        sys = mjcf.load(path)
        super().__init__(sys=sys, backend='positional', n_frames=1)

    def reset(self, q: jp.ndarray, qd: jp.ndarray) -> State:
        """Resets the environment to an initial state."""
        pipeline_state = self.pipeline_init(q, qd)
        return pipeline_state

    def step(self, pipeline_state: State, action: jp.ndarray) -> State:
        """Runs one timestep of the environment's dynamics."""
        pipeline_state = self.pipeline_step(pipeline_state, action)
        return pipeline_state

In [15]:
env = HumanoidEnv()

In [16]:
env.sys.actuator.index_set

<bound method Base.index_set of Actuator(q_id=Array([ 7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 23, 22,
       24, 25, 26, 27, 28, 30, 29, 31, 32, 33, 34], dtype=int32), qd_id=Array([ 6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 21,
       23, 24, 25, 26, 27, 29, 28, 30, 31, 32, 33], dtype=int32), ctrl_range=Array([[-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.],
       [-1.,  1.]], dtype=float32), force_range=Array([[-inf,  inf],
       [-inf,  inf],
       [-inf,  inf],
       [-inf,  inf],
       [-i

In [17]:
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)

q = env.sys.init_q
qd = jp.zeros((env.sys.qd_size(),))

rollout = []
state = jit_env_reset(q, qd)
rollout.append(state)
for i in range(500):
    if i % 100 == 0:
        print(f'step: {i}')
    state = jit_env_step(state, np.zeros(env.action_size, dtype=float))
    rollout.append(state)

step: 0
step: 100
step: 200
step: 300
step: 400


In [18]:
# HTML(html.render(env.sys, rollout))

In [10]:
m.dt

Array(0.00555, dtype=float32, weak_type=True)

In [8]:
m.link_names

['pelvis',
 'torso',
 'head',
 'right_upper_arm',
 'right_lower_arm',
 'left_upper_arm',
 'left_lower_arm',
 'right_thigh',
 'right_shin',
 'right_foot',
 'left_thigh',
 'left_shin',
 'left_foot']

In [9]:
m.actuator.gain

Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)