In [1]:
from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.apg import train as apg
from brax.training.agents.apg import networks as apg_networks
from brax.io import html, mjcf, model
from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
from jax import vmap
import jax.random
from jax import lax

In [2]:
from utils.SimpleConverter import SimpleConverter
from agent_mimic_env.agent_template import HumanoidTemplate
from agent_mimic_env.agent_eval_template import HumanoidEvalTemplate
from agent_mimic_env.agent_test_apg import HumanoidAPGTest
from utils.util_data import *
from copy import deepcopy

In [3]:
import agent_mimic_env
from agent_mimic_env import register_mimic_env


In [4]:

import distutils.util
import os
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

Setting environment variable to use GPU rendering:
env: MUJOCO_GL=egl


In [5]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.8" # 0.9 causes too much lag. 
from datetime import datetime
import functools

# Math
import jax.numpy as jp
import numpy as np
import jax
from jax import config # Analytical gradients work much better with double precision.
config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)
config.update('jax_default_matmul_precision', jax.lax.Precision.HIGH)
from brax import math

# Sim
import mujoco
import mujoco.mjx as mjx

# Brax
from brax import envs
from brax.base import Motion, Transform
from brax.io import mjcf
from brax.envs.base import PipelineEnv, State
from brax.mjx.pipeline import _reformat_contact
from brax.training.acme import running_statistics
from brax.io import model

# Algorithms
# from brax.training.agents.apg import train as apg
# from brax.training.agents.apg import networks as apg_networks
from brax.training.agents.ppo import train as ppo

# Supporting
from etils import epath
import mediapy as media
import matplotlib.pyplot as plt
from ml_collections import config_dict
from typing import Any, Dict

In [6]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [7]:
import yaml
from box import Box
# Path to your YAML file
yaml_file_path = 'config_params/walk.yaml'
# Load the YAML file
with open(yaml_file_path, 'r') as file:
    args = Box(yaml.safe_load(file))

In [8]:
print(args)

{'num_envs': 256, 'num_eval_envs': 32, 'lr': 0.0003, 'max_it': 1000, 'max_grad_norm': 0.4, 'seed': 0, 'system_config': 'humanoid', 'demo_replay_mode': 'threshold', 'threshold': 0.4, 'normalize_observations': True, 'cycle_len': 195, 'ep_len': 39, 'ep_len_eval': 54, 'use_lr_scheduler': True, 'reward_scaling': 0.02, 'rot_weight': 0.5, 'vel_weight': 0.01, 'ang_weight': 0.01, 'deep_mimic_reward_weights': {'w_p': 0.3, 'w_v': 0.1, 'w_e': 0.5, 'w_c': 0.1}, 'deep_mimic_weights_factors': {'w_pose': 2, 'w_angular': 0.005, 'w_efector': 5, 'w_com': 100}, 'ref': 'motions/humanoid3d_walk_duplicated.txt'}


In [9]:
args.deep_mimic_reward_weights.w_p

0.3

In [10]:
from agent_mimic_env.pds_controllers_agents import feedback_pd_controller, stable_pd_controller_action


In [11]:
env_replay,env_eval,env,env_apg=register_mimic_env(args)


#for the env
jit_reset = jax.jit(env.reset)
#jit_step = jax.jit(env.step)
jit_step_custom_env = jax.jit(env.step_custom)
env.set_pd_callback(stable_pd_controller_action)


# #for the eval
jit_reset_eval = jax.jit(env_eval.reset)
jit_step_custom_eval = jax.jit(env_eval.step_custom)
env_eval.set_pd_callback(stable_pd_controller_action)

#for the replay
jit_reset_replay = jax.jit(env_replay.reset)
jit_step_replay= jax.jit(env_replay.step)

#for apg
# jit_reset_apg = jax.jit(env_apg.reset)
# jit_step_custom_apg= jax.jit(env_apg.step_custom)
# env_apg.set_pd_callback(stable_pd_controller_action)





qpos init (195, 35)
qvel init (195, 34)


In [12]:
print(env_eval.sys.opt.timestep)
print(env_eval.dt)

0.002
0.016


In [13]:
print(env_eval.rollout_lenght)
print(env_eval.cycle_len)
print(env_eval.rot_weight)
print(env_eval.vel_weight)
print(env_eval.ang_weight)
print(env_eval.reward_scaling)
print(env_eval.dt)
# print(env_eval.sys.dt)

print(env.rollout_lenght)
print(env.cycle_len)
print(env.rot_weight)
print(env.vel_weight)
print(env.ang_weight)
print(env.reward_scaling)
print(env.dt)
print(env.dt)
print(env.err_threshold)

195
195
0.5
0.01
0.01
0.02
0.016
195
195
0.5
0.01
0.01
0.02
0.016
0.016
0.4


In [14]:
# initialize the state

#for the replay
state = jit_reset_replay(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

#for the eval
# state = jit_reset_eval(jax.random.PRNGKey(0))
# rollout = [state.pipeline_state]
#for env
# state = jit_reset(jax.random.PRNGKey(0))
# rollout = [state.pipeline_state]

#for apg test
# state = jit_reset_apg(jax.random.PRNGKey(0))
# rollout = [state.pipeline_state]

# grab a 500 steps
n_steps = 500

In [15]:
print(state.metrics)
print(state.info)

{'fall': Array(0., dtype=float64), 'pose_error': Array(0., dtype=float64), 'reference_angular': Array(0., dtype=float64, weak_type=True), 'reference_position': Array(0., dtype=float64, weak_type=True), 'reference_rotation': Array(0., dtype=float64, weak_type=True), 'reference_velocity': Array(0., dtype=float64, weak_type=True), 'step_index': Array(0, dtype=int64, weak_type=True)}
{'reward_tuple': {'reference_angular': Array(0., dtype=float64, weak_type=True), 'reference_position': Array(0., dtype=float64, weak_type=True), 'reference_rotation': Array(0., dtype=float64, weak_type=True), 'reference_velocity': Array(0., dtype=float64, weak_type=True)}, 'rng': Array([0, 0], dtype=uint32), 'steps': Array(0., dtype=float64, weak_type=True)}


In [16]:
id = 64 + 1

id =id % 65
id
# cy = (66%65)
# cy

0

In [17]:
for i in range(env_replay.cycle_len-1):
    
    #print(i)
    #ctrl = -0.1 * jp.ones(env_eval.sys.nu)
    #ctrl = -0.1 * jp.ones(env.sys.nu)
    ctrl = -0.1 * jp.ones(env_replay.sys.nu)
    #ctrl = -0.1 * jp.ones(env_apg.sys.nu)
    #time
    #time = state.pipeline_state.time
    
    #print('time: ',state.pipeline_state.time)    
         
    state = jit_step_replay(state, ctrl)
    #state = jit_step_custom_eval(state, ctrl)
    #state = jit_step(state, ctrl)
    #state = jit_step_custom_env(state,ctrl)
    #state = jit_step_custom_apg(state,ctrl)
    
    #print(state.pipeline_state.geom_xpos)
    #if state.done:
     #   print(time)
      #  break
    #print("Is done", state.done)
    # print("Is rewards", state.reward)
    print('is step', state.metrics['step_index'])
    #print("reward tuple", state.info['reward_tuple'])
    # print("Is ", state.metrics['pose_error'])
    
    rollout.append(state.pipeline_state)

is step 1
is step 2
is step 3
is step 4
is step 5
is step 6
is step 7
is step 8
is step 9
is step 10
is step 11
is step 12
is step 13
is step 14
is step 15
is step 16
is step 17
is step 18
is step 19
is step 20
is step 21
is step 22
is step 23
is step 24
is step 25
is step 26
is step 27
is step 28
is step 29
is step 30
is step 31
is step 32
is step 33
is step 34
is step 35
is step 36
is step 37
is step 38
is step 39
is step 40
is step 41
is step 42
is step 43
is step 44
is step 45
is step 46
is step 47
is step 48
is step 49
is step 50
is step 51
is step 52
is step 53
is step 54
is step 55
is step 56
is step 57
is step 58
is step 59
is step 60
is step 61
is step 62
is step 63
is step 64
is step 65
is step 66
is step 67
is step 68
is step 69
is step 70
is step 71
is step 72
is step 73
is step 74
is step 75
is step 76
is step 77
is step 78
is step 79
is step 80
is step 81
is step 82
is step 83
is step 84
is step 85
is step 86
is step 87
is step 88
is step 89
is step 90
is step 91
is step 

In [18]:
from some_math.quaternion_diff import *

In [19]:
#6 right writs 9 left writs 12 right ankle, 15 left ankle

print(state.pipeline_state.geom_xpos)


#index 16

[[ 0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 6.20090570e+00 -1.59753636e-03  9.17060086e-01]
 [ 6.23262037e+00 -6.46088548e-03  1.20139818e+00]
 [ 6.26244145e+00 -8.95108268e-03  1.47868209e+00]
 [ 6.19138362e+00 -2.22853758e-01  1.19757965e+00]
 [ 6.10929672e+00 -2.82049935e-01  9.64719533e-01]
 [ 6.07610463e+00 -3.16459124e-01  8.34256659e-01]
 [ 6.22300922e+00  2.18574503e-01  1.20201332e+00]
 [ 6.32704048e+00  2.72369877e-01  9.88154886e-01]
 [ 6.41894919e+00  2.81096666e-01  8.84314135e-01]
 [ 6.28964892e+00 -9.10024563e-02  6.53397154e-01]
 [ 6.41492240e+00 -1.02829318e-01  2.64580211e-01]
 [ 6.50260983e+00 -1.09762609e-01  4.37442467e-02]
 [ 6.13159279e+00  8.43163728e-02  6.47011893e-01]
 [ 5.96795246e+00  6.77869378e-02  2.76887902e-01]
 [ 5.88630560e+00  5.20817792e-02  6.81942568e-02]]


In [20]:
HTML(html.render(env_replay.sys.tree_replace({'opt.timestep': env_replay.dt}), rollout))
#HTML(html.render(env_eval.sys.tree_replace({'opt.timestep': env_eval.dt}), rollout))
#HTML(html.render(env.sys.tree_replace({'opt.timestep': env.dt}), rollout))
#HTML(html.render(env_apg.sys.tree_replace({'opt.timestep': env_apg.dt}), rollout))
