In [24]:
from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
# from brax.training.agents.apg import train as apg
# from brax.training.agents.apg import networks as apg_networks
from brax.io import html, mjcf, model
from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
from jax import vmap
import jax.random
from jax import lax

In [2]:
import distutils.util
import os
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

Setting environment variable to use GPU rendering:
env: MUJOCO_GL=egl


In [3]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.8" # 0.9 causes too much lag. 
from datetime import datetime
import functools

# Math
import jax.numpy as jp
import numpy as np
import jax
from jax import config # Analytical gradients work much better with double precision.
config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)
config.update('jax_default_matmul_precision', jax.lax.Precision.HIGH)
from brax import math

# Sim
import mujoco
import mujoco.mjx as mjx

# Brax
from brax import envs
from brax.base import Motion, Transform
from brax.io import mjcf
from brax.envs.base import PipelineEnv, State
from brax.mjx.pipeline import _reformat_contact
from brax.training.acme import running_statistics
from brax.io import model

# Algorithms
# from brax.training.agents.apg import train as apg
# from brax.training.agents.apg import networks as apg_networks
from brax.training.agents.ppo import train as ppo

# Supporting
from etils import epath
import mediapy as media
import matplotlib.pyplot as plt
from ml_collections import config_dict
from typing import Any, Dict


In [4]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [5]:
import agent_mimic_env
from agent_mimic_env import register_mimic_env, register_ppo_env

In [6]:
from agent_mimic_env.pds_controllers_agents import feedback_pd_controller, stable_pd_controller_action


In [7]:
import yaml
from box import Box
# Path to your YAML file
yaml_file_path = 'config_params/punch.yaml'
# Load the YAML file
with open(yaml_file_path, 'r') as file:
    args = Box(yaml.safe_load(file))



In [8]:
print(args)

{'num_envs': 256, 'num_eval_envs': 32, 'lr': 0.0003, 'max_it': 1000, 'max_grad_norm': 0.4, 'seed': 0, 'system_config': 'humanoid', 'demo_replay_mode': 'threshold', 'threshold': 0.4, 'normalize_observations': True, 'cycle_len': 130, 'ep_len': 65, 'ep_len_eval': 54, 'use_lr_scheduler': True, 'reward_scaling': 0.02, 'rot_weight': 0.5, 'vel_weight': 0.3, 'ang_weight': 0.01, 'deep_mimic_reward_weights': {'w_p': 0.65, 'w_v': 0.1, 'w_e': 0.15, 'w_c': 0.1}, 'deep_mimic_weights_factors': {'w_pose': 2, 'w_angular': 0.1, 'w_efector': 40, 'w_com': 10}, 'model': 'models/final_humanoid.xml', 'ref': 'motions/humanoid3d_punch_duplicated.txt'}


In [9]:
env_ppo = register_ppo_env(args)
env_ppo.set_pd_callback(stable_pd_controller_action)
jit_reset = jax.jit(env_ppo.reset)
#jit_step= jax.jit(env_ppo.step)
jit_step= jax.jit(env_ppo.step_custom)



this is the model:  models/final_humanoid.xml


In [10]:
#checking if doible point precision is activate
xd = jp.exp(11)
xd


Array(59874.1417152, dtype=float64, weak_type=True)

In [11]:

s = env_ppo.dict_ee
print(s)

xd= env_ppo.sys.geom_pos[env_ppo.dict_ee]

xd


#right wrist idex 6 initial pos on mjx pos="0 0 -0.258947"
#left wrist inital 9 pos on mjx pos="0 0 -0.258947"
#right ankle index 12 pos [0.045   ,  0.      , -0.0225  ],
#left ankle  index 15"0.045   ,  0.      , -0.0225"

[ 6  9 12 15]


Array([[ 0.      ,  0.      , -0.258947],
       [ 0.      ,  0.      , -0.258947],
       [ 0.045   ,  0.      , -0.0225  ],
       [ 0.045   ,  0.      , -0.0225  ]], dtype=float64)

In [12]:
print(env_ppo.sys.opt)
print(env_ppo.dt)

Option(timestep=0.002, impratio=Array(100., dtype=float64, weak_type=True), tolerance=Array(1.e-08, dtype=float64, weak_type=True), ls_tolerance=Array(0.01, dtype=float64, weak_type=True), gravity=Array([ 0.  ,  0.  , -9.81], dtype=float64), wind=Array([0., 0., 0.], dtype=float64), density=Array(0., dtype=float64, weak_type=True), viscosity=Array(0., dtype=float64, weak_type=True), has_fluid_params=False, integrator=<IntegratorType.EULER: 0>, cone=<ConeType.PYRAMIDAL: 0>, jacobian=<JacobianType.AUTO: 2>, solver=<SolverType.NEWTON: 2>, iterations=1, ls_iterations=4, disableflags=<DisableBit.EULERDAMP: 16384>)
0.02


In [13]:
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

n_steps = 500

In [14]:
# pos = state.pipeline_state.qpos
# vel = state.pipeline_state.qvel

# xd = np.concatenate([pos,vel])

# xd.shape

In [15]:
state.pipeline_state.qpos[:]

Array([ 1.00952578,  0.96560349,  0.763598  ,  0.69266161,  0.00396145,
        0.0290793 ,  0.72066538,  0.0344528 ,  0.38038578, -0.03930848,
        0.25575706, -0.27285779, -1.17071799, -1.60473499,  0.07730953,
       -1.0934958 ,  0.82069482,  0.59607216,  0.09617058, -0.84933419,
        1.92958378, -0.723898  , -0.63271327, -1.08916748, -1.27049323,
       -0.16462933, -0.02070159, -0.17030865,  0.37251784,  0.20128708,
        0.66214301, -0.50385212, -0.0529365 , -0.12219402, -0.05907235],      dtype=float64)

In [16]:
state.pipeline_state.geom_xpos[env_ppo.dict_ee]

Array([[1.67751076, 1.1740035 , 1.33588565],
       [0.87219029, 1.07226696, 0.94031171],
       [1.18705463, 1.14263128, 0.07462683],
       [0.71621663, 0.63931363, 0.02716332]], dtype=float64)

In [17]:
print(state.metrics)
print(state.info)

{'fall': Array(0., dtype=float64), 'pose_error': Array(0., dtype=float64), 'reference_angular': Array(0., dtype=float64, weak_type=True), 'reference_com': Array(0., dtype=float64, weak_type=True), 'reference_end_effector': Array(0., dtype=float64, weak_type=True), 'reference_quaternions': Array(0., dtype=float64, weak_type=True), 'step_index': Array(48., dtype=float64)}
{'default_pos': Array([ 1.00952578,  0.96560349,  0.763598  ,  0.69266161,  0.00396145,
        0.0290793 ,  0.72066538,  0.0344528 ,  0.38038578, -0.03930848,
        0.25575706, -0.27285779, -1.17071799, -1.60473499,  0.07730953,
       -1.0934958 ,  0.82069482,  0.59607216,  0.09617058, -0.84933419,
        1.92958378, -0.723898  , -0.63271327, -1.08916748, -1.27049323,
       -0.16462933, -0.02070159, -0.17030865,  0.37251784,  0.20128708,
        0.66214301, -0.50385212, -0.0529365 , -0.12219402, -0.05907235],      dtype=float64), 'kinematic_ref': Array([ 1.00952578,  0.96560349,  0.763598  ,  0.69266161,  0.003961

In [18]:
state.pipeline_state.subtree_com[1]

Array([1.03295954, 1.00265482, 0.82999913], dtype=float64)

In [19]:
env_ppo.rollout_lenght

65

In [20]:
for i in range(env_ppo.rollout_lenght-1):
    
    #print(i)
    ctrl = -0.1 * jp.ones(env_ppo.sys.nu)
    #time
    #time = state.pipeline_state.time
    
    #print('time: ',state.pipeline_state.time)    
         
    
    state = jit_step(state,ctrl)
    
    #print(state.pipeline_state.geom_xpos)
    #print("Is done", state.done)
    if state.done:
        print(state.pipeline_state.time)
        break
    
    #print("Is rewards", state.reward)
    # print("Is quat", state.metrics['reference_quaternions'])
    # print("Is ang",state.metrics['reference_angular'])
    # print("Is ee", state.metrics['reference_end_effector'])
    # print("Is com", state.metrics['reference_com'])

    #print("reward tuple", state.info['reward_tuple'])
    #print("Is ", state.metrics['pose_error'])
    #print("Is ", state.metrics['step_index'])
    
    rollout.append(state.pipeline_state)

reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward0.9999999999999862:
reward1.0:
reward1.0:
reward0.9999999999999913:
reward1.0:
reward0.9999999999999925:
reward0.9999999999999911:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward0.9999999999999873:
reward1.0:
reward0.9999999999999825:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward0.9999999999999827:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward1.0:
reward0.9999999999999869:
reward1.0:
reward1.0:


In [21]:
from some_math.rotation6D import quaternion_to_rotation_6d

In [22]:
#coputing the formulas.
def com_reward(com_current, ref_com):
    
    print(com_current.shape)
    #compute over the rows
    com_dist = jp.sum(jp.linalg.norm(com_current - ref_com, axis=-1))         
    print(com_dist)
    
    com_reward = jp.exp(-100 * (com_dist))
    
    return com_reward

def end_effector_diff(current_ee,current_ref_ee):
    # jax.debug.print("end diff result local:{}",current_ee)
    # jax.debug.print("end diff result ref:{}",current_ref_ee)        
    ee_dist = jp.linalg.norm(current_ee - current_ref_ee,axis=1)   
    
    print(ee_dist.shape)
    
    ee_dist = jp.sum(ee_dist)     
    print(ee_dist)
    
    #jax.debug.print("end dis:{}",ee_dist)
    ee_reward = jp.exp(- 5 * (ee_dist))
    #jax.debug.print("ee_reward:{}",ee_reward)
    
    return ee_reward
    





Testing rewards

In [23]:
HTML(html.render(env_ppo.sys.tree_replace({'opt.timestep': env_ppo.dt}), rollout))