In [2]:
import agents_env
from agents_env.agent_replay_motion import HumanoidReplay
from utils.SimpleConverter import SimpleConverter
from utils.util_data import *

In [3]:

from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model
from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
from jax import vmap

In [4]:
envs.register_environment('humanoidReplay', HumanoidReplay)

In [5]:
trajectory = SimpleConverter('motions/humanoid3d_punch.txt')
trajectory.load_mocap()
model_path = 'models/final_humanoid.xml'

In [6]:
data_mocap_matrix = jp.asarray(trajectory.data)
data_pos_mocap = jp.asarray(trajectory.data_pos)
data_vel_mocap = jp.asarray(trajectory.data_vel)

data_dict_mocap = trajectory.duration_dict

In [7]:
data_dict_mocap

{0: [0.0, 0.0333333015],
 1: [0.03333330154418945, 0.0333333015],
 2: [0.0666666030883789, 0.0333333015],
 3: [0.09999990463256836, 0.0333333015],
 4: [0.1333332061767578, 0.0333333015],
 5: [0.16666650772094727, 0.0333333015],
 6: [0.19999980926513672, 0.0333333015],
 7: [0.23333311080932617, 0.0333333015],
 8: [0.2666664123535156, 0.0333333015],
 9: [0.2999997138977051, 0.0333333015],
 10: [0.33333301544189453, 0.0333333015],
 11: [0.366666316986084, 0.0333333015],
 12: [0.39999961853027344, 0.0333333015],
 13: [0.4333329200744629, 0.0333333015],
 14: [0.46666622161865234, 0.0333333015],
 15: [0.4999995231628418, 0.0333333015],
 16: [0.5333328247070312, 0.0333333015],
 17: [0.5666661262512207, 0.0333333015],
 18: [0.5999994277954102, 0.0333333015],
 19: [0.6333327293395996, 0.0333333015],
 20: [0.6666660308837891, 0.0333333015],
 21: [0.6999993324279785, 0.0333333015],
 22: [0.733332633972168, 0.0333333015],
 23: [0.7666659355163574, 0.0333333015],
 24: [0.7999992370605469, 0.0333333

In [8]:
env_name = 'humanoidReplay'
env = envs.get_environment(env_name=env_name,
                           reference_trajectory_qpos=data_pos_mocap,
                           reference_trajectory_qvel = data_vel_mocap,
                            duration_trajectory=trajectory.total_time,
                            dict_duration= data_dict_mocap,
                           model_path=model_path)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [9]:
env.reference_trajectory_qpos.shape

(65, 35)

In [10]:
def visualizer(rollout):
    media.show_video(env.render(rollout, camera='back'), fps=1.0/env.dt)

In [26]:
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

for k in range(len(env.dict_duration)):
    print(state.pipeline_state.x.pos)
    
    print('rotation')
    print(state.pipeline_state.x.rot)
    print('end')
    #print(state.pipeline_state.q)
    
    qpos = state.pipeline_state.qpos
    tmp_pos = env.reference_trajectory_qpos[k]
    updated = qpos.at[:].set(tmp_pos[:])
    
    state = state.tree_replace({'pipeline_state.qpos': updated})
    ctrl = -0.1 * jp.ones(env.sys.nv)
    state = jit_step(state, ctrl)
    
    #print(state.pipeline_state.time)
    rollout.append(state.pipeline_state)

[[ 0.          0.          0.9       ]
 [ 0.          0.          1.136151  ]
 [ 0.          0.          1.360045  ]
 [-0.02405    -0.18311     1.379651  ]
 [-0.02405    -0.18311     1.1048629 ]
 [-0.02405     0.18311     1.379651  ]
 [-0.02405     0.18311     1.1048629 ]
 [ 0.         -0.084887    0.9       ]
 [ 0.         -0.084887    0.47845396]
 [ 0.         -0.084887    0.06858397]
 [ 0.          0.084887    0.9       ]
 [ 0.          0.084887    0.47845396]
 [ 0.          0.084887    0.06858397]]
rotation
[[1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]]
end
[[ 0.0000000e+00  0.0000000e+00  7.5521302e-01]
 [ 2.2348549e-02 -7.3908614e-03  9.9018794e-01]
 [ 1.1783108e-01 -4.9806852e-03  1.1926868e+00]
 [ 9.8556340e-02 -1.8673937e-01  1.2256161e+00]
 [-1.0900904e-01 -3.2336295e-01  1.1083171e+00]
 [ 1.1034919e-01  1.7915639e-01  1.2157006e

In [12]:
visualizer(rollout)

0
This browser does not support the video tag.


Inital positions and rotations


In [13]:
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

In [14]:
state.pipeline_state.x.pos.shape

(13, 3)

In [15]:
#it has the same ros as x.pos
body_initial_config = {name:np.array(state.pipeline_state.x.pos[i]) for i,name in enumerate(BODIES[1:])}
body_initial_config

{'root': array([0. , 0. , 0.9], dtype=float32),
 'chest': array([0.      , 0.      , 1.136151], dtype=float32),
 'neck': array([0.      , 0.      , 1.360045], dtype=float32),
 'right_shoulder': array([-0.02405 , -0.18311 ,  1.379651], dtype=float32),
 'right_elbow': array([-0.02405  , -0.18311  ,  1.1048629], dtype=float32),
 'left_shoulder': array([-0.02405 ,  0.18311 ,  1.379651], dtype=float32),
 'left_elbow': array([-0.02405  ,  0.18311  ,  1.1048629], dtype=float32),
 'right_hip': array([ 0.      , -0.084887,  0.9     ], dtype=float32),
 'right_knee': array([ 0.        , -0.084887  ,  0.47845396], dtype=float32),
 'right_ankle': array([ 0.        , -0.084887  ,  0.06858397], dtype=float32),
 'left_hip': array([0.      , 0.084887, 0.9     ], dtype=float32),
 'left_knee': array([0.        , 0.084887  , 0.47845396], dtype=float32),
 'left_ankle': array([0.        , 0.084887  , 0.06858397], dtype=float32)}

In [25]:
state.pipeline_state.x.rot

Array([[1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.]], dtype=float32)

In [17]:
#it has the same ros as x.pos
body_initial_config_rot = {name:np.array(state.pipeline_state.x.rot[i]) for i,name in enumerate(BODIES[1:])}
body_initial_config_rot

{'root': array([1., 0., 0., 0.], dtype=float32),
 'chest': array([1., 0., 0., 0.], dtype=float32),
 'neck': array([1., 0., 0., 0.], dtype=float32),
 'right_shoulder': array([1., 0., 0., 0.], dtype=float32),
 'right_elbow': array([1., 0., 0., 0.], dtype=float32),
 'left_shoulder': array([1., 0., 0., 0.], dtype=float32),
 'left_elbow': array([1., 0., 0., 0.], dtype=float32),
 'right_hip': array([1., 0., 0., 0.], dtype=float32),
 'right_knee': array([1., 0., 0., 0.], dtype=float32),
 'right_ankle': array([1., 0., 0., 0.], dtype=float32),
 'left_hip': array([1., 0., 0., 0.], dtype=float32),
 'left_knee': array([1., 0., 0., 0.], dtype=float32),
 'left_ankle': array([1., 0., 0., 0.], dtype=float32)}

Exploring the joints hierchacy 

In [18]:

with open('models/final_humanoid_no_gravity.xml', 'r') as file:
    xml = file.read()
# Make model, data, and renderer
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)

In [19]:
env.sys.link_parents

(-1, 0, 1, 1, 3, 1, 5, 0, 7, 8, 0, 10, 11)

In [20]:
#create a dictionary with the index
index_to_body = {index:name for index,name in enumerate(BODIES[1:])}
index_to_body

{0: 'root',
 1: 'chest',
 2: 'neck',
 3: 'right_shoulder',
 4: 'right_elbow',
 5: 'left_shoulder',
 6: 'left_elbow',
 7: 'right_hip',
 8: 'right_knee',
 9: 'right_ankle',
 10: 'left_hip',
 11: 'left_knee',
 12: 'left_ankle'}

In [21]:
bodies_joints = dict()
#start from the root
bodies_joints[BODIES[1]] = [None,0]

for index,name in enumerate(BODIES[2:]):
    parent_id = mj_model.body(name).parentid
    #the parentid is an array so we want the fist index only
    parent_id = parent_id[0]
    bodies_joints[name] = [index_to_body[parent_id-1],parent_id-1]

bodies_joints

{'root': [None, 0],
 'chest': ['root', 0],
 'neck': ['chest', 1],
 'right_shoulder': ['chest', 1],
 'right_elbow': ['right_shoulder', 3],
 'left_shoulder': ['chest', 1],
 'left_elbow': ['left_shoulder', 5],
 'right_hip': ['root', 0],
 'right_knee': ['right_hip', 7],
 'right_ankle': ['right_knee', 8],
 'left_hip': ['root', 0],
 'left_knee': ['left_hip', 10],
 'left_ankle': ['left_knee', 11]}

In [22]:
trajectory.all_states

[{'root_pos': array([0.      , 0.      , 0.755213]),
  'root_rot': array([-9.98753896e-01, -1.56329212e-02, -4.73889973e-02,  7.41752800e-04]),
  'chest': array([ 0.98463402, -0.02421333,  0.17156441, -0.02179934]),
  'neck': array([ 0.92988643, -0.10478597, -0.19867835,  0.29130402]),
  'right_hip': array([ 0.9313745 , -0.07018205, -0.27435029, -0.2287967 ]),
  'right_knee': array([-1.23311177]),
  'right_ankle': array([ 0.96711253, -0.11525968, -0.15977069, -0.16087852]),
  'right_shoulder': array([ 0.87983352, -0.18093009,  0.33238639,  0.28753532]),
  'right_elbow': array([1.76083421]),
  'left_hip': array([ 0.88979552,  0.08174648, -0.42077278,  0.15662605]),
  'left_knee': array([-0.97332998]),
  'left_ankle': array([9.99958596e-01, 7.65695400e-04, 9.03625060e-03, 7.51972600e-04]),
  'left_shoulder': array([ 0.90963906,  0.35224765, -0.18458766, -0.12002407]),
  'left_elbow': array([0.99051862])},
 {'root_pos': array([0.00697176, 0.00595309, 0.752527  ]),
  'root_rot': array([-0.

In [23]:
trajectory.data_pos

[array([ 0.00000000e+00,  0.00000000e+00,  7.55213000e-01, -9.98753896e-01,
        -1.56329212e-02, -4.73889973e-02,  7.41752800e-04, -4.27445096e-02,
         3.45760197e-01, -3.68065505e-02, -8.77808113e-02, -4.45097357e-01,
         5.87292821e-01, -6.20093130e-01,  5.01614525e-01,  7.95526112e-01,
         1.76083421e+00,  7.17412330e-01, -4.33856011e-01, -9.74882088e-02,
         9.90518623e-01, -2.96245206e-01, -4.99436438e-01, -5.57845626e-01,
        -1.23311177e+00, -2.89100204e-01, -2.75415560e-01, -3.70020645e-01,
         4.13138633e-01, -8.08419056e-01,  5.27284911e-01, -9.73329982e-01,
         1.51798588e-03,  1.80738885e-02,  1.49028886e-03]),
 array([ 6.97175780e-03,  5.95309350e-03,  7.52527000e-01, -9.98823220e-01,
        -1.12807980e-02, -4.71493919e-02, -1.36156720e-03, -3.17845401e-02,
         3.42894366e-01, -3.98943870e-02, -8.93776128e-02, -4.40016703e-01,
         5.91198868e-01, -6.19465293e-01,  5.02453365e-01,  7.96520516e-01,
         1.75985929e+00,  7

In [24]:
x_pos_matrix = np.zeros((13,3))
x_pos_matrix[1] = np.ones(3)
x_pos_matrix

array([[0., 0., 0.],
       [1., 1., 1.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]])