In [1]:
from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.apg import train as apg
from brax.training.agents.apg import networks as apg_networks
from brax.io import html, mjcf, model
from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
from jax import vmap
import jax.random
from jax import lax

In [2]:
from utils.SimpleConverter import SimpleConverter
from agent_mimic_env.agent_template import HumanoidTemplate
from agent_mimic_env.agent_eval_template import HumanoidEvalTemplate
from agent_mimic_env.agent_test_apg import HumanoidAPGTest
from utils.util_data import *
from copy import deepcopy

In [3]:
import agent_mimic_env
from agent_mimic_env import register_mimic_env


In [4]:

import distutils.util
import os
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

Setting environment variable to use GPU rendering:
env: MUJOCO_GL=egl


In [5]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [6]:
import yaml
from box import Box
# Path to your YAML file
yaml_file_path = 'config_params/punch.yaml'
# Load the YAML file
with open(yaml_file_path, 'r') as file:
    args = Box(yaml.safe_load(file))

In [7]:
print(args)

{'num_envs': 256, 'num_eval_envs': 32, 'lr': 0.0003, 'max_it': 1000, 'max_grad_norm': 0.4, 'seed': 0, 'system_config': 'humanoid', 'demo_replay_mode': 'threshold', 'threshold': 0.4, 'normalize_observations': True, 'cycle_len': 0, 'ep_len': 30, 'ep_len_eval': 54, 'use_lr_scheduler': True, 'reward_scaling': 0.02, 'rot_weight': 0.5, 'vel_weight': 0.01, 'ang_weight': 0.01, 'reward_weights': {'w_p': 0.6, 'w_v': 0.1, 'w_e': 0.2, 'w_c': 0.1}, 'ref': 'motions/humanoid3d_punch.txt'}


In [8]:
args.reward_weights.w_p

0.6

In [9]:
from agent_mimic_env.pds_controllers_agents import feedback_pd_controller, stable_pd_controller_action


In [10]:
env_replay,env_eval,env,env_apg=register_mimic_env(args)


#for the env
jit_reset = jax.jit(env.reset)
#jit_step = jax.jit(env.step)
jit_step_custom_env = jax.jit(env.step_custom)
env.set_pd_callback(stable_pd_controller_action)


# #for the eval
jit_reset_eval = jax.jit(env_eval.reset)
jit_step_custom_eval = jax.jit(env_eval.step_custom)
env_eval.set_pd_callback(stable_pd_controller_action)

#for the replay
jit_reset_replay = jax.jit(env_replay.reset)
jit_step_replay= jax.jit(env_replay.step)

#for apg
#jit_reset_apg = jax.jit(env_apg.reset)
#jit_step_custom_apg= jax.jit(env_apg.step_custom)
#env_apg.set_pd_callback(stable_pd_controller_action)





qpos init (65, 35)
qvel init (65, 34)


In [11]:
print(env_eval.sys.opt.timestep)
print(env_eval.dt)
print(env.sys.opt)

0.002
0.016
Option(timestep=0.002, impratio=Array(1., dtype=float32, weak_type=True), tolerance=Array(1.e-08, dtype=float32, weak_type=True), ls_tolerance=Array(0.01, dtype=float32, weak_type=True), gravity=Array([ 0.  ,  0.  , -9.81], dtype=float32), wind=Array([0., 0., 0.], dtype=float32), density=Array(0., dtype=float32, weak_type=True), viscosity=Array(0., dtype=float32, weak_type=True), has_fluid_params=False, integrator=<IntegratorType.EULER: 0>, cone=<ConeType.PYRAMIDAL: 0>, jacobian=<JacobianType.AUTO: 2>, solver=<SolverType.NEWTON: 2>, iterations=1, ls_iterations=6, disableflags=<DisableBit.EULERDAMP: 16384>)


In [12]:
print(env_eval.rollout_lenght)
print(env_eval.cycle_len)
print(env_eval.rot_weight)
print(env_eval.vel_weight)
print(env_eval.ang_weight)
print(env_eval.reward_scaling)
print(env_eval.dt)
# print(env_eval.sys.dt)

print(env.rollout_lenght)
print(env.cycle_len)
print(env.rot_weight)
print(env.vel_weight)
print(env.ang_weight)
print(env.reward_scaling)
print(env.dt)
print(env.dt)
print(env.err_threshold)

65
65
0.5
0.01
0.01
0.02
0.016
65
65
0.5
0.01
0.01
0.02
0.016
0.016
0.4


In [13]:
# initialize the state

#for the replay
#state = jit_reset_replay(jax.random.PRNGKey(0))
#rollout = [state.pipeline_state]

#for the eval
#state = jit_reset_eval(jax.random.PRNGKey(0))
#rollout = [state.pipeline_state]
#for env
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

#for apg test
# state = jit_reset_apg(jax.random.PRNGKey(0))
# rollout = [state.pipeline_state]

# grab a 500 steps
n_steps = 500

In [14]:
print(state.metrics)
print(state.info)

{'pose_error': Array(0., dtype=float32, weak_type=True), 'reference_angular': Array(0., dtype=float32, weak_type=True), 'reference_position': Array(0., dtype=float32, weak_type=True), 'reference_rotation': Array(0., dtype=float32, weak_type=True), 'reference_velocity': Array(0., dtype=float32, weak_type=True)}
{'index_step': Array(0., dtype=float32, weak_type=True), 'reward_tuple': {'reference_angular': Array(0., dtype=float32, weak_type=True), 'reference_position': Array(0., dtype=float32, weak_type=True), 'reference_rotation': Array(0., dtype=float32, weak_type=True), 'reference_velocity': Array(0., dtype=float32, weak_type=True)}, 'rng': Array([0, 0], dtype=uint32), 'steps': Array(0., dtype=float32, weak_type=True)}


In [15]:
for i in range(env.rollout_lenght-1):
    
    #print(i)
    #ctrl = -0.1 * jp.ones(env_eval.sys.nu)
    ctrl = -0.1 * jp.ones(env.sys.nu)
    #ctrl = -0.1 * jp.ones(env_replay.sys.nu)
    #ctrl = -0.1 * jp.ones(env_apg.sys.nu)
    #time
    #time = state.pipeline_state.time
    
    print('time: ',state.pipeline_state.time)    
         
    #state = jit_step_replay(state, ctrl)
    #state = jit_step_custom_eval(state, ctrl)
    #state = jit_step(state, ctrl)
    state = jit_step_custom_env(state,ctrl)
    #state = jit_step_custom_apg(state,ctrl)
    
    #print(state.pipeline_state.geom_xpos)
    #if state.done:
     #   print(time)
      #  break
    #print("Is done", state.done)
    print("Is rewards", state.reward)
    print('is step', state.info['index_step'])
    #print("reward tuple", state.info['reward_tuple'])
    print("Is pose", state.metrics['pose_error'])
    
    rollout.append(state.pipeline_state)

time:  0.0
to converted reference: 0
error demo replay: 0.009272252209484577
Is rewards -0.0005570441
is step 1.0
Is pose 0.0030484262
time:  0.016
to converted reference: 0
error demo replay: 0.017997145652770996
Is rewards -0.00018392537
is step 2.0
Is pose 0.0054769497
time:  0.032
to converted reference: 0
error demo replay: 0.02556402049958706
Is rewards -0.0002562675
is step 3.0
Is pose 0.007019113
time:  0.048000004
to converted reference: 0
error demo replay: 0.03300696611404419
Is rewards -0.00017771126
is step 4.0
Is pose 0.009385679
time:  0.064
to converted reference: 0
error demo replay: 0.04052149876952171
Is rewards -0.00021774767
is step 5.0
Is pose 0.0114693865
time:  0.079999976
to converted reference: 0
error demo replay: 0.04915088415145874
Is rewards -0.00030210192
is step 6.0
Is pose 0.014450781
time:  0.09599995
to converted reference: 0
error demo replay: 0.05986785143613815
Is rewards -0.00093334314
is step 7.0
Is pose 0.017599795
time:  0.11199992
to converted

In [16]:
from some_math.quaternion_diff import *

In [17]:
#6 right writs 9 left writs 12 right ankle, 15 left ankle

print(state.pipeline_state.geom_xpos)


#index 16

[[0.         0.         0.        ]
 [1.244459   1.4314553  0.8863175 ]
 [1.2936954  1.4463661  1.1639953 ]
 [1.3743398  1.4491659  1.4287616 ]
 [1.2938482  1.2095433  1.2009286 ]
 [1.3595285  1.1709361  1.089154  ]
 [1.4924569  1.2099655  1.0785233 ]
 [1.1725777  1.6986383  1.1928931 ]
 [1.1430155  1.7022065  1.0385554 ]
 [1.2249813  1.643487   0.9429525 ]
 [1.3444555  1.3574903  0.62328047]
 [1.3284034  1.3157953  0.27006006]
 [1.2513252  1.2712458  0.06409848]
 [1.275435   1.6062667  0.63669026]
 [1.2866184  1.7668781  0.269242  ]
 [1.2806823  1.8599191  0.04734407]]


In [18]:
#HTML(html.render(env_replay.sys.tree_replace({'opt.timestep': env_replay.dt}), rollout))
#HTML(html.render(env_eval.sys.tree_replace({'opt.timestep': env_eval.dt}), rollout))
HTML(html.render(env.sys.tree_replace({'opt.timestep': env.dt}), rollout))
#HTML(html.render(env_apg.sys.tree_replace({'opt.timestep': env_apg.dt}), rollout))
