In [1]:
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn
import optax

In [2]:
done = jnp.array([True, False,  True, False, True], dtype=bool)
obs = jnp.array([[ 0.5162862 ,  0.2074894 ,  0.7240187 , -1.0787021 ,  0.04069785,
                 -0.06364991,  1.        ,  0.        ],
                [ 0.11500321,  1.1899251 ,  0.50547856, -0.419965  , -0.2609647 ,
                 -0.33009827,  0.        ,  0.        ],
                [-0.63004315, -0.23130883, -0.3797977 , -0.47536182,  2.8034565 ,
                  1.9403803 ,  0.        ,  0.        ],
                [-0.153934  ,  1.1053196 , -0.3867789 , -0.58739966,  0.3082726 ,
                  0.06706445,  0.        ,  0.        ],
                [-0.0033843 ,  1.4856167 ,  0.04007225,  0.00840264, -0.11051258,
                 -0.24021249,  0.        ,  0.        ]], dtype=jnp.float32)

print(done.shape)
print(obs.shape)

CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated cuDNN installation found.
Version JAX was built against: 8906
Minimum supported: 8900
Installed version: 8500
The local installation version must be no lower than 8900. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


(5,)
(5, 8)


In [3]:
def classify_goal(x):
    result = jnp.zeros_like(x, dtype=jnp.int32)
    result = result.at[x < -0.33].set(0)
    result = result.at[(x >= -0.33) & (x < 0.33)].set(1)
    result = result.at[x >= 0.33].set(2)

    return result

done_idx = jnp.argwhere(done).reshape(-1)
goal_idx = classify_goal(obs[done_idx, 0])

print(done_idx)
print(obs[done_idx, 0])
print(goal_idx)

[0 2 4]
[ 0.5162862  -0.63004315 -0.0033843 ]
[2 0 1]


In [2]:
from craftax_classic.envs.craftax_symbolic_env import CraftaxClassicSymbolicEnv
from environment_base.wrappers import AutoResetEnvWrapper, BatchEnvWrapper

rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 3)

# Create environment
env = AutoResetEnvWrapper(CraftaxClassicSymbolicEnv())
env_params = env.default_params

# Get an initial state and observation
obs, state = env.reset(rngs[0], env_params)

# Pick random action
action = env.action_space(env_params).sample(rngs[1])

# Step environment
obs, state, reward, done, info = env.step(rngs[2], state, action, env_params)

# print(all_map.flatten().shape)
# print(inventory.shape)
# print(intrinsics.shape)
# print(direction.shape)
# print(jnp.array([state.light_level, state.is_sleeping]).shape)

CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated cuDNN installation found.
Version JAX was built against: 8906
Minimum supported: 8900
Installed version: 8500
The local installation version must be no lower than 8900. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


map_view (7, 9)
map_view_one_hot (7, 9, 17)
mob_map (7, 9, 4)


  return lax_numpy.astype(arr, dtype)


map_view (7, 9)
map_view_one_hot (7, 9, 17)
mob_map (7, 9, 4)


In [8]:
print(obs.shape)

(1345,)


In [24]:
# Load iris dataset from sklearn
from sklearn.datasets import load_iris
iris = load_iris()

# Extract features and labels
X = iris.data
y = iris.target

# Convert X and y into jnp arrays
X = jnp.array(X)
y = jnp.array(y)

# Convert y into one hot
y = jnp.eye(3)[y]

print(X.shape, y.shape)

(150, 4) (150, 3)


In [28]:
class MLP(nn.Module):
    hidden1_size: int
    hidden2_size: int
    hidden3_size: int
    output_size: int
    
    dropout_rate: float

    @nn.compact
    def __call__(self, x, train=False):
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
        x = nn.Dense(features=self.hidden1_size)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
        x = nn.Dense(features=self.hidden2_size)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
        x = nn.Dense(features=self.hidden3_size)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_size)(x)
        x = nn.log_softmax(x)
        return x

In [33]:
key = jax.random.PRNGKey(0)
input_shape = (768,)

model = MLP(hidden1_size=64, hidden2_size=64, hidden3_size=16, output_size=3, dropout_rate=0.2)
params = model.init(key, jnp.ones(input_shape, jnp.float32))

print(model)

MLP(
    # attributes
    hidden1_size = 64
    hidden2_size = 64
    hidden3_size = 16
    output_size = 3
    dropout_rate = 0.2
)


In [14]:
dummy_obs = jnp.ones((10, 1345))
print(dummy_obs.shape)

maps, metadata = jnp.split(dummy_obs, [7 * 9 * 21], axis=1)
maps = maps.reshape((-1, 7, 9, 21))

(10, 1345)
(10, 7, 9, 21)
(10, 22)


In [48]:
class QNetCraftax(nn.Module):
    action_size: 17

    @nn.compact
    def __call__(self, x):
        maps, metadata = jnp.split(x, [7 * 9 * 21], axis=1)
        maps = maps.reshape((-1, 7, 9, 21))
        
        maps = nn.Conv(features=32, kernel_size=(3, 3), padding='SAME')(maps)
        maps = nn.relu(maps)
        maps = nn.max_pool(maps, window_shape=(2, 2), strides=(2, 2), padding='VALID')
        maps = nn.Conv(features=64, kernel_size=(3, 3), padding='SAME')(maps)
        maps = nn.relu(maps)
        maps = nn.max_pool(maps, window_shape=(2, 2), strides=(2, 2), padding='VALID')
        maps_features = maps.reshape((maps.shape[0], -1))

        fc_inputs = jnp.concatenate((maps_features, metadata), axis=-1)
        y = nn.Dense(128)(fc_inputs)
        y = nn.relu(y)
        y = nn.Dense(64)(fc_inputs)
        y = nn.relu(y)
        y = nn.Dense(self.action_size)(y)
        y = nn.softmax(y)
        return y

In [49]:
key = jax.random.PRNGKey(0)
model = QNetCraftax(action_size=17)
params = model.init(key, dummy_obs)