In [20]:
import agents_env
from agents_env.agent_replay_motion import HumanoidReplay
from utils.SimpleConverter import SimpleConverter
from utils.util_data import *

In [21]:

from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model
from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
from jax import vmap

In [22]:
envs.register_environment('humanoidReplay', HumanoidReplay)

Get the motion data and the model path

In [23]:
trajectory = SimpleConverter('motions/humanoid3d_punch.txt')
trajectory.load_mocap()
model_path = 'models/final_humanoid.xml'

In [24]:
data_mocap_matrix = jp.asarray(trajectory.data)
data_pos_mocap = jp.asarray(trajectory.data_pos)
data_vel_mocap = jp.asarray(trajectory.data_vel)

data_dict_mocap = trajectory.duration_dict



In [25]:
data_dict_mocap

{0: [0.0, 0.0333333015],
 1: [0.03333330154418945, 0.0333333015],
 2: [0.0666666030883789, 0.0333333015],
 3: [0.09999990463256836, 0.0333333015],
 4: [0.1333332061767578, 0.0333333015],
 5: [0.16666650772094727, 0.0333333015],
 6: [0.19999980926513672, 0.0333333015],
 7: [0.23333311080932617, 0.0333333015],
 8: [0.2666664123535156, 0.0333333015],
 9: [0.2999997138977051, 0.0333333015],
 10: [0.33333301544189453, 0.0333333015],
 11: [0.366666316986084, 0.0333333015],
 12: [0.39999961853027344, 0.0333333015],
 13: [0.4333329200744629, 0.0333333015],
 14: [0.46666622161865234, 0.0333333015],
 15: [0.4999995231628418, 0.0333333015],
 16: [0.5333328247070312, 0.0333333015],
 17: [0.5666661262512207, 0.0333333015],
 18: [0.5999994277954102, 0.0333333015],
 19: [0.6333327293395996, 0.0333333015],
 20: [0.6666660308837891, 0.0333333015],
 21: [0.6999993324279785, 0.0333333015],
 22: [0.733332633972168, 0.0333333015],
 23: [0.7666659355163574, 0.0333333015],
 24: [0.7999992370605469, 0.0333333

In [26]:
env_name = 'humanoidReplay'
env = envs.get_environment(env_name=env_name,
                           reference_trajectory_qpos=data_pos_mocap,
                           reference_trajectory_qvel = data_vel_mocap,
                            duration_trajectory=trajectory.total_time,
                            dict_duration= data_dict_mocap,
                           model_path=model_path)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [27]:
env.reference_trajectory_qpos.shape

(65, 35)

In [28]:

def visualizer(rollout):
    media.show_video(env.render(rollout, camera='back'), fps=1.0/env.dt)

In [29]:
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

for k in range(len(env.dict_duration)):
    qpos = state.pipeline_state.qpos
    tmp_pos = env.reference_trajectory_qpos[k]
    updated = qpos.at[:].set(tmp_pos[:])
    
    state = state.tree_replace({'pipeline_state.qpos': updated})
    ctrl = -0.1 * jp.ones(env.sys.nv)
    state = jit_step(state, ctrl)
    #print(state.pipeline_state.time)
    rollout.append(state.pipeline_state)

In [30]:
visualizer(rollout)

0
This browser does not support the video tag.


Visualize only the arm

In [31]:

chest_x = get_joint_index(env.sys.mj_model,'chest','X')


#data of all the right arm
right_shoulder_x = get_joint_index(env.sys.mj_model,'right_shoulder','X')
right_shoulder_y = get_joint_index(env.sys.mj_model,'right_shoulder','Y')
right_shoulder_z = get_joint_index(env.sys.mj_model,'right_shoulder','Z')

right_elbow = get_joint_index(env.sys.mj_model,'right_elbow','X')


print(chest_x)
print(right_shoulder_x)
print(right_shoulder_y)
print(right_shoulder_z)
print(right_elbow)

7
13
14
15
16


In [32]:


# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

for k in range(len(env.dict_duration)):
    qpos = state.pipeline_state.qpos
    #only grab the right arm poses
    tmp_pos = env.reference_trajectory_qpos[k,right_shoulder_x:right_elbow+1]
    #gran zeor data, to concatenate
    new_pos = jp.concatenate([jp.zeros(right_shoulder_x)])
    updated = qpos.at[right_shoulder_x:right_elbow+1].set(tmp_pos)
    
    
    state = state.tree_replace({'pipeline_state.qpos': updated})
    ctrl = -0.1 * jp.ones(env.sys.nv)
    state = jit_step(state, ctrl)
    #print(state.pipeline_state.time)
    rollout.append(state.pipeline_state)

In [33]:
visualizer(rollout)

0
This browser does not support the video tag.


Now grab a specific frame, and only render that

In [34]:
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]
#keep track of the time
time = 0

for k in range(len(env.dict_duration)):
    qpos = state.pipeline_state.qpos
    # print("current-time", time)
    
    #so now on the dictionary grab a frame and see the duration and dt
    timestep_trajectory, dura_dt = env.dict_duration[k]
    if k >= 45:
        
        # print('time step', timestep_trajectory)
        # print('dura', dura_dt)
        # #only grab the right arm poses
        # tmp_pos = data_pos_mocap[k,right_shoulder_x:right_elbow+1]
        # updated = qpos.at[right_shoulder_x:right_elbow+1].set(tmp_pos
        tmp_pos = env.reference_trajectory_qpos[k]
        updated = qpos.at[:].set(tmp_pos[:])
        
        state = state.tree_replace({'pipeline_state.qpos': updated})
        ctrl = -0.1 * jp.ones(env.sys.nv)
    state = jit_step(state, ctrl)
    #print(state.pipeline_state.time)
    rollout.append(state.pipeline_state)
        
    #time += env.sys.dt
    time += trajectory.durations[k]

In [35]:
visualizer(rollout)

0
This browser does not support the video tag.


In [36]:
env.dict_duration

{0: [0.0, 0.0333333015],
 1: [0.03333330154418945, 0.0333333015],
 2: [0.0666666030883789, 0.0333333015],
 3: [0.09999990463256836, 0.0333333015],
 4: [0.1333332061767578, 0.0333333015],
 5: [0.16666650772094727, 0.0333333015],
 6: [0.19999980926513672, 0.0333333015],
 7: [0.23333311080932617, 0.0333333015],
 8: [0.2666664123535156, 0.0333333015],
 9: [0.2999997138977051, 0.0333333015],
 10: [0.33333301544189453, 0.0333333015],
 11: [0.366666316986084, 0.0333333015],
 12: [0.39999961853027344, 0.0333333015],
 13: [0.4333329200744629, 0.0333333015],
 14: [0.46666622161865234, 0.0333333015],
 15: [0.4999995231628418, 0.0333333015],
 16: [0.5333328247070312, 0.0333333015],
 17: [0.5666661262512207, 0.0333333015],
 18: [0.5999994277954102, 0.0333333015],
 19: [0.6333327293395996, 0.0333333015],
 20: [0.6666660308837891, 0.0333333015],
 21: [0.6999993324279785, 0.0333333015],
 22: [0.733332633972168, 0.0333333015],
 23: [0.7666659355163574, 0.0333333015],
 24: [0.7999992370605469, 0.0333333

In [37]:

def visualize_per_time_step():
    # initialize the state
    state = jit_reset(jax.random.PRNGKey(0))
    rollout = [state.pipeline_state]
    time = 0
    frames = 0
    N=1300
    for i in range(N):
        qpos = state.pipeline_state.qpos
        #print("current-time", time)
        frames = np.clip(frames,0,len(env.dict_duration)-1)
        
        #so now on the dictionary grab a frame and see the duration and dt
        timestep_trajectory, dura_dt = env.dict_duration[frames]
        
        if time >= timestep_trajectory:
            #we increment the frames
            frames+=1
            
            # print('time step', timestep_trajectory)
            # print('dura', dura_dt)
        
            # print("frames:", frames)
            # #only grab the right arm poses
            # tmp_pos = data_pos_mocap[k,right_shoulder_x:right_elbow+1]
            # updated = qpos.at[right_shoulder_x:right_elbow+1].set(tmp_pos
            tmp_pos = env.reference_trajectory_qpos[frames]
            updated = qpos.at[:].set(tmp_pos[:])
            
            state = state.tree_replace({'pipeline_state.qpos': updated})
            ctrl = -0.1 * jp.ones(env.sys.nv)
            state = jit_step(state, ctrl)

                    #print(state.pipeline_state.time)
            rollout.append(state.pipeline_state)
            
        time += env.sys.dt
    return rollout


In [38]:
rollout = visualize_per_time_step()
visualizer(rollout)

0
This browser does not support the video tag.


Visualization on other movements

In [39]:
stand_up = SimpleConverter('motions/humanoid3d_getup_faceup.txt')
stand_up.load_mocap()

In [40]:
stand_data_mocap_matrix = jp.asarray(stand_up.data)
stand_data_pos_mocap = jp.asarray(stand_up.data_pos)
stand_data_vel_mocap = jp.asarray(stand_up.data_vel)
stand_data_dict_mocap = stand_up.duration_dict


In [41]:
env.set_new_trajectory(stand_up)

In [42]:
env.duration_trajectory

3.7665161581999853

In [43]:
stand_up.total_time

3.7665161581999853

In [44]:
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

for k in range(len(env.dict_duration)):
    qpos = state.pipeline_state.qpos
    tmp_pos = env.reference_trajectory_qpos[k]
    updated = qpos.at[:].set(tmp_pos[:])
    
    state = state.tree_replace({'pipeline_state.qpos': updated})
    ctrl = -0.1 * jp.ones(env.sys.nv)
    state = jit_step(state, ctrl)
    #print(state.pipeline_state.time)
    rollout.append(state.pipeline_state)

In [45]:
visualizer(rollout)

0
This browser does not support the video tag.


In [46]:
len(env.dict_duration)

227

In [47]:
env.reference_trajectory_qpos.shape

(227, 35)

In [48]:
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]
#keep track of the time
time = 0

for k in range(len(env.dict_duration)):
    qpos = state.pipeline_state.qpos
    # print("current-time", time)
    
    #so now on the dictionary grab a frame and see the duration and dt
    timestep_trajectory, dura_dt = env.dict_duration[k]
    if k >= 50:
        
        # print('time step', timestep_trajectory)
        # print('dura', dura_dt)
        # #only grab the right arm poses
        # tmp_pos = data_pos_mocap[k,right_shoulder_x:right_elbow+1]
        # updated = qpos.at[right_shoulder_x:right_elbow+1].set(tmp_pos
        tmp_pos = env.reference_trajectory_qpos[k]
        updated = qpos.at[:].set(tmp_pos[:])
        
        state = state.tree_replace({'pipeline_state.qpos': updated})
        ctrl = -0.1 * jp.ones(env.sys.nv)
    state = jit_step(state, ctrl)
    #print(state.pipeline_state.time)
    rollout.append(state.pipeline_state)
        
    #time += env.sys.dt
    time += dura_dt

In [49]:
visualizer(rollout)

0
This browser does not support the video tag.


In [50]:
env.dict_duration

{0: [0.0, 0.0166660007],
 1: [0.016666000708937645, 0.0166660007],
 2: [0.03333200141787529, 0.0166660007],
 3: [0.049998000264167786, 0.0166660007],
 4: [0.06666400283575058, 0.0166660007],
 5: [0.08333000540733337, 0.0166660007],
 6: [0.09999600052833557, 0.0166660007],
 7: [0.11666200309991837, 0.0166660007],
 8: [0.13332800567150116, 0.0166660007],
 9: [0.14999400079250336, 0.0166660007],
 10: [0.16666001081466675, 0.0166660007],
 11: [0.18332600593566895, 0.0166660007],
 12: [0.19999200105667114, 0.0166660007],
 13: [0.21665799617767334, 0.0166660007],
 14: [0.23332400619983673, 0.0166660007],
 15: [0.24999000132083893, 0.0166660007],
 16: [0.2666560113430023, 0.0166660007],
 17: [0.2833220064640045, 0.0166660007],
 18: [0.2999880015850067, 0.0166660007],
 19: [0.3166539967060089, 0.0166660007],
 20: [0.3333200216293335, 0.0166660007],
 21: [0.3499860167503357, 0.0166660007],
 22: [0.3666520118713379, 0.0166660007],
 23: [0.3833180069923401, 0.0166660007],
 24: [0.3999840021133423

In [51]:

def visualize_per_time_step():
    # initialize the state
    state = jit_reset(jax.random.PRNGKey(0))
    rollout = [state.pipeline_state]
    time = 0
    frames = 0
    N=3000
    for i in range(N):
        qpos = state.pipeline_state.qpos
        print("current-time", frames)
        
        frames = np.clip(frames,0,len(env.dict_duration)-1)
        
        print("current-time", type(frames))
        
        #so now on the dictionary grab a frame and see the duration and dt
        timestep_trajectory, dura_dt = env.dict_duration[frames]
        
        if time >= timestep_trajectory:
            #we increment the frames
            frames+=1
            
            # print('time step', timestep_trajectory)
            # print('dura', dura_dt)
        
            # print("frames:", frames)
            # #only grab the right arm poses
            # tmp_pos = data_pos_mocap[k,right_shoulder_x:right_elbow+1]
            # updated = qpos.at[right_shoulder_x:right_elbow+1].set(tmp_pos
            tmp_pos = env.reference_trajectory_qpos[frames]
            updated = qpos.at[:].set(tmp_pos[:])
            
            state = state.tree_replace({'pipeline_state.qpos': updated})
            ctrl = -0.1 * jp.ones(env.sys.nv)
            state = jit_step(state, ctrl)

                    #print(state.pipeline_state.time)
            rollout.append(state.pipeline_state)
            
        time += env.sys.dt
    return rollout

In [52]:
rollout = visualize_per_time_step()
visualizer(rollout)

current-time 0
current-time <class 'numpy.int32'>
current-time 1
current-time <class 'numpy.int32'>
current-time 1
current-time <class 'numpy.int32'>
current-time 1
current-time <class 'numpy.int32'>
current-time 1
current-time <class 'numpy.int32'>
current-time 1
current-time <class 'numpy.int32'>
current-time 1
current-time <class 'numpy.int32'>
current-time 1
current-time <class 'numpy.int32'>
current-time 1
current-time <class 'numpy.int32'>
current-time 1
current-time <class 'numpy.int32'>
current-time 2
current-time <class 'numpy.int32'>
current-time 2
current-time <class 'numpy.int32'>
current-time 2
current-time <class 'numpy.int32'>
current-time 2
current-time <class 'numpy.int32'>
current-time 2
current-time <class 'numpy.int32'>
current-time 2
current-time <class 'numpy.int32'>
current-time 2
current-time <class 'numpy.int32'>
current-time 2
current-time <class 'numpy.int32'>
current-time 3
current-time <class 'numpy.int32'>
current-time 3
current-time <class 'numpy.int32'>


current-time 7
current-time <class 'numpy.int32'>
current-time 7
current-time <class 'numpy.int32'>
current-time 7
current-time <class 'numpy.int32'>
current-time 7
current-time <class 'numpy.int32'>
current-time 7
current-time <class 'numpy.int32'>
current-time 7
current-time <class 'numpy.int32'>
current-time 7
current-time <class 'numpy.int32'>
current-time 7
current-time <class 'numpy.int32'>
current-time 7
current-time <class 'numpy.int32'>
current-time 8
current-time <class 'numpy.int32'>
current-time 8
current-time <class 'numpy.int32'>
current-time 8
current-time <class 'numpy.int32'>
current-time 8
current-time <class 'numpy.int32'>
current-time 8
current-time <class 'numpy.int32'>
current-time 8
current-time <class 'numpy.int32'>
current-time 8
current-time <class 'numpy.int32'>
current-time 8
current-time <class 'numpy.int32'>
current-time 9
current-time <class 'numpy.int32'>
current-time 9
current-time <class 'numpy.int32'>
current-time 9
current-time <class 'numpy.int32'>


0
This browser does not support the video tag.


In [53]:
env.duration_trajectory

3.7665161581999853

Now I will test it with velocity

In [54]:
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

for k in range(len(env.dict_duration)):
    qpos = state.pipeline_state.qpos
    qvel = state.pipeline_state.qvel
    
    tmp_pos = env.reference_trajectory_qpos[k]
    tmp_vel = env.reference_trajectory_qvel[k]
    
    updated = qpos.at[:].set(tmp_pos[:])
    updated_vel = qvel.at[:].set(tmp_vel[:])
    
    state = state.tree_replace({'pipeline_state.qpos': updated})
    state = state.tree_replace({'pipeline_state.qvel': updated_vel})
    ctrl = -0.1 * jp.ones(env.sys.nv)
    state = jit_step(state, ctrl)
    #print(state.pipeline_state.time)
    rollout.append(state.pipeline_state)

In [55]:
visualizer(rollout)

0
This browser does not support the video tag.


In [56]:
env.reference_trajectory_qvel.shape

(227, 34)

In [57]:
env.sys.dt

0.002