In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import jax
import jax.numpy as jnp
import numpy as np
import gymnax
from minatar import Environment

import matplotlib.pyplot as plt

# Freeway MinAtar Environment

In [None]:
# 1. Create gymnax based environment
rng, env = gymnax.make("Freeway-MinAtar")
rng, key_reset, key_step, key_action = jax.random.split(rng, 4)
obs, state = env.reset(key_reset)

# 2. Setup gym-based environment, get state dict
env_gym = Environment("freeway")
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "Freeway-MinAtar")

# 3. Transition in both environments
action = env.action_space.sample(key_action)
action_gym = gymnax.utils.minatar_action_map(action, "Freeway-MinAtar")
reward_gym = env_gym.act(action_gym)
obs_gym = env_gym.state()
done_gym = env_gym.env.terminal
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "Freeway-MinAtar", state_jax)

# SpaceInvaders MinAtar Environment

In [None]:
# 1. Create gymnax based environment
rng, env = gymnax.make("SpaceInvaders-MinAtar")
rng, key_reset, key_step, key_action = jax.random.split(rng, 4)
obs, state = env.reset(key_reset)

# 2. Setup gym-based environment, get state dict
env_gym = Environment("space_invaders")
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "SpaceInvaders-MinAtar")

# 3. Transition in both environments
action = env.action_space.sample(key_action)
action_gym = gymnax.utils.minatar_action_map(action, "SpaceInvaders-MinAtar")
reward_gym, obs_gym, done_gym = env_gym.act(action_gym), env_gym.state(), env_gym.env.terminal
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "SpaceInvaders-MinAtar", state_jax)