# Brax Environment & State

In [6]:
from BaseRodentEnv import Rodent
from brax import envs
import jax
import numpy as np
import mediapy as media

In [2]:
Rodent

Rodent_Env_Brax.Rodent

In [3]:
envs.register_environment('rodent', Rodent)

In [4]:
env = envs.get_environment(env_name='rodent')

In [17]:
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
state = jit_reset(jax.random.PRNGKey(0))

In [48]:
state.pipeline_state

State(solver_niter=Array(0, dtype=int32), time=Array(0.50000066, dtype=float32, weak_type=True), qpos=Array([-4.8841089e-03, -2.7480636e-02,  5.9930775e-02,  9.8588538e-01,
        1.1666925e-01,  1.1256338e-01, -4.1806601e-02, -7.0711441e-02,
       -1.4916677e-02,  4.4804864e-02, -2.3284039e-01, -4.8595299e-03,
        3.6644593e-02, -5.0870981e-02,  2.9366004e-01, -1.6541862e-01,
       -2.4248837e-01,  7.1096474e-01,  4.2060059e-01, -7.0588455e-02,
        3.4955928e-01, -4.4224191e-01,  6.6646045e-01,  6.9606262e-01,
        3.9798671e-01, -1.0729350e-01, -4.6670400e-02, -4.1010503e-02,
       -2.2438843e-02,  5.9090736e-03, -2.1666506e-02,  1.9220879e-02,
       -2.8136652e-02,  2.5769843e-02, -2.5776245e-02,  4.3289870e-02,
       -1.3123008e-02,  5.1498204e-02, -2.2431094e-02,  6.8788372e-02,
       -1.6553253e-02,  4.0427566e-02, -1.1611838e-02,  1.4840000e-02,
       -8.6657153e-03, -2.0848415e-03,  2.4704486e-03, -2.3633083e-03,
       -9.1410609e-04,  6.0649565e-03,  3.7975

In [34]:
rollout = [state.pipeline_state]

for i in (range(50)):
    ctrl = jax.numpy.array(np.random.uniform(-1,1, env.sys.nu))
    state = jit_step(state, ctrl)
    rollout.append(state.pipeline_state)

In [39]:
media.show_video(env.render(rollout,camera='close_profile'), fps=1.0 / env.dt)

0
This browser does not support the video tag.


In [42]:
# import jax
# import mujoco
# from mujoco import mjx
# data_path = "adam_exp_tr.p"

In [43]:
# import pickle
# with open(data_path, "rb") as file:
#     data = pickle.load(file)

# Networks & Flax.linen

In [4]:
from flax import linen as nn
from typing import Any, Callable, Sequence
import jax.numpy as jnp

In [10]:
class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)

In [12]:
class AutoEncoder(nn.Module):
  encoder_widths: Sequence[int]
  decoder_widths: Sequence[int]
  input_shape: Sequence[int]

  def setup(self):
    input_dim = np.prod(self.input_shape)
    self.encoder = MLP(self.encoder_widths)
    self.decoder = MLP(self.decoder_widths + (input_dim,))

  def __call__(self, x):
    return self.decode(self.encode(x))

  def encode(self, x):
    assert x.shape[1:] == self.input_shape
    return self.encoder(jnp.reshape(x, (x.shape[0], -1)))

  def decode(self, z):
    z = self.decoder(z)
    x = nn.sigmoid(z)
    x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
    return x

In [13]:
model = AutoEncoder(encoder_widths=[20, 10, 5],
                    decoder_widths=[5, 10, 20],
                    input_shape=(12,))

batch = jnp.ones((16, 12))
variables = model.init(jax.random.key(0), batch)
encoded = model.apply(variables, batch, method=model.encode)
decoded = model.apply(variables, encoded, method=model.decode)

In [7]:

ActivationFn = Callable[[jnp.ndarray], jnp.ndarray]
Initializer = Callable[..., Any]

class Decoder(nn.Module):
    '''DEcoder for VAE'''
    layer_sizes: Sequence[int]
    activation: ActivationFn = nn.tanh
    kernel_init: Initializer = jax.nn.initializers.lecun_uniform()
    activate_final: bool = False
    bias: bool = True

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        for i, hidden_size in enumerate(self.layer_sizes):
            x = nn.Dense(
                hidden_size,
                name=f'hidden_{i}',
                kernel_init=self.kernel_init,
                use_bias=self.bias)(x)
            if i != len(self.layer_sizes) - 1 or self.activate_final:
                x = self.activation(x)
        return x

In [8]:
class Encoder(nn.Module):
    '''Encoder for VAE'''
    layer_sizes: Sequence[int]
    activation: ActivationFn = nn.tanh
    kernel_init: Initializer = jax.nn.initializers.lecun_uniform()
    bias: bool = True
    latents: int
    
    @nn.compact
    def __call__(self, x: jnp.ndarray):
        # For each layer in the sequence
        # Make a dense net and apply layernorm then tanh
        for i, hidden_size in enumerate(self.layer_sizes):
            x = nn.Dense(
                hidden_size,
                name=f'hidden_{i}',
                kernel_init=self.kernel_init,
                use_bias=self.bias)(x)
            x = nn.LayerNorm(x)
            x = self.activation(x)
            
        mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
        logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
        return mean_x, logvar_x

TypeError: non-default argument 'latents' follows default argument