In [1]:
from IPython.display import clear_output

In [None]:
%pip install swig
%pip install stable-baselines3 gymnasium[all] gym_super_mario_bros nes_py gym==0.10.9  # might need a restart of the session.

clear_output()

In [9]:
#pip install matplotlib

clear_output()

#Content

In this demo we will train a PPO agent which learns to play the classic super mario game.

For the agent, we will use the stable-baselines3 implementation.

For the env, we will use gym_super_mario_bros. Read more about it [Here](https://github.com/Kautenja/gym-super-mario-bros/)

Note that our stable-baselines3 implementations expect a gymnasium environment and not a gym environment (gymnasium is the upgraded form of gym. gym is depreciated but we can still find a lot of environments made in it.)

Fortunately, gymnasium has a way to resolve that issue and convert a gym env to a gymnasium env. We do need to install a compatible version of gym though.

In [10]:
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

from gymnasium.wrappers import GrayScaleObservation
import gymnasium as gym
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from nes_py.wrappers import JoypadSpace

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

In [11]:
def frames_to_video(frames, fps=24):
    fig = plt.figure(figsize=(frames[0].shape[1] / 100, frames[0].shape[0] / 100), dpi=100)
    ax = plt.axes()
    ax.set_axis_off()

    if len(frames[0].shape) == 2:  # Grayscale image
        im = ax.imshow(frames[0], cmap='gray')
    else:  # Color image
        im = ax.imshow(frames[0])

    def init():
        if len(frames[0].shape) == 2:
            im.set_data(frames[0], cmap='gray')
        else:
            im.set_data(frames[0])
        return im,

    def update(frame):
        if len(frames[frame].shape) == 2:
            im.set_data(frames[frame], cmap='gray')
        else:
            im.set_data(frames[frame])
        return im,

    interval = 1000 / fps
    anim = FuncAnimation(fig, update, frames=len(frames), init_func=init, blit=True, interval=interval)
    plt.close()
    return HTML(anim.to_html5_video())

## Making the environment

on top of making the gym requirement, we will make a vectorized environment (provided by stable baselines 3)

This introduces training over multiple environments simultaneously, making the traning faster. We will use DummyVecEnv which doesn't actually use subprocesses but if we were working with a complex environment with higher compute time, we could also use SubProcessVecEnv

In [12]:
def make_env(render_mode=None):

    env = gym_super_mario_bros.make('SuperMarioBros-v0')
    env = JoypadSpace(env, SIMPLE_MOVEMENT)
    env = gym.make("GymV21Environment-v0", env=env, render_mode=render_mode)
    env = GrayScaleObservation(env, keep_dim=True)

    return env

def get_vec_env(render_mode=None, num_envs=1):

    env = DummyVecEnv([lambda: make_env(render_mode=render_mode) for _ in range(num_envs)])
    return env

In [13]:
env = get_vec_env(num_envs=4)  # change according to RAM and cores

  result = entry_point.load(False)


## Creating and training the model

In [14]:
def lr_scheduler(progress):

    upper = 5e-6
    lower = 1e-6

    return upper - ((upper - lower) * (1 - progress))  # linearly from upper to lower

In [15]:
model = PPO("CnnPolicy", env, verbose=0, learning_rate=lr_scheduler, n_steps=512)

In [None]:
model.learn(total_timesteps=1e6, progress_bar=True)  # 1M steps. Takes some time.

## Visualizing the results

In [None]:
t_env = get_vec_env(render_mode="rgb_array")

state = t_env.reset()
frames = []

while True:
    action, _ = model.predict(state)
    state_next, r, done, info = t_env.step(action)
    state = state_next.copy()
    frames.append(t_env.render())
    if done:
        break
    if len(frames) > 5000:  # to limit the video length in case mario is stuck on untrained models. can be removed
        break

t_env.close()

In [None]:
frames_to_video(frames, fps=60)