In [1]:
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ['PATH']
#os.environ['XLA_FLAGS']='--xla_gpu_cuda_data_dir=/usr/local/cuda-12.3'
os.environ['XLA_FLAGS']='--xla_gpu_cuda_data_dir=/usr/local/cuda-12.1'

In [2]:
import jax.numpy as jp
from mujoco import mjx
from brax.io import html
from IPython.display import HTML, clear_output
import matplotlib.pyplot as plt
import jax
from brax import envs
from qube_mjx import Qube_mjx
from tqdm import tqdm
from mujoco.mjx._src.forward import forward

In [3]:
from jax import config

debug_nans=True
config.update("jax_debug_nans", debug_nans)

In [4]:
def array_to_state(x, bef_state, env):
    # update qpos and qvel information afterwards call mjx forward function to ?synchronize? the information across the state
    # this is also done when initializing a state with a given psotion and velocity in the examples
    
    pipeline_state = bef_state.pipeline_state
    q_idx = pipeline_state.qpos.shape[0]
    qd_idx = q_idx + pipeline_state.qvel.shape[0]
    q = x[:q_idx]
    qd = x[q_idx:qd_idx]
    
    pipeline_state = pipeline_state.replace(qpos=q, qvel=qd)
    pipeline_state = forward(env.sys, pipeline_state)
    return bef_state.replace(pipeline_state=pipeline_state)

In [5]:
envs.register_environment('qube_mjx', Qube_mjx)
env_name='qube_mjx'
env = envs.create(env_name=env_name, auto_reset=False)
jit_reset = jax.jit(env.reset)

if debug_nans:
    # has to be unwrapped because a brax wrapper would otherwise use jax.lax.scan
    # then we would only see that a nan is produced in the scan but not what 'creates' the nan
    jit_step=env.unwrapped.step
else:
    jit_step = jax.jit(env.unwrapped.step)

In [6]:
def loss_fn(state, v0, c0):
    v0 = array_to_state(state, v0, env)
    
    # next states
    v1 = jit_step(v0, jp.zeros(env.action_size))
    c1 = jit_step(c0, jp.zeros(env.action_size))

    loss_q = jp.mean(jp.square(c1.pipeline_state.qpos-v1.pipeline_state.qpos))
    loss_q_dot = jp.mean(jp.square(c1.pipeline_state.qvel-v1.pipeline_state.qvel))
    loss = loss_q + loss_q_dot
    return loss

grad_fn = jax.value_and_grad(loss_fn)

def descent_step(v0, c0, alpha):
    q = v0.pipeline_state.qpos
    q_dot = v0.pipeline_state.qvel
    state = jp.concatenate((q,q_dot))
    val, grad = grad_fn(state, v0, c0)
    state = state - alpha * grad
    v0 = array_to_state(state, v0, env)
    return v0, val

if not debug_nans:
    descent_step = jax.jit(descent_step)
    loss_fn = jax.jit(loss_fn)

In [7]:
losses=[]

for i in range(5):
    # init keys
    reset_key, c_key = jax.random.split(jax.random.PRNGKey(i))
    q_key, q_dot_key = jax.random.split(c_key)
    
    c0 = jit_reset(reset_key)

    # some small offset to add to q
    eps_q = jax.random.normal(q_key, shape=(c0.pipeline_state.qpos.size,))/10

    # some small offset to add to q_dot
    eps_q_dot = jax.random.normal(q_dot_key, shape=(c0.pipeline_state.qvel.size,))/10
    
    v0 = jit_reset(reset_key)
    q = v0.pipeline_state.qpos
    q_dot = v0.pipeline_state.qvel

    # update rest of v0 information
    v0 = array_to_state(jp.concatenate([q+eps_q,q_dot+eps_q_dot]),v0,env)

    steps = 1000
    alpha = 1e-2
    loss = [loss_fn(jp.concatenate((v0.pipeline_state.qpos,v0.pipeline_state.qvel)),v0,c0)]
    
    # gradient descent
    for _ in tqdm(range(steps)):
        v0, loss_val = descent_step(v0, c0, alpha)
        loss.append(loss_val)
    losses.append(loss)


FloatingPointError: invalid value (nan) encountered in jit(div)

In [None]:
for loss in losses:
    plt.plot(loss)
plt.yscale('log')
plt.show()