In [1]:
from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.apg import train as apg
from brax.training.agents.apg import networks as apg_networks
from brax.io import html, mjcf, model
from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
from jax import vmap
import jax.random
from jax import lax

In [2]:
from utils.SimpleConverter import SimpleConverter
from agent_mimic_env.agent_template import HumanoidTemplate
from agent_mimic_env.agent_eval_template import HumanoidEvalTemplate
from agent_mimic_env.agent_test_apg import HumanoidAPGTest
from utils.util_data import *
from copy import deepcopy

In [3]:
import agent_mimic_env
from agent_mimic_env import register_mimic_env


In [4]:

import distutils.util
import os
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

Setting environment variable to use GPU rendering:
env: MUJOCO_GL=egl


In [5]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.8" # 0.9 causes too much lag. 
from datetime import datetime
import functools

# Math
import jax.numpy as jp
import numpy as np
import jax
from jax import config # Analytical gradients work much better with double precision.
config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)
config.update('jax_default_matmul_precision', jax.lax.Precision.HIGH)
from brax import math

# Sim
import mujoco
import mujoco.mjx as mjx

# Brax
from brax import envs
from brax.base import Motion, Transform
from brax.io import mjcf
from brax.envs.base import PipelineEnv, State
from brax.mjx.pipeline import _reformat_contact
from brax.training.acme import running_statistics
from brax.io import model

# Algorithms
# from brax.training.agents.apg import train as apg
# from brax.training.agents.apg import networks as apg_networks
from brax.training.agents.ppo import train as ppo

# Supporting
from etils import epath
import mediapy as media
import matplotlib.pyplot as plt
from ml_collections import config_dict
from typing import Any, Dict

In [6]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [7]:
import yaml
from box import Box
# Path to your YAML file
yaml_file_path = 'config_params/punch.yaml'
# Load the YAML file
with open(yaml_file_path, 'r') as file:
    args = Box(yaml.safe_load(file))

In [8]:
print(args)

{'num_envs': 256, 'num_eval_envs': 32, 'lr': 0.0003, 'max_it': 1000, 'max_grad_norm': 0.4, 'seed': 0, 'system_config': 'humanoid', 'demo_replay_mode': 'threshold', 'threshold': 0.4, 'normalize_observations': True, 'cycle_len': 130, 'ep_len': 65, 'ep_len_eval': 54, 'use_lr_scheduler': True, 'reward_scaling': 0.02, 'rot_weight': 0.5, 'vel_weight': 0.01, 'ang_weight': 0.01, 'deep_mimic_reward_weights': {'w_p': 0.3, 'w_v': 0.1, 'w_e': 0.5, 'w_c': 0.1}, 'deep_mimic_weights_factors': {'w_pose': 2, 'w_angular': 0.005, 'w_efector': 5, 'w_com': 100}, 'model': 'models/final_humanoid.xml', 'ref': 'motions/humanoid3d_punch_duplicated.txt'}


In [9]:
args.deep_mimic_reward_weights.w_p

0.3

In [10]:
from agent_mimic_env.pds_controllers_agents import feedback_pd_controller, stable_pd_controller_action


In [11]:
env_replay,env_eval,env,env_apg=register_mimic_env(args)




# #for the eval
jit_reset_eval = jax.jit(env_eval.reset)
jit_step_custom_eval = jax.jit(env_eval.step_custom)
env_eval.set_pd_callback(stable_pd_controller_action)




In [12]:
env_eval.sys.opt

Option(timestep=0.002, impratio=Array(100., dtype=float64, weak_type=True), tolerance=Array(1.e-08, dtype=float64, weak_type=True), ls_tolerance=Array(0.01, dtype=float64, weak_type=True), gravity=Array([ 0.  ,  0.  , -9.81], dtype=float64), wind=Array([0., 0., 0.], dtype=float64), density=Array(0., dtype=float64, weak_type=True), viscosity=Array(0., dtype=float64, weak_type=True), has_fluid_params=False, integrator=<IntegratorType.EULER: 0>, cone=<ConeType.PYRAMIDAL: 0>, jacobian=<JacobianType.AUTO: 2>, solver=<SolverType.NEWTON: 2>, iterations=1, ls_iterations=4, disableflags=<DisableBit.EULERDAMP: 16384>)

In [13]:
print(env_eval.sys.opt.timestep)
print(env_eval.dt)

0.002
0.02


In [14]:
print(env_eval.rollout_lenght)
print(env_eval.cycle_len)
print(env_eval.rot_weight)
print(env_eval.vel_weight)
print(env_eval.ang_weight)
print(env_eval.reward_scaling)
print(env_eval.dt)
# print(env_eval.sys.dt)


65
130
0.5
0.01
0.01
0.02
0.02


In [15]:
# initialize the state

#for the replay
# state = jit_reset_replay(jax.random.PRNGKey(0))
# rollout = [state.pipeline_state]

#for the eval
state = jit_reset_eval(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]
#for env
# state = jit_reset(jax.random.PRNGKey(0))
# rollout = [state.pipeline_state]

#for apg test
# state = jit_reset_apg(jax.random.PRNGKey(0))
# rollout = [state.pipeline_state]

# grab a 500 steps
n_steps = 500

In [16]:
print(state.metrics)
print(state.info)

{'fall': Array(0., dtype=float64), 'min_reference_tracking': Array(0., dtype=float64, weak_type=True), 'pose_error': Array(0., dtype=float64), 'reference_angular': Array(0., dtype=float64, weak_type=True), 'reference_position': Array(0., dtype=float64, weak_type=True), 'reference_rotation': Array(0., dtype=float64, weak_type=True), 'reference_velocity': Array(0., dtype=float64, weak_type=True), 'step_index': Array(48., dtype=float64)}
{'default_pos': Array([ 1.00952578,  0.96560349,  0.763598  ,  0.69266161,  0.00396145,
        0.0290793 ,  0.72066538,  0.0344528 ,  0.38038578, -0.03930848,
        0.25575706, -0.27285779, -1.17071799, -1.60473499,  0.07730953,
       -1.0934958 ,  0.82069482,  0.59607216,  0.09617058, -0.84933419,
        1.92958378, -0.723898  , -0.63271327, -1.08916748, -1.27049323,
       -0.16462933, -0.02070159, -0.17030865,  0.37251784,  0.20128708,
        0.66214301, -0.50385212, -0.0529365 , -0.12219402, -0.05907235],      dtype=float64), 'kinematic_ref': Ar

In [17]:
id = 9

id =(195 % 195) /195
id
# cy = (66%65)
# cy

0.0

In [18]:
for i in range(env_eval.rollout_lenght-1):
    
    #print(i)
    ctrl = -0.1 * jp.ones(env_eval.sys.nu)
   
    #print('time: ',state.pipeline_state.time)    
         
    #state = jit_step_replay(state, ctrl)
    state = jit_step_custom_eval(state, ctrl)
    #print(state.pipeline_state.geom_xpos)
    #if state.done:
     #   print(time)
      #  break
    #print("Is done", state.done)
    print("Is rewards", state.reward)
    #print('is step', state.metrics['step_index'])
    #print("reward tuple", state.info['reward_tuple'])
    # print("Is ", state.metrics['pose_error'])
    
    rollout.append(state.pipeline_state)

Is rewards -0.004667226793952933
Is rewards -0.007139026974603351
Is rewards -0.010862601996117102
Is rewards -0.018280714260659713
Is rewards -0.02327371213429831
Is rewards -0.03367913397381057
Is rewards -0.027897784006008536
Is rewards -0.027601226650832693
Is rewards -0.028945197387786923
Is rewards -0.030839949453252337
Is rewards -0.034474945439321114
Is rewards -0.03814130115001462
Is rewards -0.0424929151592616
Is rewards -0.04787226323898677
Is rewards -0.046876699003118805
Is rewards -0.050969622300051184
Is rewards -0.19760659083792892
Is rewards -0.05088738993834536
Is rewards -0.05432989072841864
Is rewards -0.05740656595852785
Is rewards -0.06044491399722023
Is rewards -0.06247744215960634
Is rewards -0.06445911598212348
Is rewards -0.06613885681282093
Is rewards -0.06918422718370655
Is rewards -0.07181801776474422
Is rewards -0.07573485842160547
Is rewards -0.07939094039976365
Is rewards -0.08223394383131552
Is rewards -0.08774112526230227
Is rewards -0.0907003729579358

In [19]:
from some_math.quaternion_diff import *

In [20]:
#6 right writs 9 left writs 12 right ankle, 15 left ankle

print(state.pipeline_state.geom_xpos)
# env_eval.named.data.geom_xpos['green_sphere', 'z']
# env_eval.named.data.qpos['swing']

#index 16

[[0.         0.         0.        ]
 [0.82599287 1.95442066 0.43775664]
 [0.69498117 2.19987885 0.45605995]
 [0.56788583 2.43856968 0.4178242 ]
 [0.71980936 2.39297306 0.1629824 ]
 [0.63571681 2.42567686 0.04088801]
 [0.49741341 2.41232124 0.04114386]
 [0.64708885 2.19664787 0.72701359]
 [0.7167798  2.05270223 0.72832616]
 [0.71577239 1.96351928 0.62178202]
 [0.81016245 1.76782492 0.20169191]
 [0.86040319 1.46352531 0.08986073]
 [0.9629177  1.2520773  0.09772377]
 [0.84622745 1.67166404 0.5745805 ]
 [1.00289712 1.3290136  0.67122301]
 [1.12151837 1.13864945 0.71942454]]


In [21]:

HTML(html.render(env_eval.sys.tree_replace({'opt.timestep': env_eval.dt}), rollout))
