In [9]:
%load_ext autoreload
%autoreload 2

import os  
import sys 
from PIL import Image 
os.environ['JAX_PLATFORMS'] = 'cpu'

import flax
import flax.linen as nn
import jax 
from jax import random
import jax.numpy as jnp 

from craftax.craftax_env import make_craftax_env_from_name
from craftax.craftax.world_gen.world_gen import generate_world
from craftax.craftax.envs.craftax_symbolic_env import CraftaxSymbolicEnv

from jaxued.wrappers.autoreplay import AutoReplayWrapper

from editax.models.lstm import LSTMActorCritic, ResetLSTM
from editax.upomdp import LogWrapper


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 3)

# Create environment
env = make_craftax_env_from_name("Craftax-Symbolic-v1", auto_reset=True)
env_params = env.default_params
default_statics = CraftaxSymbolicEnv.default_static_params()

level = generate_world(_rng, env_params, default_statics) 
print(level.map.shape)

env = CraftaxSymbolicEnv(default_statics)
env = AutoReplayWrapper(LogWrapper(env))
print(type(env))

(9, 48, 48)
<class 'jaxued.wrappers.autoreplay.AutoReplayWrapper'>


In [3]:
batch_size = 32
seq_len = 10

obs, _ = env.reset_to_level(rng, level, env_params)
#print(obs.shape)
obs = jax.tree_util.tree_map(
    lambda x: jnp.repeat(
        jnp.repeat(
            x[None, ...], 
            batch_size,
            axis=0
        )[None, ...],
        seq_len,
        axis=0,
    ),
    obs,
)
print(obs.shape)

(10, 32, 8268)


In [4]:
n_editor = 4 
in_feat, out_feat = obs.shape[2], 256
key_1, key_2, key_3 = random.split(random.PRNGKey(0), 3)

model = ResetLSTM(nn.OptimizedLSTMCell(features=out_feat))
print(model)

ResetLSTM(
    # attributes
    cell = OptimizedLSTMCell(
        # attributes
        features = 256
        gate_fn = sigmoid
        activation_fn = tanh
        kernel_init = init
        recurrent_kernel_init = init
        bias_init = zeros
        dtype = None
        param_dtype = float32
        carry_init = zeros
    )
)


In [5]:
embeds = obs 
dones = jax.random.uniform(key_2, (seq_len, batch_size, )) > 0.5

xs = (embeds, dones)
print(xs[0].shape)
print(xs[1].shape)

(10, 32, 8268)
(10, 32)


In [6]:
init_carry = model.cell.initialize_carry(key_3, xs[0].shape[1:])
print(init_carry[0].shape)
print(init_carry[1].shape)


(32, 256)
(32, 256)


In [7]:
variables = model.init(key_3, xs)

Inner loop
x shape: (32, 8268)
Resets shape: (32,)
Carry shape: (32, 256)
Inner loop
x shape: (32, 8268)
Resets shape: (32,)
Carry shape: (32, 256)


In [8]:
out_carry, out_val = model.apply(variables, xs, initial_carry=init_carry)
print(out_carry[0].shape)
print(out_carry[1].shape)
print(out_val.shape)

Inner loop
x shape: (32, 8268)
Resets shape: (32,)
Carry shape: (32, 256)
Inner loop
x shape: (32, 8268)
Resets shape: (32,)
Carry shape: (32, 256)
(32, 256)
(32, 256)
(10, 32, 256)


In [52]:
print(out_carry[0].shape)
print(out_carry[1].shape)

(16, 5)
(16, 5)


In [53]:
out_val.shape 

(16, 20, 5)