In [1]:
import agents_env
from agents_env.agent_mimic_eval import HumanoidEnvTrainEval
from utils.SimpleConverter import SimpleConverter
from utils.util_data import *
from copy import deepcopy

In [2]:
from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model
from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
from jax import vmap
import jax.random
from jax import lax

In [3]:

#import distutils.util
import os
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl


Setting environment variable to use GPU rendering:
env: MUJOCO_GL=egl


In [4]:
from agents_env.agent_env_template import HumanoidDiff
from some_math.rotation6D import quaternion_to_rotation_6d
from some_math.math_utils_jax import *
from agents_env.losses import *

In [5]:
#since we are going to use custom trajectory, we will set up the initial
#and the end time
t_init = 1
t_end = 3

In [6]:
#this is just dummy data to initialize the agent
trajectory = SimpleConverter('motions/humanoid3d_punch.txt')
trajectory.load_mocap()
model_path = 'models/final_humanoid_no_gravity.xml'

data_mocap_matrix = jp.asarray(trajectory.data)
data_pos_mocap = jp.asarray(trajectory.data_pos)
data_vel_mocap = jp.asarray(trajectory.data_vel)

data_dict_mocap = trajectory.duration_dict

data_xpos_mocap=jp.asarray(trajectory.data_xpos)
data_xrot_mocap=jp.asarray(trajectory.data_xrot)


In [7]:
#get th kp and kd for the agent
kp,kd = generate_kp_kd_gains()
print(kp)
print(kd)

[1000 1000 1000  100  100  100  400  400  400  300  400  400  400  300
  500  500  500  500  400  400  400  500  500  500  500  400  400  400]
[100 100 100  10  10  10  40  40  40  30  40  40  40  30  50  50  50  50
  40  40  40  50  50  50  50  40  40  40]


In [8]:
#this is just dummy data to initialize the agent
trajectory = SimpleConverter('motions/humanoid3d_punch.txt')
trajectory.load_mocap()
model_path = 'models/final_humanoid_no_gravity.xml'

data_mocap_matrix = jp.asarray(trajectory.data)
data_pos_mocap = jp.asarray(trajectory.data_pos)
data_vel_mocap = jp.asarray(trajectory.data_vel)

data_dict_mocap = trajectory.duration_dict

In [9]:
envs.register_environment('humanoidEnvMimicEval',HumanoidEnvTrainEval )
env_name = 'humanoidEnvMimicEval'
env = envs.get_environment(env_name=env_name,
                           reference_trajectory_qpos=data_pos_mocap,
                           reference_trajectory_qvel = data_vel_mocap,
                            duration_trajectory=trajectory.total_time,
                            dict_duration= data_dict_mocap,
                           model_path=model_path,
                           kp_gains = kp,
                           kd_gains = kd,
                           reference_x_pos=data_xpos_mocap,
                           reference_x_rot=data_xrot_mocap)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)


#select the pd_control
from agents_env.pds_controllers_agents import stable_pd_controller_custom_trajectory, stable_pd_controller_trajectory,feedback_pd_controller
#env.set_pd_callback(stable_pd_controller_custom_trajectory)
env.set_pd_callback(feedback_pd_controller)

In [10]:
from some_math.math_utils import generate_trajectory,compute_cubic_trajectory,start_trajectories

In [11]:
trajec_dict = dict()


a_jnt_right =get_actuator_indx(env.sys.mj_model,'right_shoulder','Y')
a_jnt_left =get_actuator_indx(env.sys.mj_model,'left_shoulder','X')
a_jnt_left_elbow = get_actuator_indx(env.sys.mj_model,'left_elbow','X')
a_jnt_right_elbow = get_actuator_indx(env.sys.mj_model,'right_elbow','X')

right_knee = get_actuator_indx(env.sys.mj_model,'right_knee','X')
left_knee = get_actuator_indx(env.sys.mj_model,'left_knee','X')

trajec_dict[a_jnt_right] = generate_trajectory(t_init,t_end, 0, -1.5)
#trajectory left
trajec_dict[a_jnt_left] = generate_trajectory(t_init, t_end, 0, 1.5)
# #left elbow   
trajec_dict[a_jnt_left_elbow]= generate_trajectory(t_init, t_end, 0, 1.5)
# #right elbow
trajec_dict[a_jnt_right_elbow] = generate_trajectory(t_init, t_end, 0, 1.5)

trajec_dict[right_knee] = generate_trajectory(t_init, t_end, 0, 0)

trajec_dict[left_knee] = generate_trajectory(t_init, t_end, 0, 0)


start_trajec = start_trajectories(trajec_dict)

In [12]:
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]
# grab a 500 steps
n_steps = 500

In [13]:

for i in range(n_steps):
    
    
    ctrl = -0.1 * jp.ones(env.sys.nu)
    #time
    time = state.pipeline_state.time
    
    #print('time: ',state.pipeline_state.time)
    time = jp.clip(time, t_init, t_end)
    
    target = compute_cubic_trajectory(time,start_trajec)
    # print(target)    
    # print(target.shape)    
    state = jit_step(state, ctrl,target,time)
    
    rollout.append(state.pipeline_state)

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]
(28,)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]
(28,)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]
(28,)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]
(28,)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]
(28,)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]
(28,)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]
(28,)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]
(28,)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]
(28,)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]
(28,)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 

In [14]:
media.show_video(env.render(rollout, camera='back'), fps=1.0/env.dt)
#HTML(html.render(env.sys.replace(dt=env.dt), rollout))

0
This browser does not support the video tag.
