# Brax Environment & State
Need to restart kernal to reflect changes made

In [1]:
from BaseRodentEnv import Rodent
from brax import envs
import jax
import numpy as np
import mediapy as media
from flax import linen as nn
from typing import Any, Callable, Sequence
import jax.numpy as jnp
from jax import random

In [2]:
Rodent

BaseRodentEnv.Rodent

In [3]:
envs.register_environment('rodent', Rodent)

In [4]:
env = envs.get_environment(env_name='rodent')

In [None]:
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
state = jit_reset(jax.random.PRNGKey(0))

In [None]:
state.pipeline_state

In [None]:
rollout = [state.pipeline_state]

for i in (range(50)):
    ctrl = jax.numpy.array(np.random.uniform(-1,1, env.sys.nu))
    state = jit_step(state, ctrl)
    rollout.append(state.pipeline_state)

In [None]:
media.show_video(env.render(rollout,camera='close_profile'), fps=1.0 / env.dt)

In [None]:
# import jax
# import mujoco
# from mujoco import mjx
# data_path = "adam_exp_tr.p"

In [None]:
# import pickle
# with open(data_path, "rb") as file:
#     data = pickle.load(file)

# Networks & Flax.linen

In [None]:
class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x
  
class AutoEncoder(nn.Module):
  encoder_widths: Sequence[int]
  decoder_widths: Sequence[int]
  input_shape: Sequence[int]

  def setup(self):
    input_dim = np.prod(self.input_shape)
    self.encoder = MLP(self.encoder_widths)
    self.decoder = MLP(self.decoder_widths + (input_dim,))

  def __call__(self, x):
    return self.decode(self.encode(x))

  def encode(self, x):
    assert x.shape[1:] == self.input_shape
    return self.encoder(jnp.reshape(x, (x.shape[0], -1)))

  def decode(self, z):
    z = self.decoder(z)
    x = nn.sigmoid(z)
    x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
    return x

In [None]:
model = AutoEncoder(encoder_widths=[20, 10, 5],
                    decoder_widths=[5, 10, 20],
                    input_shape=(12,))

batch = jnp.ones((16, 12))
variables = model.init(jax.random.key(0), batch)
encoded = model.apply(variables, batch, method=model.encode)
decoded = model.apply(variables, encoded, method=model.decode)

In [None]:
ActivationFn = Callable[[jnp.ndarray], jnp.ndarray]
Initializer = Callable[..., Any]

class Decoder(nn.Module):
    '''DEcoder for VAE'''
    layer_sizes: Sequence[int]
    activation: ActivationFn = nn.tanh
    kernel_init: Initializer = jax.nn.initializers.lecun_uniform()
    activate_final: bool = False
    bias: bool = True

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        for i, hidden_size in enumerate(self.layer_sizes):
            x = nn.Dense(
                hidden_size,
                name=f'hidden_{i}',
                kernel_init=self.kernel_init,
                use_bias=self.bias)(x)
            if i != len(self.layer_sizes) - 1 or self.activate_final:
                x = self.activation(x)
        return x

In [None]:
class Encoder(nn.Module):
    '''Encoder for VAE'''
    layer_sizes: Sequence[int]
    activation: ActivationFn = nn.tanh
    kernel_init: Initializer = jax.nn.initializers.lecun_uniform()
    bias: bool = True
    latents: int=60
    
    @nn.compact
    def __call__(self, x: jnp.ndarray):
        # For each layer in the sequence, make a dense net and apply layernorm then tanh
        for i, hidden_size in enumerate(self.layer_sizes):
            x = nn.Dense(
                hidden_size,
                name=f'hidden_{i}',
                kernel_init=self.kernel_init,
                use_bias=self.bias)(x)
            x = nn.LayerNorm(x)
            x = self.activation(x)
            
        mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
        logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
        return mean_x, logvar_x

# reset_function()

In [2]:
envs.register_environment('rodent', Rodent)
env = envs.get_environment(env_name='rodent')

Make random trajectory

In [17]:
# traj = {'qpos':jnp.arange(250), 'qvel':jnp.arange(250)}

# key = jax.random.PRNGKey(1)
# max_start_index = len(traj['qpos']) - 150
# start_index = jax.random.randint(key, (1,), 0, max_start_index)[0]

# sliced_qpos = traj['qpos'][start_index:start_index + 150]
# sliced_qvel = traj['qvel'][start_index:start_index + 150]

# sliced_traj = {'qpos': sliced_qpos, 'qvel': sliced_qvel}
# sliced_traj

# key = random.PRNGKey(90)
# frame = jnp.arange(250)
# traj = {'qpos':jnp.arange(5 * 74 * 250), 'qvel':jnp.arange(5 * 74 * 250)}
# random_frame = jax.random.randint(key,(1,), 0, len(frame))[0]
# # need to make sure slice index not excedding

# # qpos should be one 74 length jp.array here
# qpos = traj['qpos'][250*random_frame: 250*random_frame+74] # start with spcific frame
# qvel = traj['qvel'][250*random_frame: 250*random_frame+74]

# jax.lax.dynamic_slice(traj['qpos'], [250*random_frame], [74])

In [4]:
jit_reset = jax.jit(env.reset)
# jit_reset(jax.random.PRNGKey(0))

Make random keys with specific dimensions

In [5]:
key = jax.random.PRNGKey(100)
global_key, local_key = jax.random.split(key)
local_key, key_env, eval_key = jax.random.split(local_key, 3)
key_env

Array([1468288049, 2628321454], dtype=uint32)

In [6]:
# jax.process_count()
# jax.local_device_count()

In [7]:
num_envs = 1
process_count = jax.process_count()
local_device_count = jax.local_device_count()
local_devices_to_use = local_device_count

key_envs = jax.random.split(key_env, num_envs // process_count)
print(key_envs)
key_envs = jnp.reshape(key_envs, (local_devices_to_use, -1) + key_envs.shape[1:])
print(key_envs)

[[ 962186518 1447979288]]
[[[ 962186518 1447979288]]]


Wrapper function

In [8]:
wrap_for_training = envs.training.wrap
env_wrapped = wrap_for_training(env,action_repeat=1,episode_length=10)

`vmap` in JAX is a powerful function that vectorizes a given function over one or more of its input dimensions. This means it automatically applies the function in a batched manner across an axis of an array, efficiently handling multiple inputs at once without the need for explicit loops. This is particularly useful for speeding up calculations in machine learning and numerical computing by leveraging vectorized operations, which are typically faster than iterating through inputs one by one.

In [18]:
reset_fn = jax.jit(jax.vmap(env_wrapped.reset))
reset_fn

<PjitFunction of <function AutoResetWrapper.reset at 0x3119b77f0>>

In [19]:
env_state = reset_fn(key_envs)

(74,) (73,)




In [20]:
env_state.pipeline_state.qpos.shape

(1, 1, 74)

In [21]:
env.sys.qpos0.shape

(74,)

In [22]:
print(env.sys.nq, env.sys.nv)

74 73


In [23]:
env_state.pipeline_state.qvel

Array([[[26750, 26751, 26752, 26753, 26754, 26755, 26756, 26757, 26758,
         26759, 26760, 26761, 26762, 26763, 26764, 26765, 26766, 26767,
         26768, 26769, 26770, 26771, 26772, 26773, 26774, 26775, 26776,
         26777, 26778, 26779, 26780, 26781, 26782, 26783, 26784, 26785,
         26786, 26787, 26788, 26789, 26790, 26791, 26792, 26793, 26794,
         26795, 26796, 26797, 26798, 26799, 26800, 26801, 26802, 26803,
         26804, 26805, 26806, 26807, 26808, 26809, 26810, 26811, 26812,
         26813, 26814, 26815, 26816, 26817, 26818, 26819, 26820, 26821,
         26822]]], dtype=int32)