In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import jax
import jax.numpy as jnp
import gymnax
from minatar import Environment

import matplotlib.pyplot as plt



# Asterix MinAtar Environment

In [None]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Asterix-MinAtar")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [None]:
fig, axs = plt.subplots(2, 3)
for i in range(5):
    axs.flatten()[i].imshow(obs[:, :, i])

In [None]:
obs.shape, state

In [None]:
env = Environment("asterix")
obs = env.reset()
action = 1
reward_gym, done_gym = env.act(action)
obs_gym = env.state()
reward_gym, done_gym, obs_gym.shape

In [None]:
state

In [None]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
state_jax, reward_jax, done_jax

# Breakout MinAtar Environment

In [9]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Breakout-MinAtar")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [10]:
env = Environment("breakout")
obs = env.reset()
action = 1
reward_gym, done_gym = env.act(action)
obs_gym = env.state()
reward_gym, done_gym, obs_gym.shape

(0, False, (10, 10, 4))

In [15]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
state_jax, reward_jax, done_jax

({'ball_dir': DeviceArray([2], dtype=int32),
  'ball_x': DeviceArray([1], dtype=int32),
  'ball_y': DeviceArray([4], dtype=int32),
  'brick_map': DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
               [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
               [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
               [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
               [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
  'last_x': DeviceArray([0], dtype=int32),
  'last_y': DeviceArray(3, dtype=int32),
  'pos': DeviceArray(3, dtype=int32),
  'strike': DeviceArray([0], dtype=int32),
  'terminal': DeviceArray([False], dtype=bool)},
 DeviceArray([0], dtype=int32),
 DeviceArray(False, dty