In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import jax
import jax.numpy as jnp
import numpy as np
import gymnax
from minatar import Environment

import matplotlib.pyplot as plt

In [2]:
rng, env = gymnax.make("Asterix-MinAtar")
obs, state = env.reset(rng)
n_obs, n_state, reward, done, info = env.step(rng, state, 0)



In [3]:
rng, env = gymnax.make("Breakout-MinAtar")
obs, state = env.reset(rng)
n_obs, n_state, reward, done, info = env.step(rng, state, 0)

In [4]:
rng, env = gymnax.make("Freeway-MinAtar")
obs, state = env.reset(rng)
n_obs, n_state, reward, done, info = env.step(rng, state, 0)

  lax._check_user_dtype_supported(dtype, "array")


In [None]:
rng, env = gymnax.make("SpaceInvaders-MinAtar")
obs, state = env.reset(rng)
n_obs, n_state, reward, done, info = env.step(rng, state, 0)

# Asterix MinAtar Environment

In [None]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Asterix-MinAtar")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [None]:
fig, axs = plt.subplots(2, 3)
for i in range(5):
    axs.flatten()[i].imshow(obs[:, :, i])

In [None]:
obs.shape, state

In [None]:
env = Environment("asterix")
obs = env.reset()
action = 1
reward_gym, done_gym = env.act(action)
obs_gym = env.state()
reward_gym, done_gym, obs_gym.shape

In [None]:
state

In [None]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
state_jax, reward_jax, done_jax

# Breakout MinAtar Environment

In [None]:
env = Environment("breakout", sticky_action_prob =0.0)
env.reset()
obs_gym = env.state()

In [None]:
state = {'ball_dir': env.env.ball_dir,
         'ball_x': env.env.ball_x,
         'ball_y': env.env.ball_y,
         'brick_map': env.env.brick_map,
         'last_x': env.env.last_x,
         'last_y': env.env.last_y,
         'pos': env.env.pos,
         'strike': env.env.strike,
         'terminal': env.env.terminal}

In [None]:
action = 1
reward_gym, done_gym = env.act(action)
next_obs_gym = env.state()
reward_gym, done_gym, obs_gym.shape

In [None]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Breakout-MinAtar")
rng, key_reset, key_step = jax.random.split(rng, 3)
o, s = reset(key_reset, env_params)

In [None]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
state_jax, reward_jax, done_jax

In [None]:
(obs_jax == next_obs_gym).all()

In [None]:
seed = 3
env = Environment("breakout", sticky_action_prob =0.0, random_seed=seed)
env.reset()
np.random.seed(seed)
action_space = [1, 3]
for i in range(10):
    # print("pos gym pre", env.env.pos)
    state_jax = {'ball_dir': env.env.ball_dir,
                 'ball_x': env.env.ball_x,
                 'ball_y': env.env.ball_y,
                 'brick_map': env.env.brick_map,
                 'last_x': env.env.last_x,
                 'last_y': env.env.last_y,
                 'pos': env.env.pos,
                 'strike': env.env.strike,
                 'terminal': env.env.terminal}
    action = np.random.choice(action_space)
    #print("gym_pre", env.env.ball_x, env.env.ball_y , env.env.pos)
    reward_gym, done_gym = env.act(action)
    #print("pos gym post", env.env.pos, print(done_gym))
    next_obs_gym = env.state()
    #print("gym_post", env.env.ball_x, env.env.ball_y , env.env.pos)
    
    obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                       state_jax, action)
    #print(done_jax)
    
    #print(i, "==========")
    if not (obs_jax == next_obs_gym).all():
        print("===")
        print("Problem")
        for i in range(4):
            print((obs_jax[:, :, i] == next_obs_gym[:, :, i]).all())
        break
    if done_gym:
        break
    

In [None]:
(obs_jax[:, :, 1] == next_obs_gym[:, :, 1]).all()

In [None]:
fig, axs = plt.subplots(2, 2)
for i in range(4):
    axs.flatten()[i].imshow(next_obs_gym[:, :, i])

In [None]:
fig, axs = plt.subplots(2, 2)
for i in range(4):
    axs.flatten()[i].imshow(obs_jax[:, :, i])

# Freeway MinAtar Environment

In [None]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Freeway-MinAtar")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [None]:
env = Environment("freeway", sticky_action_prob =0.0)
env.reset()
obs_gym = env.state()
state_jax = {"pos": env.env.pos,
             "cars": jnp.array(env.env.cars),
             "move_timer": env.env.move_timer,
             "terminate_timer": env.env.terminate_timer,
             "terminal": env.env.terminal}

action = 2
reward_gym, done_gym = env.act(action)
next_obs_gym = env.state()
reward_gym, done_gym, obs_gym.shape
state_next_jax = {"pos": env.env.pos,
             "cars": jnp.array(env.env.cars),
             "move_timer": env.env.move_timer,
             "terminate_timer": env.env.terminate_timer,
             "terminal": env.env.terminal}

In [None]:
state_jax, state_next_jax

In [None]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state_jax, action)
state_jax, reward_jax, done_jax

In [None]:
fig, axs = plt.subplots(2, 4)
for i in range(7):
    axs.flatten()[i].imshow(obs[:, :, i])

# Seaquest MinAtar Environment

In [None]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Seaquest-MinAtar")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

# SpaceInvaders MinAtar Environment

In [None]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("SpaceInvaders-MinAtar")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [None]:
env = Environment("space_invaders", sticky_action_prob=0.0)
env.reset()
obs_gym = env.state()


In [None]:
env.act(1)