In [1]:
import agents_env
from agents_env.agent_mimic import HumanoidTrain
from utils.SimpleConverter import SimpleConverter
from utils.util_data import *
from some_math.math_utils import generate_trajectory,compute_cubic_trajectory,start_trajectories


In [2]:

from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model
from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
from jax import vmap

In [3]:
envs.register_environment('humanoidMimic', HumanoidTrain)

In [4]:
#since we are going to use custom trajectory, we will set up the initial
#and the end time
t_init = 1
t_end = 3

In [5]:
#this is just dummy data to initialize the agent
trajectory = SimpleConverter('motions/humanoid3d_punch.txt')
trajectory.load_mocap()
model_path = 'models/final_humanoid_no_gravity.xml'

data_mocap_matrix = jp.asarray(trajectory.data)
data_pos_mocap = jp.asarray(trajectory.data_pos)
data_vel_mocap = jp.asarray(trajectory.data_vel)

data_dict_mocap = trajectory.duration_dict

In [6]:
data_pos_mocap[0]

Array([ 0.0000000e+00,  0.0000000e+00,  7.5521302e-01, -9.9875391e-01,
       -1.5632922e-02, -4.7388997e-02,  7.4175280e-04, -4.2744510e-02,
        3.4576020e-01, -3.6806550e-02, -8.7780811e-02, -4.4509736e-01,
        5.8729285e-01, -6.2009311e-01,  5.0161451e-01,  7.9552609e-01,
        1.7608342e+00,  7.1741235e-01, -4.3385601e-01, -9.7488210e-02,
        9.9051863e-01, -2.9624522e-01, -4.9943644e-01, -5.5784565e-01,
       -1.2331117e+00, -2.8910020e-01, -2.7541557e-01, -3.7002066e-01,
        4.1313863e-01, -8.0841905e-01,  5.2728492e-01, -9.7332996e-01,
        1.5179858e-03,  1.8073888e-02,  1.4902889e-03], dtype=float32)

In [7]:
#get th kp and kd for the agent
kp,kd = generate_kp_kd_gains()
print(kp)
print(kd)

[1000 1000 1000  100  100  100  400  400  400  300  400  400  400  300
  500  500  500  500  400  400  400  500  500  500  500  400  400  400]
[100 100 100  10  10  10  40  40  40  30  40  40  40  30  50  50  50  50
  40  40  40  50  50  50  50  40  40  40]


In [8]:
env_name = 'humanoidMimic'
env = envs.get_environment(env_name=env_name,
                           reference_trajectory_qpos=data_pos_mocap,
                           reference_trajectory_qvel = data_vel_mocap,
                            duration_trajectory=trajectory.total_time,
                            dict_duration= data_dict_mocap,
                           model_path=model_path,
                           kp_gains = kp,
                           kd_gains = kd)
jit_reset = jax.jit(env.reset)
#jit_step = jax.jit(env.step)
jit_step_selected_joints = jax.jit(env.step_selected_joints_custom_target_and_joints)
jit_step_selected_joints6 = jax.jit(env.step_selected_joints_custom_target_and_joints6)
jit_custom_traj = jax.jit(env.step_custom_target)

In [9]:
trajec_dict = dict()


a_jnt_right =get_actuator_indx(env.sys.mj_model,'right_shoulder','Y')
a_jnt_left =get_actuator_indx(env.sys.mj_model,'left_shoulder','X')
a_jnt_left_elbow = get_actuator_indx(env.sys.mj_model,'left_elbow','X')
a_jnt_right_elbow = get_actuator_indx(env.sys.mj_model,'right_elbow','X')

right_knee = get_actuator_indx(env.sys.mj_model,'right_knee','X')
left_knee = get_actuator_indx(env.sys.mj_model,'left_knee','X')

trajec_dict[a_jnt_right] = generate_trajectory(t_init,t_end, 0, -1.5)
#trajectory left
trajec_dict[a_jnt_left] = generate_trajectory(t_init, t_end, 0, 1.5)
# #left elbow   
trajec_dict[a_jnt_left_elbow]= generate_trajectory(t_init, t_end, 0, 1.5)
# #right elbow
trajec_dict[a_jnt_right_elbow] = generate_trajectory(t_init, t_end, 0, 1.5)

trajec_dict[right_knee] = generate_trajectory(t_init, t_end, 0, 0)

trajec_dict[left_knee] = generate_trajectory(t_init, t_end, 0, 0)


start_trajec = start_trajectories(trajec_dict)

In [10]:
#select the pd_control
from agents_env.pds_controllers_agents import stable_pd_controller_custom_trajectory
env.set_pd_callback(stable_pd_controller_custom_trajectory)

In [11]:
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

In [12]:
state.pipeline_state.xpos

Array([[ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.9       ],
       [ 0.        ,  0.        ,  1.136151  ],
       [ 0.        ,  0.        ,  1.360045  ],
       [-0.02405   , -0.18311   ,  1.379651  ],
       [-0.02405   , -0.18311   ,  1.1048629 ],
       [-0.02405   ,  0.18311   ,  1.379651  ],
       [-0.02405   ,  0.18311   ,  1.1048629 ],
       [ 0.        , -0.084887  ,  0.9       ],
       [ 0.        , -0.084887  ,  0.47845396],
       [ 0.        , -0.084887  ,  0.06858397],
       [ 0.        ,  0.084887  ,  0.9       ],
       [ 0.        ,  0.084887  ,  0.47845396],
       [ 0.        ,  0.084887  ,  0.06858397]], dtype=float32)

In [13]:
state.pipeline_state.qpos


Array([0. , 0. , 0.9, 1. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], dtype=float32)

In [14]:
state.pipeline_state.xquat

Array([[1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.]], dtype=float32)

In [15]:
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]
# grab a 500 steps
n_steps = 500
for i in range(n_steps):
    
    
    ctrl = -0.1 * jp.ones(env.sys.nu)
    
    print(state.pipeline_state.xpos)
    #time
    time = state.pipeline_state.time
    
    time = jp.clip(time, t_init, t_end)
    
   
         
    state = jit_step_selected_joints(state, ctrl,start_trajec,a_jnt_right,
                                     a_jnt_left,a_jnt_left_elbow,
                                     a_jnt_right_elbow,right_knee,left_knee,time)
    
    rollout.append(state.pipeline_state)

[[ 0.          0.          0.        ]
 [ 0.          0.          0.9       ]
 [ 0.          0.          1.136151  ]
 [ 0.          0.          1.360045  ]
 [-0.02405    -0.18311     1.379651  ]
 [-0.02405    -0.18311     1.1048629 ]
 [-0.02405     0.18311     1.379651  ]
 [-0.02405     0.18311     1.1048629 ]
 [ 0.         -0.084887    0.9       ]
 [ 0.         -0.084887    0.47845396]
 [ 0.         -0.084887    0.06858397]
 [ 0.          0.084887    0.9       ]
 [ 0.          0.084887    0.47845396]
 [ 0.          0.084887    0.06858397]]
[[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
 [-5.5359988e-07 -4.4057310e-07  8.9999998e-01]
 [-6.1875943e-07 -5.2500241e-07  1.1361510e+00]
 [-1.3593495e-07 -2.1095633e-07  1.3600450e+00]
 [-2.4051178e-02 -1.8311004e-01  1.3796513e+00]
 [-2.4047578e-02 -1.8311033e-01  1.1048633e+00]
 [-2.4049006e-02  1.8310995e-01  1.3796507e+00]
 [-2.4049833e-02  1.8311368e-01  1.1048627e+00]
 [-1.0074179e-06 -8.4887438e-02  8.9999992e-01]
 [-4.8414466e-07 -8.4

In [16]:
media.show_video(env.render(rollout, camera='back'), fps=1.0/env.dt)

0
This browser does not support the video tag.


In [20]:
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

In [21]:
state.pipeline_state.qpos

Array([0. , 0. , 0.9, 1. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], dtype=float32)

In [None]:
state.pipeline_state.xquat

Array([[ 1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [ 9.97073293e-01, -2.37808656e-02, -4.49362490e-03,
        -7.25199953e-02],
       [ 9.96299207e-01,  1.57909449e-02,  3.56603898e-02,
        -7.65958503e-02],
       [ 9.96191382e-01,  1.59777664e-02,  3.85766551e-02,
        -7.65456706e-02],
       [ 7.64126122e-01, -7.17872158e-02, -6.40024602e-01,
        -3.64195369e-02],
       [ 1.16652645e-01, -7.72100911e-02, -9.89905238e-01,
         2.27681659e-02],
       [ 7.25966752e-01,  6.87516272e-01, -1.49261337e-02,
         8.41637887e-03],
       [ 5.21675587e-01,  5.09395778e-01, -5.05079329e-01,
        -4.61806715e-01],
       [ 9.97008801e-01, -1.21779134e-03, -2.85205431e-02,
        -7.18229339e-02],
       [ 9.97266233e-01, -4.11111163e-04, -1.73218939e-02,
        -7.18320832e-02],
       [ 9.96874213e-01, -2.79931184e-02,  1.38660250e-02,
        -7.25670084e-02],
       [ 9.97462630e-01, -1.81207545e-02, -9.99111496e-03,
      