In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import jax
import jax.numpy as jnp
import numpy as np
import gymnax
from minatar import Environment

import matplotlib.pyplot as plt

# Freeway MinAtar Environment

In [2]:
# 1. Create gymnax based environment
rng, env = gymnax.make("Freeway-MinAtar")
rng, key_reset, key_step, key_action = jax.random.split(rng, 4)
obs, state = env.reset(key_reset)

# 2. Setup gym-based environment, get state dict
env_gym = Environment("freeway")
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "Freeway-MinAtar")

# 3. Transition in both environments
action = env.action_space.sample(key_action)
action_gym = gymnax.utils.minatar_action_map(action, "Freeway-MinAtar")
reward_gym = env_gym.act(action_gym)
obs_gym = env_gym.state()
done_gym = env_gym.env.terminal
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "Freeway-MinAtar", state_jax)

  lax._check_user_dtype_supported(dtype, "array")
  lax._check_user_dtype_supported(dtype, "astype")


# SpaceInvaders MinAtar Environment

In [3]:
# 1. Create gymnax based environment
rng, env = gymnax.make("SpaceInvaders-MinAtar")
rng, key_reset, key_step, key_action = jax.random.split(rng, 4)
obs, state = env.reset(key_reset)

# 2. Setup gym-based environment, get state dict
env_gym = Environment("space_invaders")
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "SpaceInvaders-MinAtar")

# 3. Transition in both environments
action = env.action_space.sample(key_action)
action_gym = gymnax.utils.minatar_action_map(action, "SpaceInvaders-MinAtar")
reward_gym = env_gym.act(action_gym)
obs_gym = env_gym.state()
done_gym = env_gym.env.terminal
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "SpaceInvaders-MinAtar", state_jax)


# Breakout MinAtar Environment

In [5]:
# 1. Create gymnax based environment
rng, env = gymnax.make("Breakout-MinAtar")
rng, key_reset, key_step, key_action = jax.random.split(rng, 4)
obs, state = env.reset(key_reset)

# 2. Setup gym-based environment, get state dict
env_gym = Environment("breakout")
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "Breakout-MinAtar")

# 3. Transition in both environments
action = env.action_space.sample(key_action)
action_gym = gymnax.utils.minatar_action_map(action, "Breakout-MinAtar")
reward_gym = env_gym.act(action_gym)
obs_gym = env_gym.state()
done_gym = env_gym.env.terminal
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "Breakout-MinAtar", state_jax)

# Asterix MinAtar Environment

In [16]:
# 1. Create gymnax based environment
rng, env = gymnax.make("Asterix-MinAtar")
rng, key_reset, key_step, key_action = jax.random.split(rng, 4)
obs, state = env.reset(key_reset)
state

{'player_x': 5,
 'player_y': 5,
 'shot_timer': 0,
 'spawn_speed': 10,
 'spawn_timer': 10,
 'move_speed': 5,
 'move_timer': 5,
 'ramp_timer': 100,
 'ramp_index': 0,
 'entities': DeviceArray([[0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0]], dtype=int32),
 'time': 0,
 'terminal': 0}

In [17]:
# 2. Setup gym-based environment, get state dict
env_gym = Environment("asterix")
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "Asterix-MinAtar")
state

{'player_x': 5,
 'player_y': 5,
 'shot_timer': 0,
 'spawn_speed': 10,
 'spawn_timer': 10,
 'move_speed': 5,
 'move_timer': 5,
 'ramp_timer': 100,
 'ramp_index': 0,
 'entities': DeviceArray([[0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0]], dtype=int32),
 'time': 0,
 'terminal': 0}

In [18]:
# 3. Transition in both environments
action = env.action_space.sample(key_action)
action_gym = gymnax.utils.minatar_action_map(action, "Asterix-MinAtar")
reward_gym = env_gym.act(action_gym)
obs_gym = env_gym.state()
done_gym = env_gym.env.terminal
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# Cannot assert exact correctness due to random entity slot sampling!
# Test go through substeps in overall step transition
# # 4a. Check correctness of transition
# gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
#                                     obs_jax, reward_jax, done_jax)

# # 4b. Check that post-transition states are equal
# gymnax.utils.assert_correct_state(env_gym, "Asterix-MinAtar", state_jax)

AssertionError: 

# Seaquest MinAtar Environment

In [None]:
# 1. Create gymnax based environment
rng, env = gymnax.make("Seaquest-MinAtar")
rng, key_reset, key_step, key_action = jax.random.split(rng, 4)
obs, state = env.reset(key_reset)

# 2. Setup gym-based environment, get state dict
env_gym = Environment("seaquest")
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "Seaquest-MinAtar")

# 3. Transition in both environments
action = env.action_space.sample(key_action)
action_gym = gymnax.utils.minatar_action_map(action, "Seaquest-MinAtar")
reward_gym = env_gym.act(action_gym)
obs_gym = env_gym.state()
done_gym = env_gym.env.terminal
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "Seaquest-MinAtar", state_jax)