# 1 - Test the environment

In [None]:
import warnings

In [None]:
# Disable useless warnings for this project
warnings.simplefilter('ignore', category=DeprecationWarning)
warnings.simplefilter('ignore', category=UserWarning)

In [None]:
import gym

from matplotlib import pyplot as plt
# To construct vectorized environments for Atari games
from stable_baselines3.common.env_util import make_atari_env
# To do frame stacking with vectorized environments
from stable_baselines3.common.vec_env import VecFrameStack

## 1.1 - Play games with random actions

In [None]:
# Create the game environment with OpenAI Gym from the ALE rom
env = gym.make('ALE/SpaceInvaders-v5')

In [None]:
# Display the meanings of all possible actions
env.unwrapped.get_action_meanings()

In [None]:
for episode in range(5):
    # Reset the environment
    obs = env.reset()
    done = False
    score = 0

    while not done:
        # Choose a random action among available actions
        action = env.action_space.sample()
        # Perform an action in the environment
        # It returns :
        #   the observation of the current environment
        #   the amount of reward returned after previous action
        #   the done boolean which is true whether the episode has ended
        #   the info dictionary contains auxiliary diagnostic information
        obs, reward, done, info = env.step(action)
        # Update the overall score with reward of the performed action
        score += reward

    print(f'Episode : {episode + 1} --> Score : {score}')

# Close the environment
env.close()

## 1.2 - Stack frames to represent movements

In [None]:
# Create the game environment
# Preprocess to the environment for Atari games
# Vectorize the environment
env = make_atari_env('ALE/SpaceInvaders-v5')
# Stack the last 4 frames of the game to have a dimension of movement
env = VecFrameStack(env, n_stack=4)

In [None]:
obs = env.reset()

# Perform 10 actions and show them to visualize how the frame stacking works
for i in range(10):
    plt.figure(figsize=(20, 16))

    # Print 4 images which represents the frames stacked in the environment
    for idx in range(obs.shape[3]):
        plt.subplot(1, 4, idx + 1)
        plt.imshow(obs[0][:, :, idx])

    plt.show()

    action = env.action_space.sample()
    obs, reward, done, info = env.step([action])

env.close()

# 2 - Train and evaluate models

In [None]:
import numpy as np

# Stable Baselines3 Contrib : Contrib package for Stable Baselines3 (SB3) - Experimental code
# Reinforcement learning algorithms
from sb3_contrib import MaskablePPO, QRDQN
# To mask actions in an environment
from sb3_contrib.common.wrappers import ActionMasker
# To do preprocessing for Atari games
from stable_baselines3.common.atari_wrappers import AtariWrapper
# To save and evaluate regularly a model during training
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback
# To construct vectorized environments
from stable_baselines3.common.env_util import make_vec_env
# To evaluate a model
from stable_baselines3.common.evaluation import evaluate_policy
# To give linear function to hyperparams (e.g. learning rate, clip range, ...)
from stable_baselines3.common.type_aliases import Schedule

In [None]:
# Function to mask some actions from the available actions in the environment
def mask_fn(_env: gym.Env) -> np.ndarray:
    return [True, True, True, True, False, False]

In [None]:
# Custom wrapper to apply multiple wrappers to the environment
class ActionMaskerAtariWrapper(gym.Wrapper):

    def __init__(self, _env: gym.Env):
        # Wrap the environment to expose a function which return masked actions
        _env = ActionMasker(_env, mask_fn)
        # Wrap the environment to preprocess it for Atari games
        _env = AtariWrapper(_env)

        super().__init__(_env)

In [None]:
# Function to make a linear evolution with the initial value as coefficient
def linear_schedule(initial_value: float) -> Schedule:
    def func(progress_remaining: float) -> float:
        return progress_remaining * initial_value

    return func

## 2.1 - Setup train and evaluation environments

In [None]:
# Create 8 vectorized environments
# Wrap the environments to do some preprocessing
env = make_vec_env('ALE/SpaceInvaders-v5',
                   n_envs=8,
                   wrapper_class=ActionMaskerAtariWrapper)
env = VecFrameStack(env, n_stack=4)

In [None]:
eval_env = make_vec_env('ALE/SpaceInvaders-v5',
                        wrapper_class=ActionMaskerAtariWrapper)
eval_env = VecFrameStack(eval_env, n_stack=4)

## 2.2 - Setup checkpoint and evaluation callbacks

In [None]:
# Create a checkpoint callback to save the model every million steps
# The save frequency refers to each environment (125 000 * 8 = 1 000 000)
checkpoint_callback = CheckpointCallback(save_freq=125_000,
                                         save_path='logs/',
                                         name_prefix='space_invaders')

In [None]:
# Create an evaluation callback to mesure the evolution of the model every hundred thousand steps
# The evaluation frequency refers to each environment (12 500 * 8 = 100 000)
eval_callback = EvalCallback(eval_env,
                             best_model_save_path='logs/',
                             eval_freq=12_500)

In [None]:
# Create a list of callbacks to give it at the initialisation of the model learning
callback_list = CallbackList([checkpoint_callback, eval_callback])

## 2.3 - Train a model with [QR-DQN algorithm](https://sb3-contrib.readthedocs.io/en/master/modules/qrdqn.html])

In [None]:
# Specify the paths where logs will be stored for this model
checkpoint_callback.name_prefix = 'qrdqn/space_invaders'
eval_callback.best_model_save_path = 'logs/qrdqn/'

In [None]:
# Initialize the model with atari hyperparams from https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/qrdqn.yml
# Hyperparams are the following :
#   - CnnPolicy : the policy class for QR-DQN when using images as input
#   - env : 8 vectorized environments with 4 frames stacked each
#   - optimize_memory_usage=True : enable a memory efficient variant of the replay buffer at a cost of more complexity
#   - exploration_fraction=0.025 : fraction of entire training period over which the exploration rate is reduced
#   - verbose=1 : the verbosity level: 0 no output, 1 info, 2 debug
#   - tensorboard_log='logs/tensorboard/' : the log location for tensorboard
model = QRDQN('CnnPolicy',
              env,
              optimize_memory_usage=True,
              exploration_fraction=0.025,
              verbose=1,
              tensorboard_log='logs/tensorboard/')

In [None]:
# Train the model, execute callbacks and log to follow the evolution with tensorboard
model.learn(total_timesteps=int(1e7),
            callback=callback_list,
            log_interval=2000,
            tb_log_name='qrdqn_space_invaders')
# Save the model in a zip file
model.save('qrdqn_space_invaders')

## 2.4 - Train a model with [Maskable PPO algorithm](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html)

In [None]:
checkpoint_callback.name_prefix = 'maskable_ppo/space_invaders'
eval_callback.best_model_save_path = 'logs/maskable_ppo/'

In [None]:
# Initialize the model with atari hyperparams from https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/ppo.yml
# Hyperparams are the following :
#   - CnnPolicy : Convolutional Neural Network policy class for actor-critic algorithms
#   - env : 8 vectorized environments with 4 frames stacked each
#   - learning_rate=linear_schedule(2.5e-4) : learning rate function of the current progress remaining (from 1 to 0)
#   - n_steps=128 : number of steps to run for each environment per update
#   - batch_size=256 : Minibatch size (where batch size is n_steps times the number of environment copies running in parallel)
#   - n_epochs=4 : number of epoch when optimizing the surrogate loss
#   - clip_range=linear_schedule(0.1) : clipping function of the current progress remaining (from 1 to 0)
#   - ent_coef=0.01 : entropy coefficient for the loss calculation
#   - vf_coef=0.5 : value function coefficient for the loss calculation
#   - verbose=1 : the verbosity level: 0 no output, 1 info, 2 debug
#   - tensorboard_log='logs/tensorboard/' : the log location for tensorboard
model = MaskablePPO('CnnPolicy',
                    env,
                    learning_rate=linear_schedule(2.5e-4),
                    n_steps=128,
                    batch_size=256,
                    n_epochs=4,
                    clip_range=linear_schedule(0.1),
                    ent_coef=0.01,
                    vf_coef=0.5,
                    verbose=1,
                    tensorboard_log='logs/tensorboard/')

In [None]:
model.learn(total_timesteps=int(1e7),
            callback=callback_list,
            log_interval=100,
            tb_log_name='maskable_ppo_space_invaders')
model.save('maskable_ppo_space_invaders')

## 2.5 - Error : Python process interruption

The python kernel crashed a little bit after 6M steps, so the model is reset with saved parameters of the zip file of 6M steps to train 4M more steps to reach 10M steps in total.

In [None]:
checkpoint_callback.name_prefix = 'maskable_ppo/follow-up/space_invaders'

In [None]:
# Reinitialize the model with the hyperparams updated with values from before the interruption
model = MaskablePPO('CnnPolicy',
                    env,
                    learning_rate=linear_schedule(1.0001e-4),
                    n_steps=128,
                    batch_size=256,
                    n_epochs=4,
                    clip_range=linear_schedule(0.04),
                    ent_coef=0.01,
                    vf_coef=0.5,
                    verbose=1,
                    tensorboard_log='logs/tensorboard/')

In [None]:
# Set the model parameters to values from the last saved model
model.set_parameters('logs/maskable_ppo/space_invaders_6000000_steps.zip')

In [None]:
model.learn(total_timesteps=int(4e6),
            callback=callback_list,
            log_interval=100,
            tb_log_name='maskable_ppo_space_invaders')
model.save('maskable_ppo_space_invaders')

# 3 - See the results of trained models

In [None]:
# Function to evaluate a model on an environment and render the played game
def demo(_model, _env):
    mean_reward, std_reward = evaluate_policy(_model, _env, render=True)
    print(f'mean_reward = {mean_reward:.2f} +/- {std_reward:.2f}')

## 3.1 - Setup the demonstration environment

In [None]:
demo_env = make_vec_env('ALE/SpaceInvaders-v5',
                        wrapper_class=ActionMaskerAtariWrapper)
demo_env = VecFrameStack(demo_env, n_stack=4)

## 3.2 - Demo of models trained with [QR-DQN algorithm](https://sb3-contrib.readthedocs.io/en/master/modules/qrdqn.html])

In [None]:
# Initialize the model with a saved zip file
demo_model = QRDQN.load('logs/qrdqn/best_model')
demo(demo_model, demo_env)

In [None]:
demo_model = QRDQN.load('qrdqn_space_invaders')
demo(demo_model, demo_env)

## 3.3 - Demo of models trained with [Maskable PPO algorithm](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html)

In [None]:
demo_model = MaskablePPO.load('logs/maskable_ppo/best_model')
demo(demo_model, demo_env)

In [None]:
demo_model = MaskablePPO.load('maskable_ppo_space_invaders')
demo(demo_model, demo_env)

# Bonus - Make a gif of the model

In [None]:
import imageio

In [None]:
gif_env = make_atari_env('ALE/SpaceInvaders-v5')
gif_env = VecFrameStack(gif_env, n_stack=4)

In [None]:
gif_model = MaskablePPO.load('maskable_ppo_space_invaders')
gif_model.set_env(gif_env)

In [None]:
# https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#bonus-make-a-gif-of-a-trained-agent
images = []
obs = gif_env.reset()
img = gif_env.render(mode='rgb_array')

for i in range(350):
    images.append(img)
    action, _ = gif_model.predict(obs)
    obs, _, _, _ = gif_env.step(action)
    img = gif_env.render(mode='rgb_array')

gif_env.close()

In [None]:
imageio.mimsave('../images/space_invaders.gif',
                [np.array(img) for i, img in enumerate(images) if i % 2 == 0],
                fps=29)