# Brax Environment & State

In [6]:
from BaseRodentEnv import Rodent
from brax import envs
import jax
import numpy as np
import mediapy as media

In [7]:
Rodent

BaseRodentEnv.Rodent

In [8]:
envs.register_environment('rodent', Rodent)

In [9]:
env = envs.get_environment(env_name='rodent')

In [10]:
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
state = jit_reset(jax.random.PRNGKey(0))

In [11]:
state.pipeline_state

State(solver_niter=Array(0, dtype=int32), time=Array(0., dtype=float32, weak_type=True), qpos=Array([ 4.04833890e-02, -8.96714069e-03,  7.55159557e-02,  9.99944270e-01,
        1.74673495e-03,  4.56508016e-03,  9.35458764e-03,  4.98908758e-03,
        1.63373468e-03,  8.58705025e-03,  5.59443003e-03,  9.38590057e-03,
        9.71957669e-03,  5.60190668e-03, -1.04575156e-04,  9.36790463e-03,
        4.01520738e-05, -7.71624316e-03, -2.73497566e-03, -1.00942852e-03,
        7.73050543e-03,  3.20954784e-03,  3.57467635e-03,  9.61852726e-03,
       -8.71190336e-03, -3.41589446e-03, -8.77431873e-03, -6.90993061e-03,
        4.92489804e-03,  8.87549389e-03,  1.82025426e-04,  3.83699639e-03,
       -9.46147460e-03, -1.89946173e-03, -9.18047130e-03,  8.07757583e-03,
        2.22446187e-03,  9.94298235e-03, -9.68115032e-03,  6.05548825e-03,
       -6.93347445e-03, -2.66981125e-03, -7.76003813e-03,  9.18575749e-03,
       -8.71645473e-03,  2.32315055e-04,  2.46979948e-03, -1.84839009e-03,
      

In [12]:
rollout = [state.pipeline_state]

for i in (range(50)):
    ctrl = jax.numpy.array(np.random.uniform(-1,1, env.sys.nu))
    state = jit_step(state, ctrl)
    rollout.append(state.pipeline_state)

In [13]:
media.show_video(env.render(rollout,camera='close_profile'), fps=1.0 / env.dt)

0
This browser does not support the video tag.


In [14]:
# import jax
# import mujoco
# from mujoco import mjx
# data_path = "adam_exp_tr.p"

In [15]:
# import pickle
# with open(data_path, "rb") as file:
#     data = pickle.load(file)

# Networks & Flax.linen

In [16]:
from flax import linen as nn
from typing import Any, Callable, Sequence
import jax.numpy as jnp
from jax import random

In [17]:
class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x
  
class AutoEncoder(nn.Module):
  encoder_widths: Sequence[int]
  decoder_widths: Sequence[int]
  input_shape: Sequence[int]

  def setup(self):
    input_dim = np.prod(self.input_shape)
    self.encoder = MLP(self.encoder_widths)
    self.decoder = MLP(self.decoder_widths + (input_dim,))

  def __call__(self, x):
    return self.decode(self.encode(x))

  def encode(self, x):
    assert x.shape[1:] == self.input_shape
    return self.encoder(jnp.reshape(x, (x.shape[0], -1)))

  def decode(self, z):
    z = self.decoder(z)
    x = nn.sigmoid(z)
    x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
    return x

In [18]:
model = AutoEncoder(encoder_widths=[20, 10, 5],
                    decoder_widths=[5, 10, 20],
                    input_shape=(12,))

batch = jnp.ones((16, 12))
variables = model.init(jax.random.key(0), batch)
encoded = model.apply(variables, batch, method=model.encode)
decoded = model.apply(variables, encoded, method=model.decode)

In [19]:
ActivationFn = Callable[[jnp.ndarray], jnp.ndarray]
Initializer = Callable[..., Any]

class Decoder(nn.Module):
    '''DEcoder for VAE'''
    layer_sizes: Sequence[int]
    activation: ActivationFn = nn.tanh
    kernel_init: Initializer = jax.nn.initializers.lecun_uniform()
    activate_final: bool = False
    bias: bool = True

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        for i, hidden_size in enumerate(self.layer_sizes):
            x = nn.Dense(
                hidden_size,
                name=f'hidden_{i}',
                kernel_init=self.kernel_init,
                use_bias=self.bias)(x)
            if i != len(self.layer_sizes) - 1 or self.activate_final:
                x = self.activation(x)
        return x

In [20]:
class Encoder(nn.Module):
    '''Encoder for VAE'''
    layer_sizes: Sequence[int]
    activation: ActivationFn = nn.tanh
    kernel_init: Initializer = jax.nn.initializers.lecun_uniform()
    bias: bool = True
    latents: int=60
    
    @nn.compact
    def __call__(self, x: jnp.ndarray):
        # For each layer in the sequence, make a dense net and apply layernorm then tanh
        for i, hidden_size in enumerate(self.layer_sizes):
            x = nn.Dense(
                hidden_size,
                name=f'hidden_{i}',
                kernel_init=self.kernel_init,
                use_bias=self.bias)(x)
            x = nn.LayerNorm(x)
            x = self.activation(x)
            
        mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
        logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
        return mean_x, logvar_x

# reset_function()

In [55]:
traj = jnp.arange(250)

In [51]:
def reset(rng):
    rng, rng1, rng2 = jax.random.split(rng, 3)
    
    return rng, rng1, rng2

In [52]:
jit_reset = jax.jit(reset)
jit_reset(jax.random.PRNGKey(100))

(Array([1041291563, 1031891722], dtype=uint32),
 Array([18805559, 90144667], dtype=uint32),
 Array([2007850834, 3762328797], dtype=uint32))

In [57]:
key = jax.random.PRNGKey(100)
global_key, local_key = jax.random.split(key)
local_key, key_env, eval_key = jax.random.split(local_key, 3)

In [66]:
num_envs = 10
process_count = 1
local_devices_to_use = 1

key_envs = jax.random.split(key_env, num_envs // process_count)
key_envs = jnp.reshape(key_envs, (local_devices_to_use, -1) + key_envs.shape[1:])

In [67]:
key_envs

Array([[[3832035091, 1972974072],
        [1603586291, 2141669964],
        [4181382255, 3431925277],
        [1764346582,  427055401],
        [3560575122,  664938885],
        [3191181278, 2560363209],
        [2588569651, 2143078675],
        [ 558925301, 3177949683],
        [1807721867, 4207491001],
        [2322623226, 3505268039]]], dtype=uint32)

In [72]:
reset_fn = jax.jit(jax.vmap(reset))

In [73]:
env_state = reset_fn(key_envs)

ValueError: split accepts a single key, but was given a key array of shape (10, 2) != (). Use jax.vmap for batching.