In [None]:
import functools
import jax
import os
import time
# jax.config.update('jax_platform_name', "cpu")

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output

NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
import flax
from brax import envs
from brax.io import model
from brax.io import json
from brax.io import html
from brax.training.agents.ppo import train as ppo

import distutils.util
import os
import subprocess


import jax.numpy as jnp
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

from ambersim.envs.exo import Exo
#
env = Exo()

In [None]:
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)

# warming up JIT-compiled function
print("warming up JIT compilation...")
rng = jax.random.PRNGKey(0)
# rng = jax.random.split(rng, num_env)

start = time.time()
state = jit_env_reset(rng)

end = time.time()
print(f"compilation time: {end - start}")

In [None]:

# initial_state['q'] = initial_state['q'].at[2].set(0.945)
rollout = []
actions = []
state = jit_env_reset(rng)
tCur = 0
for _ in range(100):
    start = time.time()
    rollout.append(state)
    act = jnp.zeros(12)
    
    state = jit_env_step(state, act)
    tCur = tCur + env.dt
    end = time.time()
    print(f"step time: {end - start}")


In [None]:
state.obs[0:2]

In [None]:
import mujoco
from mujoco import mjx
def get_image(renderer, env, state, camera):
    d = mujoco.MjData(env.model)
    # write the mjx.Data into an mjData object
    mjx.device_get_into(d, state)
    mujoco.mj_forward(env.model, d)
    
    camera.lookat[0] = d.qpos[0]
    camera.lookat[1] = d.qpos[1]
    
    # use the mjData object to update the renderer
    
    renderer.update_scene(d, camera=camera)

    # time = d.time
    # curTime = f"Time = {time:.3f}"
    # mujoco.mjr_overlay(mujoco.mjtFont.mjFONT_NORMAL,mujoco.mjtGridPos.mjGRID_TOPLEFT,renderer._rect,curTime,'test',renderer._mjr_context)

    return renderer.render()

env.getRender()
images = []
for i in range(len(rollout)):
    temp_State = rollout[i].pipeline_state
    images.append(get_image(env.renderer,env,temp_State,env.camera))

media.show_video(images, fps=1.0 / env.dt)

In [None]:
state.pipeline_state.qfrc_actuator

In [None]:
output_file = 'walking_box0.mp4'
# Save the video
media.write_video(output_file, images,fps=30) 

In [None]:
import utils.env_utils as env_utils
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 5)
rand_geom_quat = jax.vmap(updateGeomQuat,in_axes=(None,0,None))(env.sys,jnp.array([0,1,2,3,4]),rng)

In [None]:
# from traj_opt import *
jit_env_reset = jax.jit(env.reset2GivenState)
jit_env_step = jax.jit(env.step)
evalCostFunc = functools.partial(evalCost, exo_behav,jit_env_reset, jit_env_step, env.dt)
alpha = exo_behav.alpha
costs = evalCostFunc(alpha,1)
                

In [None]:
def evalCost(exo_behav, jit_env_reset, jit_env_step, dt, alpha,step_dur, timesteps = 1):
    # rollout = []
    costs = []
    # jit_env_reset = jax.jit(env.reset2GivenState)
    # jit_env_reset_rand = jax.jit(env.reset)
    # jit_env_step = jax.jit(env.step)
    # jit_env_des = jax.jit(exo_behav.getDesiredGivenBez)
    # jit_env_act = jax.jit(exo_behav._PD_control)
    # des_state = jit_env_des(alpha,step_dur,tCur)
    # pd_torque = jit_env_act(state,des_state)  
    # state = jit_env_step(state, pd_torque)

    # set it to some random initial state close to the initial guess
    # initial_state = exo_behavior.getInitialStateGivenBez(alpha)
    initial_state = exo_behav.getInitState()
    state = jit_env_reset(initial_state)
    tCur = 0
    for _ in range(timesteps):
        # rollout.append(state.pipeline_state)
        des_state = exo_behav.getDesiredGivenBez(alpha,step_dur,tCur)
        pd_torque = exo_behav._PD_control(state,des_state)
        state = jit_env_step(state, pd_torque)
        costs.append(-state.reward)
        tCur = tCur + dt
    
    return jnp.mean(jnp.array(costs))

In [None]:
evalCostFunc = functools.partial(evalCost, exo_behav,jit_env_reset, jit_env_step, env.dt)
alpha = exo_behav.alpha
costs = evalCostFunc(alpha,1)

In [None]:
jax.grad(evalCostFunc, argnums=0)

In [None]:
from jax import jacfwd, jit
asdf = jit(jacfwd(evalCostFunc))
asdf(jnp.ones((12, 8)), 1)

In [None]:
asdf(alpha,step_dur)

In [None]:
import time
start = time.time()
for _ in range(100):
    asdf(jnp.ones((12, 8)), 1)
end = time.time()
total = end - start
print(f"{total=}")

In [None]:
grad_evalCost(exo_behavior.alpha)

In [None]:
def grad_evalCost(alpha, step_dur=1):
    # Use jax.grad to compute the gradient
    grad_fn = jax.grad(evalCostFunc, argnums=0)  # argnums=0 means we're taking the gradient with respect to the first argument (alpha)
    return grad_fn(alpha, step_dur)

grad_evalCost(alpha)