# 1 - Test the environment

In [None]:
import warnings

In [None]:
# Disable useless warnings for this project
warnings.simplefilter('ignore', category=UserWarning)

In [None]:
import gym
import gym_super_mario_bros

# To switch from RGB to gray scale
# To resize the observation of the environment
from gym.wrappers import GrayScaleObservation, ResizeObservation
# To move Mario only to the right
# To move Mario to the right or only jump or go to the left
from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT
# The Super Mario Bros environment
from gym_super_mario_bros.smb_env import SuperMarioBrosEnv
from matplotlib import pyplot as plt
# To wrap the game environment to simulate actions from a game controller
from nes_py.wrappers import JoypadSpace
# To clip AI agent rewards to -1, 0 or 1 depending on the sign of the reward
# To skip frames during training
# To set a maximum to the AI agent to not perform any action
from stable_baselines3.common.atari_wrappers import ClipRewardEnv, MaxAndSkipEnv, NoopResetEnv
# To construct vectorized environments
from stable_baselines3.common.env_util import make_vec_env
# To do frame stacking with vectorized environments
from stable_baselines3.common.vec_env import VecFrameStack

In [None]:
# Custom wrapper to apply multiple wrappers to the environment
class SuperMarioBrosWrapper(gym.Wrapper):

    def __init__(
            self,
            _env: gym.Env,
            env_id: str,
            actions_list: list
    ):
        # Create the game environment with OpenAI Gym from the NES-py rom
        _env = gym_super_mario_bros.make(env_id)
        # Enable actions in the actions list
        _env = JoypadSpace(_env, actions_list)
        # Resize the observation space to a square
        _env = ResizeObservation(_env, (84, 84))
        # Convert the image observation from RGB to gray scale
        _env = GrayScaleObservation(_env, keep_dim=True)
        # Set the maximum value of no operation to run to 30
        _env = NoopResetEnv(_env)
        # Return only every 4-th frame (frame skipping)
        _env = MaxAndSkipEnv(_env)
        # Clips the reward to {+1, 0, -1} by its sign.
        _env = ClipRewardEnv(_env)

        super().__init__(_env)

## 1.1 - Play a game with random actions

In [None]:
env = gym_super_mario_bros.make('SuperMarioBros-v0')
# Wrap the environment to enable the list of SIMPLE_MOVEMENT actions
env = JoypadSpace(env, SIMPLE_MOVEMENT)

In [None]:
# Display the meanings of all possible actions
env.get_action_meanings()

In [None]:
# Reset the environment
obs = env.reset()
done = False
score = 0

while not done:
    # Choose a random action among available actions
    action = env.action_space.sample()
    # Perform an action in the environment
    # It returns :
    #   the observation of the current environment
    #   the amount of reward returned after previous action
    #   the done boolean which is true whether the episode has ended
    #   the info dictionary contains auxiliary diagnostic information
    obs, reward, done, info = env.step(action)
    # Update the overall score with reward of the performed action
    score += reward

print(f'Score : {score}')

# Close the environment
env.close()

## 1.2 - Stack frames to represent movements

In [None]:
# Create the game environment
# Preprocess to the environment for Super Mario Bros game
# Enable Mario to do simple movements to the right and left
# Vectorize the environment
env = make_vec_env(SuperMarioBrosEnv,
                   wrapper_class=SuperMarioBrosWrapper,
                   wrapper_kwargs={'env_id': 'SuperMarioBros-v0',
                                   'actions_list': SIMPLE_MOVEMENT})
# Stack the last 4 frames of the game to have a dimension of movement
env = VecFrameStack(env, n_stack=4)

In [None]:
obs = env.reset()

# Perform 10 actions and show them to visualize how the frame stacking works
for i in range(10):
    plt.figure(figsize=(20, 16))

    # Print 4 images which represents the frames stacked in the environment
    for idx in range(obs.shape[3]):
        plt.subplot(1, 4, idx + 1)
        plt.imshow(obs[0][:, :, idx])

    plt.show()

    action = env.action_space.sample()
    obs, reward, done, info = env.step([action])

env.close()

# 2 - Train and evaluate the model

In [None]:
# Stable Baselines3 (SB3) : set of reliable implementations of reinforcement learning algorithms in PyTorch
# Quantile Regression Deep Q-Network algorithm
from sb3_contrib import QRDQN
# Proximal Policy Optimization algorithm
from stable_baselines3 import PPO
# To save and evaluate regularly a model during training
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback
# To evaluate a model
from stable_baselines3.common.evaluation import evaluate_policy
# To give linear function to hyperparams (e.g. learning rate, clip range, ...)
from stable_baselines3.common.type_aliases import Schedule

In [None]:
# Function to make a value follow a linear evolution
def linear_schedule(initial_value: float) -> Schedule:
    def func(progress_remaining: float) -> float:
        return progress_remaining * initial_value

    return func

## 2.1 - Setup train and evaluation environments with [RIGHT_ONLY actions](https://github.com/Kautenja/gym-super-mario-bros/blob/4c89cf601929733800f70833c7fe62973aecdb08/gym_super_mario_bros/actions.py#L5)

In [None]:
# Create 8 vectorized environments
# Wrap the environments to do some preprocessing
# Enable Mario to do only movements to the right
env = make_vec_env(SuperMarioBrosEnv,
                   n_envs=8,
                   wrapper_class=SuperMarioBrosWrapper,
                   wrapper_kwargs={'env_id': 'SuperMarioBros-v0',
                                   'actions_list': RIGHT_ONLY})
env = VecFrameStack(env, n_stack=4)

In [None]:
eval_env = make_vec_env(SuperMarioBrosEnv,
                        wrapper_class=SuperMarioBrosWrapper,
                        wrapper_kwargs={'env_id': 'SuperMarioBros-v0',
                                        'actions_list': RIGHT_ONLY})
eval_env = VecFrameStack(eval_env, n_stack=4)

## 2.2 - Setup checkpoint and evaluation callbacks

In [None]:
# Create a checkpoint callback to save the model every million steps
# The save frequency refers to each environment (125 000 * 8 = 1 000 000)
checkpoint_callback = CheckpointCallback(save_freq=125_000,
                                         save_path='logs/',
                                         name_prefix='super_mario_bros')

In [None]:
# Create an evaluation callback to mesure the evolution of the model every hundred thousand steps
# The evaluation frequency refers to each environment (12 500 * 8 = 100 000)
eval_callback = EvalCallback(eval_env,
                             best_model_save_path='logs/',
                             eval_freq=12_500)

In [None]:
# Create a list of callbacks to give it at the initialization of the model learning
callback_list = CallbackList([checkpoint_callback, eval_callback])

## 2.3 - Train a model with [PPO algorithm](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html)

In [None]:
checkpoint_callback.name_prefix = 'right_only/ppo/super_mario_bros'
eval_callback.best_model_save_path = 'logs/right_only/ppo/'

In [None]:
# Initialize the model with atari hyperparams from https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/ppo.yml
# Hyperparams are the following :
#   - CnnPolicy : Convolutional Neural Network policy class for actor-critic algorithms
#   - env : 8 vectorized environments with 4 frames stacked each
#   - learning_rate=linear_schedule(2.5e-4) : learning rate function of the current progress remaining (from 1 to 0)
#   - n_steps=128 : number of steps to run for each environment per update
#   - batch_size=256 : Minibatch size (where batch size is n_steps times the number of environment copies running in parallel)
#   - n_epochs=4 : number of epoch when optimizing the surrogate loss
#   - clip_range=linear_schedule(0.1) : clipping function of the current progress remaining (from 1 to 0)
#   - ent_coef=0.01 : entropy coefficient for the loss calculation
#   - vf_coef=0.5 : value function coefficient for the loss calculation
#   - verbose=1 : the verbosity level: 0 no output, 1 info, 2 debug
#   - tensorboard_log='logs/tensorboard/' : the log location for tensorboard
model = PPO('CnnPolicy',
            env,
            learning_rate=linear_schedule(2.5e-4),
            n_steps=128,
            batch_size=256,
            n_epochs=4,
            clip_range=linear_schedule(0.1),
            ent_coef=0.01,
            vf_coef=0.5,
            verbose=1,
            tensorboard_log='logs/tensorboard/')

In [None]:
# Train the model, execute callbacks and log to follow the evolution with tensorboard
model.learn(total_timesteps=int(5e6),
            callback=callback_list,
            log_interval=100,
            tb_log_name='right_only_ppo_super_mario_bros')
# Save the model in a zip file
model.save('right_only_ppo_super_mario_bros')

## 2.4 - Trained a model with [QR-DQN algorithm](https://sb3-contrib.readthedocs.io/en/master/modules/qrdqn.html)

In [None]:
checkpoint_callback.name_prefix = 'right_only/qrdqn/super_mario_bros'
eval_callback.best_model_save_path = 'logs/right_only/qrdqn/'

In [None]:
# Initialize the model with atari hyperparams from https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/qrdqn.yml
# Hyperparams are the following :
#   - CnnPolicy : the policy class for QR-DQN when using images as input
#   - env : 8 vectorized environments with 4 frames stacked each
#   - optimize_memory_usage=True : enable a memory efficient variant of the replay buffer at a cost of more complexity
#   - exploration_fraction=0.025 : fraction of entire training period over which the exploration rate is reduced
#   - verbose=1 : the verbosity level: 0 no output, 1 info, 2 debug
#   - tensorboard_log='logs/tensorboard/' : the log location for tensorboard
model = QRDQN('CnnPolicy',
              env,
              optimize_memory_usage=True,
              exploration_fraction=0.025,
              verbose=1,
              tensorboard_log='logs/tensorboard/')

In [None]:
model.learn(total_timesteps=int(5e6),
            callback=callback_list,
            log_interval=2000,
            tb_log_name='right_only_qrdqn_super_mario_bros')
model.save('right_only_qrdqn_super_mario_bros')

## 2.5 - Setup train and evaluation environments with [SIMPLE_MOVEMENT actions](https://github.com/Kautenja/gym-super-mario-bros/blob/4c89cf601929733800f70833c7fe62973aecdb08/gym_super_mario_bros/actions.py#L15)

In [None]:
env = make_vec_env(SuperMarioBrosEnv,
                   wrapper_class=SuperMarioBrosWrapper,
                   wrapper_kwargs={'env_id': 'SuperMarioBros-v0',
                                   'actions_list': SIMPLE_MOVEMENT})
env = VecFrameStack(env, n_stack=4)

In [None]:
eval_env = make_vec_env(SuperMarioBrosEnv,
                        wrapper_class=SuperMarioBrosWrapper,
                        wrapper_kwargs={'env_id': 'SuperMarioBros-v0',
                                        'actions_list': SIMPLE_MOVEMENT})
eval_env = VecFrameStack(eval_env, n_stack=4)

In [None]:
eval_callback.eval_env = eval_env

## 2.6 - Train a model with [PPO algorithm](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html)

In [None]:
checkpoint_callback.name_prefix = 'simple_movement/ppo/super_mario_bros'
eval_callback.best_model_save_path = 'logs/simple_movement/ppo/'

In [None]:
model = PPO('CnnPolicy',
            env,
            learning_rate=linear_schedule(2.5e-4),
            n_steps=128,
            batch_size=256,
            n_epochs=4,
            clip_range=linear_schedule(0.1),
            ent_coef=0.01,
            vf_coef=0.5,
            verbose=1,
            tensorboard_log='logs/tensorboard/')

In [None]:
model.learn(total_timesteps=int(3e6),
            callback=callback_list,
            log_interval=100,
            tb_log_name='simple_movement_ppo_super_mario_bros')
model.save('simple_movement_ppo_super_mario_bros')

## 2.7 - Trained a model with [QR-DQN algorithm](https://sb3-contrib.readthedocs.io/en/master/modules/qrdqn.html)

In [None]:
checkpoint_callback.name_prefix = 'simple_movement/qrdqn/super_mario_bros'
eval_callback.best_model_save_path = 'logs/simple_movement/qrdqn/'

In [None]:
model = QRDQN('CnnPolicy',
              env,
              optimize_memory_usage=True,
              exploration_fraction=0.025,
              verbose=1,
              tensorboard_log='logs/tensorboard/')

In [None]:
model.learn(total_timesteps=int(3e6),
            callback=callback_list,
            log_interval=2000,
            tb_log_name='simple_movement_qrdqn_super_mario_bros')
model.save('simple_movement_qrdqn_super_mario_bros')

# 3 - See the results of trained models

In [None]:
# Function to evaluate a model on an environment and render the played game
def demo(_model, _env):
    mean_reward, std_reward = evaluate_policy(_model, _env, render=True)
    print(f'mean_reward = {mean_reward:.2f} +/- {std_reward:.2f}')

## 3.1 - Setup the demonstration environments

In [None]:
right_only_demo_env = make_vec_env(SuperMarioBrosEnv,
                                   wrapper_class=SuperMarioBrosWrapper,
                                   wrapper_kwargs={'env_id': 'SuperMarioBros-v0',
                                                   'actions_list': RIGHT_ONLY})
right_only_demo_env = VecFrameStack(right_only_demo_env, n_stack=4)

In [None]:
simple_movement_demo_env = make_vec_env(SuperMarioBrosEnv,
                                        wrapper_class=SuperMarioBrosWrapper,
                                        wrapper_kwargs={'env_id': 'SuperMarioBros-v0',
                                                        'actions_list': SIMPLE_MOVEMENT})
simple_movement_demo_env = VecFrameStack(simple_movement_demo_env, n_stack=4)

## 3.2 - Demo of models trained with [PPO algorithm](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html)

### 3.2.1 - Environment with [RIGHT_ONLY actions](https://github.com/Kautenja/gym-super-mario-bros/blob/4c89cf601929733800f70833c7fe62973aecdb08/gym_super_mario_bros/actions.py#L5)

In [None]:
# Initialize the model with a saved zip file
demo_model = PPO.load('logs/right_only/ppo/best_model')
demo(demo_model, right_only_demo_env)

In [None]:
demo_model = PPO.load('right_only_ppo_super_mario_bros')
demo(demo_model, right_only_demo_env)

### 3.2.2 - Environment with [SIMPLE_MOVEMENT actions](https://github.com/Kautenja/gym-super-mario-bros/blob/4c89cf601929733800f70833c7fe62973aecdb08/gym_super_mario_bros/actions.py#L15)

In [None]:
demo_model = PPO.load('logs/simple_movement/ppo/best_model')
demo(demo_model, simple_movement_demo_env)

In [None]:
demo_model = PPO.load('simple_movement_ppo_super_mario_bros')
demo(demo_model, simple_movement_demo_env)

## 3.3 - Demo of models trained with [QR-DQN algorithm](https://sb3-contrib.readthedocs.io/en/master/modules/qrdqn.html)

### 3.3.1 - Environment with [RIGHT_ONLY actions](https://github.com/Kautenja/gym-super-mario-bros/blob/4c89cf601929733800f70833c7fe62973aecdb08/gym_super_mario_bros/actions.py#L5)

In [None]:
demo_model = QRDQN.load('logs/right_only/qrdqn/best_model')
demo(demo_model, right_only_demo_env)

In [None]:
demo_model = QRDQN.load('right_only_qrdqn_super_mario_bros')
demo(demo_model, right_only_demo_env)

### 3.3.2 - Environment with [SIMPLE_MOVEMENT actions](https://github.com/Kautenja/gym-super-mario-bros/blob/4c89cf601929733800f70833c7fe62973aecdb08/gym_super_mario_bros/actions.py#L15)

In [None]:
demo_model = QRDQN.load('logs/simple_movement/qrdqn/best_model')
demo(demo_model, simple_movement_demo_env)

In [None]:
demo_model = QRDQN.load('simple_movement_qrdqn_super_mario_bros')
demo(demo_model, simple_movement_demo_env)

# Bonus - Make a gif of the model

In [None]:
import imageio
import numpy as np
import os
import shutil

In [None]:
gif_env = make_vec_env(SuperMarioBrosEnv,
                       wrapper_class=SuperMarioBrosWrapper,
                       wrapper_kwargs={'env_id': 'SuperMarioBros-v0',
                                       'actions_list': RIGHT_ONLY})
gif_env = VecFrameStack(gif_env, n_stack=4)

In [None]:
gif_model = QRDQN.load('right_only_qrdqn_super_mario_bros')
gif_model.set_env(gif_env)

In [None]:
os.mkdir('tmp')

In [None]:
# https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#bonus-make-a-gif-of-a-trained-agent
images = []
obs = gif_env.reset()
img = gif_env.render(mode='rgb_array')

for i in range(350):
    images.append(img)
    action, _ = gif_model.predict(obs)
    obs, _, _, _ = gif_env.step(action)
    img = gif_env.render(mode='rgb_array')
    plt.imsave(f'tmp/{i}.jpg', img)

gif_env.close()

In [None]:
imageio.mimsave('../images/super_mario_bros3.gif',
                np.stack([imageio.v3.imread(f'tmp/{i}.jpg') for i, img in enumerate(images) if i % 2 == 0]),
                fps=29)

In [None]:
shutil.rmtree('tmp')