In [80]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import jax
import jax.numpy as jnp
import numpy as np
import gymnax
from minatar import Environment

import matplotlib.pyplot as plt

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Asterix MinAtar Environment

In [None]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Asterix-MinAtar")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [None]:
fig, axs = plt.subplots(2, 3)
for i in range(5):
    axs.flatten()[i].imshow(obs[:, :, i])

In [None]:
obs.shape, state

In [None]:
env = Environment("asterix")
obs = env.reset()
action = 1
reward_gym, done_gym = env.act(action)
obs_gym = env.state()
reward_gym, done_gym, obs_gym.shape

In [None]:
state

In [None]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
state_jax, reward_jax, done_jax

# Breakout MinAtar Environment

In [73]:
env = Environment("breakout")
env.reset()
obs_gym = env.state()

In [74]:
state = {'ball_dir': env.env.ball_dir,
         'ball_x': env.env.ball_x,
         'ball_y': env.env.ball_y,
         'brick_map': env.env.brick_map,
         'last_x': env.env.last_x,
         'last_y': env.env.last_y,
         'pos': env.env.pos,
         'strike': env.env.strike,
         'terminal': env.env.terminal}

In [75]:
action = 1
reward_gym, done_gym = env.act(action)
next_obs_gym = env.state()
reward_gym, done_gym, obs_gym.shape

(0, False, (10, 10, 4))

In [76]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Breakout-MinAtar")
rng, key_reset, key_step = jax.random.split(rng, 3)
o, s = reset(key_reset, env_params)

In [77]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
state_jax, reward_jax, done_jax

({'ball_dir': Buffer(3, dtype=int32),
  'ball_x': Buffer(8, dtype=int32),
  'ball_y': Buffer(4, dtype=int32),
  'brick_map': Buffer([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
  'last_x': Buffer(9, dtype=int32),
  'last_y': Buffer(3, dtype=int32),
  'pos': Buffer(3, dtype=int32),
  'strike': Buffer(0, dtype=int32),
  'terminal': Buffer(False, dtype=bool)},
 Buffer(0, dtype=int32),
 Buffer(False, dtype=bool))

In [78]:
(obs_jax == next_obs_gym).all()

DeviceArray(True, dtype=bool)

In [88]:
env.reset()
action_space = [0, 1, 3]
for i in range(10):
    state = {'ball_dir': env.env.ball_dir,
             'ball_x': env.env.ball_x,
             'ball_y': env.env.ball_y,
             'brick_map': env.env.brick_map,
             'last_x': env.env.last_x,
             'last_y': env.env.last_y,
             'pos': env.env.pos,
             'strike': env.env.strike,
             'terminal': env.env.terminal}
    action = np.random.choice(action_space)
    reward_gym, done_gym = env.act(action)
    next_obs_gym = env.state()
    
    obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
    print((obs_jax == next_obs_gym).all())

True
True
True
True
True
True
False
False
False
False
