In [1]:
from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.apg import train as apg
from brax.training.agents.apg import networks as apg_networks
from brax.io import html, mjcf, model
from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
from jax import vmap
import jax.random
from jax import lax

In [2]:
from utils.SimpleConverter import SimpleConverter
from agent_mimic_env.agent_template import HumanoidTemplate
from agent_mimic_env.agent_eval_template import HumanoidEvalTemplate
from agent_mimic_env.agent_test_apg import HumanoidAPGTest
from utils.util_data import *
from copy import deepcopy

In [3]:
import agent_mimic_env
from agent_mimic_env import register_mimic_env


In [4]:

import distutils.util
import os
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

Setting environment variable to use GPU rendering:
env: MUJOCO_GL=egl


In [5]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [6]:
import yaml
from box import Box
# Path to your YAML file
yaml_file_path = 'config_params/punch.yaml'
# Load the YAML file
with open(yaml_file_path, 'r') as file:
    args = Box(yaml.safe_load(file))

In [7]:
print(args)

{'num_envs': 256, 'num_eval_envs': 32, 'lr': 0.0003, 'max_it': 1000, 'max_grad_norm': 0.4, 'seed': 0, 'system_config': 'humanoid', 'demo_replay_mode': 'threshold', 'threshold': 0.4, 'normalize_observations': True, 'cycle_len': 0, 'ep_len': 30, 'ep_len_eval': 54, 'use_lr_scheduler': True, 'reward_scaling': 0.02, 'rot_weight': 0.5, 'vel_weight': 0.01, 'ang_weight': 0.01, 'reward_weights': {'w_p': 0.6, 'w_v': 0.1, 'w_e': 0.2, 'w_c': 0.1}, 'ref': 'motions/humanoid3d_punch.txt'}


In [8]:
args.reward_weights.w_p

0.6

In [9]:
from agent_mimic_env.pds_controllers_agents import feedback_pd_controller, stable_pd_controller_action


In [10]:
env_replay,env_eval,env,env_apg=register_mimic_env(args)


#for the env
jit_reset = jax.jit(env.reset)
#jit_step = jax.jit(env.step)
jit_step_custom_env = jax.jit(env.step_custom)
env.set_pd_callback(stable_pd_controller_action)


# #for the eval
jit_reset_eval = jax.jit(env_eval.reset)
jit_step_custom_eval = jax.jit(env_eval.step_custom)
env_eval.set_pd_callback(stable_pd_controller_action)

#for the replay
jit_reset_replay = jax.jit(env_replay.reset)
jit_step_replay= jax.jit(env_replay.step)

#for apg
jit_reset_apg = jax.jit(env_apg.reset)
jit_step_custom_apg= jax.jit(env_apg.step_custom)
env_apg.set_pd_callback(stable_pd_controller_action)





In [11]:
print(env_eval.sys.opt.timestep)
print(env_eval.dt)

0.002
0.016


In [12]:
print(env_eval.rollout_lenght)
print(env_eval.cycle_len)
print(env_eval.rot_weight)
print(env_eval.vel_weight)
print(env_eval.ang_weight)
print(env_eval.reward_scaling)
print(env_eval.dt)
# print(env_eval.sys.dt)

print(env.rollout_lenght)
print(env.cycle_len)
print(env.rot_weight)
print(env.vel_weight)
print(env.ang_weight)
print(env.reward_scaling)
print(env.dt)
print(env.dt)
print(env.err_threshold)

65
65
0.5
0.01
0.01
0.02
0.016
65
65
0.5
0.01
0.01
0.02
0.016
0.016
0.4


In [13]:
# initialize the state

#for the replay
# state = jit_reset_replay(jax.random.PRNGKey(0))
# rollout = [state.pipeline_state]

#for the eval
# state = jit_reset_eval(jax.random.PRNGKey(0))
# rollout = [state.pipeline_state]
#for env
# state = jit_reset(jax.random.PRNGKey(0))
# rollout = [state.pipeline_state]

#for apg test
state = jit_reset_apg(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

# grab a 500 steps
n_steps = 500

In [14]:
print(state.metrics)
print(state.info)

{'reference_angular': Array(0., dtype=float32, weak_type=True), 'reference_position': Array(0., dtype=float32, weak_type=True), 'reference_rotation': Array(0., dtype=float32, weak_type=True), 'reference_velocity': Array(0., dtype=float32, weak_type=True)}
{'kinematic_ref': Array([ 0.0000000e+00,  0.0000000e+00,  7.5521302e-01, -9.9875391e-01,
       -1.5632922e-02, -4.7388997e-02,  7.4175280e-04, -4.2744510e-02,
        3.4576020e-01, -3.6806550e-02, -8.7780811e-02, -4.4509736e-01,
        5.8729285e-01, -6.2009311e-01,  5.0161451e-01,  7.9552609e-01,
        1.7608342e+00,  7.1741235e-01, -4.3385601e-01, -9.7488210e-02,
        9.9051863e-01, -2.9624522e-01, -4.9943644e-01, -5.5784565e-01,
       -1.2331117e+00, -2.8910020e-01, -2.7541557e-01, -3.7002066e-01,
        4.1313863e-01, -8.0841905e-01,  5.2728492e-01, -9.7332996e-01,
        1.5179858e-03,  1.8073888e-02,  1.4902889e-03], dtype=float32), 'pose_error': Array(0., dtype=float32, weak_type=True), 'reward_tuple': {'reference_an

In [15]:
for i in range(env_apg.rollout_lenght-1):
    
    #print(i)
    #ctrl = -0.1 * jp.ones(env_eval.sys.nu)
    #ctrl = -0.1 * jp.ones(env.sys.nu)
    #ctrl = -0.1 * jp.ones(env_replay.sys.nu)
    ctrl = -0.1 * jp.ones(env_apg.sys.nu)
    #time
    #time = state.pipeline_state.time
    
    print('time: ',state.pipeline_state.time)    
         
    #state = jit_step_replay(state, ctrl)
    #state = jit_step_custom_eval(state, ctrl)
    #state = jit_step(state, ctrl)
    #state = jit_step_custom_env(state,ctrl)
    state = jit_step_custom_apg(state,ctrl)
    
    #print(state.pipeline_state.geom_xpos)
    #if state.done:
     #   print(time)
      #  break
    #print("Is done", state.done)
    # print("Is rewards", state.reward)
    print('is step', state.info['step_index'])
    #print("reward tuple", state.info['reward_tuple'])
    # print("Is ", state.metrics['pose_error'])
    
    rollout.append(state.pipeline_state)

time:  0.0
is step 1.0
time:  0.016
is step 2.0
time:  0.032
is step 3.0
time:  0.048000004
is step 4.0
time:  0.064
is step 5.0
time:  0.079999976
is step 6.0
time:  0.09599995
is step 7.0
time:  0.11199992
is step 8.0
time:  0.1279999
is step 9.0
time:  0.14399993
is step 10.0
time:  0.15999997
is step 11.0
time:  0.176
is step 12.0
time:  0.19200003
is step 13.0
time:  0.20800006
is step 14.0
time:  0.2240001
is step 15.0
time:  0.24000013
is step 16.0
time:  0.25600016
is step 17.0
time:  0.2720002
is step 18.0
time:  0.28800023
is step 19.0
time:  0.30400026
is step 20.0
time:  0.3200003
is step 21.0
time:  0.33600032
is step 22.0
time:  0.35200036
is step 23.0
time:  0.3680004
is step 24.0
time:  0.38400042
is step 25.0
time:  0.40000045
is step 26.0
time:  0.4160005
is step 27.0
time:  0.43200052
is step 28.0
time:  0.44800055
is step 29.0
time:  0.46400058
is step 30.0
time:  0.48000062
is step 31.0
time:  0.49600065
is step 32.0
time:  0.5120005
is step 33.0
time:  0.5280003
i

In [16]:
from some_math.quaternion_diff import *

In [17]:
#6 right writs 9 left writs 12 right ankle, 15 left ankle

print(state.pipeline_state.geom_xpos)


#index 16

[[ 0.          0.          0.        ]
 [ 0.17457122 -0.04440825  0.72948873]
 [ 0.05302095 -0.0968165   0.9798447 ]
 [-0.04608192 -0.17313755  1.2264762 ]
 [ 0.02906167 -0.35319     0.94155866]
 [ 0.13455467 -0.4777181   0.76139414]
 [ 0.22441347 -0.5568892   0.6909413 ]
 [-0.02213942  0.16763127  1.0179135 ]
 [ 0.05875943  0.19120699  0.9022077 ]
 [ 0.16650769  0.10724488  0.8767724 ]
 [ 0.41481942 -0.0881189   0.5763192 ]
 [ 0.6173413  -0.01388184  0.30129325]
 [ 0.67606306  0.03621648  0.08926342]
 [ 0.3571118   0.15534595  0.5905572 ]
 [ 0.47069862  0.2965108   0.3024173 ]
 [ 0.45456767  0.35851815  0.06761813]]


In [19]:
#HTML(html.render(env_replay.sys.tree_replace({'opt.timestep': env_replay.dt}), rollout))
#HTML(html.render(env_eval.sys.tree_replace({'opt.timestep': env_eval.dt}), rollout))
#HTML(html.render(env.sys.tree_replace({'opt.timestep': env.dt}), rollout))
HTML(html.render(env_apg.sys.tree_replace({'opt.timestep': env_apg.dt}), rollout))
