In [1]:
!pip install gym-super-mario-bros==7.3.0 nes_py



In [24]:
!pip install stable-baselines3[extra]

Collecting stable-baselines3[extra]
  Downloading stable_baselines3-2.0.0-py3-none-any.whl (178 kB)
     -------------------------------------- 178.4/178.4 kB 3.7 MB/s eta 0:00:00
Collecting gymnasium==0.28.1
  Using cached gymnasium-0.28.1-py3-none-any.whl (925 kB)
Collecting autorom[accept-rom-license]~=0.6.0
  Using cached AutoROM-0.6.1-py3-none-any.whl (9.4 kB)
Collecting shimmy[atari]~=0.2.1
  Using cached Shimmy-0.2.1-py3-none-any.whl (25 kB)
Collecting opencv-python
  Downloading opencv_python-4.8.0.74-cp37-abi3-win_amd64.whl (38.1 MB)
     --------------------------------------- 38.1/38.1 MB 13.4 MB/s eta 0:00:00
Collecting rich
  Downloading rich-13.5.1-py3-none-any.whl (239 kB)
     ------------------------------------- 239.7/239.7 kB 14.3 MB/s eta 0:00:00
Collecting farama-notifications>=0.0.1
  Using cached Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Collecting jax-jumpy>=1.0.0
  Using cached jax_jumpy-1.0.0-py3-none-any.whl (20 kB)
Collecting AutoROM.accept-rom-li

In [49]:
import gym
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from gym.wrappers import GrayScaleObservation
import matplotlib.pyplot as plt
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
from nes_py.wrappers import JoypadSpace
JoypadSpace.reset = lambda self, **kwargs: self.env.reset(**kwargs) # StackOverflow code to fix a problem when calling state.reset()

In [50]:
SIMPLE_MOVEMENT

[['NOOP'],
 ['right'],
 ['right', 'A'],
 ['right', 'B'],
 ['right', 'A', 'B'],
 ['A'],
 ['left']]

# Random Agent

In [78]:
env = gym.make("SuperMarioBros-v0",apply_api_compatibility=True,render_mode="human")
# Wrap the environment to reduce the action space : 7 instead of 256
env = JoypadSpace(env,SIMPLE_MOVEMENT)

done = True
for step in range(1000):
    if done:
        env.reset()
    action = env.action_space.sample()
    state, reward, done,_, info = env.step(action)
    env.render()
env.close()

  logger.warn(
  logger.warn(
  logger.warn(


# Preprocessing the environment

In [107]:
env = gym.make("SuperMarioBros-v0",apply_api_compatibility=True,render_mode="rgb_array")
# Wrap the environment to reduce the action space : 7 instead of 256
env = JoypadSpace(env,SIMPLE_MOVEMENT)
# Graycscale the observation space
env= GrayScaleObservation(env, keep_dim=True)
# plt.imshow(env.reset()[0], cmap="Greys")  
# Wrap into the Dummy Environment
env = DummyVecEnv([lambda: env])
# Stack the frames (so the agent can predict the movements of ennemies)
env = VecFrameStack(env,4)
print("OBSERVATION SPACE", str(env.observation_space))
print("ACTION SPACE :",str(env.action_space))
print("RENDER :",str(env.render_mode))

OBSERVATION SPACE Box(0, 255, (240, 256, 4), uint8)
ACTION SPACE : Discrete(7)
RENDER : human


# RL Model

In [80]:
import os
from stable_baselines3 import PPO 
from stable_baselines3.common.callbacks import BaseCallback # Saving models

In [83]:
# Callback to save the model every check_freq steps 
# Don't save too often because a trained model is still quite big
class TrainAndLoggingCallback(BaseCallback): 
    def __init__(self,check_freq,save_path,verbose=1):
        super(TrainAndLoggingCallback,self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path
        
    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)
        
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model{}'.format(self.n_calls))
            self.model.save(model_path)
        return True

In [84]:
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs/'

In [85]:
callback = TrainAndLoggingCallback(check_freq=10000,save_path=CHECKPOINT_DIR)

In [86]:
model = PPO("CnnPolicy",env,learning_rate=0.000001,n_steps=512, tensorboard_log=LOG_DIR,verbose=1)

Using cuda device
Wrapping the env in a VecTransposeImage.


In [87]:
model.learn(total_timesteps=100000,callback=callback)

Logging to ./logs/PPO_2
----------------------------
| time/              |     |
|    fps             | 174 |
|    iterations      | 1   |
|    time_elapsed    | 2   |
|    total_timesteps | 512 |
----------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 119           |
|    iterations           | 2             |
|    time_elapsed         | 8             |
|    total_timesteps      | 1024          |
| train/                  |               |
|    approx_kl            | 4.0821964e-05 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.95         |
|    explained_variance   | 0.00245       |
|    learning_rate        | 1e-06         |
|    loss                 | 138           |
|    n_updates            | 10            |
|    policy_gradient_loss | 9.59e-06      |
|    value_loss           | 406           |
-------------------------

-------------------------------------------
| time/                   |               |
|    fps                  | 94            |
|    iterations           | 13            |
|    time_elapsed         | 70            |
|    total_timesteps      | 6656          |
| train/                  |               |
|    approx_kl            | 4.1979132e-05 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.94         |
|    explained_variance   | 0.00165       |
|    learning_rate        | 1e-06         |
|    loss                 | 0.125         |
|    n_updates            | 120           |
|    policy_gradient_loss | -0.000556     |
|    value_loss           | 0.237         |
-------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 93            |
|    iterations           | 14            |
|    time_elapsed         | 76  

------------------------------------------
| time/                   |              |
|    fps                  | 92           |
|    iterations           | 24           |
|    time_elapsed         | 132          |
|    total_timesteps      | 12288        |
| train/                  |              |
|    approx_kl            | 3.037625e-05 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.94        |
|    explained_variance   | 0.00279      |
|    learning_rate        | 1e-06        |
|    loss                 | 0.322        |
|    n_updates            | 230          |
|    policy_gradient_loss | -0.000126    |
|    value_loss           | 1.09         |
------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 92            |
|    iterations           | 25            |
|    time_elapsed         | 138           |
|    t

-------------------------------------------
| time/                   |               |
|    fps                  | 92            |
|    iterations           | 35            |
|    time_elapsed         | 194           |
|    total_timesteps      | 17920         |
| train/                  |               |
|    approx_kl            | 0.00020032004 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.94         |
|    explained_variance   | 0.00667       |
|    learning_rate        | 1e-06         |
|    loss                 | 125           |
|    n_updates            | 340           |
|    policy_gradient_loss | -0.00113      |
|    value_loss           | 235           |
-------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 91           |
|    iterations           | 36           |
|    time_elapsed         | 200     

------------------------------------------
| time/                   |              |
|    fps                  | 91           |
|    iterations           | 46           |
|    time_elapsed         | 256          |
|    total_timesteps      | 23552        |
| train/                  |              |
|    approx_kl            | 7.590803e-05 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.93        |
|    explained_variance   | -0.00517     |
|    learning_rate        | 1e-06        |
|    loss                 | 0.0942       |
|    n_updates            | 450          |
|    policy_gradient_loss | -0.000591    |
|    value_loss           | 0.282        |
------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 91            |
|    iterations           | 47            |
|    time_elapsed         | 262           |
|    t

------------------------------------------
| time/                   |              |
|    fps                  | 92           |
|    iterations           | 57           |
|    time_elapsed         | 314          |
|    total_timesteps      | 29184        |
| train/                  |              |
|    approx_kl            | 3.531133e-05 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.93        |
|    explained_variance   | 0.000432     |
|    learning_rate        | 1e-06        |
|    loss                 | 0.149        |
|    n_updates            | 560          |
|    policy_gradient_loss | -0.000163    |
|    value_loss           | 0.453        |
------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 92            |
|    iterations           | 58            |
|    time_elapsed         | 320           |
|    t

-------------------------------------------
| time/                   |               |
|    fps                  | 91            |
|    iterations           | 68            |
|    time_elapsed         | 378           |
|    total_timesteps      | 34816         |
| train/                  |               |
|    approx_kl            | 5.5163517e-05 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.92         |
|    explained_variance   | 0.136         |
|    learning_rate        | 1e-06         |
|    loss                 | 0.0783        |
|    n_updates            | 670           |
|    policy_gradient_loss | -0.000196     |
|    value_loss           | 0.276         |
-------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 91            |
|    iterations           | 69            |
|    time_elapsed         | 384 

-------------------------------------------
| time/                   |               |
|    fps                  | 90            |
|    iterations           | 79            |
|    time_elapsed         | 446           |
|    total_timesteps      | 40448         |
| train/                  |               |
|    approx_kl            | 1.1336058e-05 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.92         |
|    explained_variance   | 0.198         |
|    learning_rate        | 1e-06         |
|    loss                 | 0.121         |
|    n_updates            | 780           |
|    policy_gradient_loss | 6.03e-05      |
|    value_loss           | 0.388         |
-------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 90           |
|    iterations           | 80           |
|    time_elapsed         | 452     

------------------------------------------
| time/                   |              |
|    fps                  | 90           |
|    iterations           | 90           |
|    time_elapsed         | 511          |
|    total_timesteps      | 46080        |
| train/                  |              |
|    approx_kl            | 6.925664e-06 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.92        |
|    explained_variance   | 0.0262       |
|    learning_rate        | 1e-06        |
|    loss                 | 1.29         |
|    n_updates            | 890          |
|    policy_gradient_loss | -1.62e-05    |
|    value_loss           | 2.34         |
------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 90           |
|    iterations           | 91           |
|    time_elapsed         | 517          |
|    total_

-------------------------------------------
| time/                   |               |
|    fps                  | 89            |
|    iterations           | 101           |
|    time_elapsed         | 575           |
|    total_timesteps      | 51712         |
| train/                  |               |
|    approx_kl            | 0.00015287276 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.9          |
|    explained_variance   | -0.0499       |
|    learning_rate        | 1e-06         |
|    loss                 | 0.0498        |
|    n_updates            | 1000          |
|    policy_gradient_loss | -0.00089      |
|    value_loss           | 0.134         |
-------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 89           |
|    iterations           | 102          |
|    time_elapsed         | 581     

------------------------------------------
| time/                   |              |
|    fps                  | 89           |
|    iterations           | 112          |
|    time_elapsed         | 639          |
|    total_timesteps      | 57344        |
| train/                  |              |
|    approx_kl            | 8.137303e-05 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.89        |
|    explained_variance   | 0.00686      |
|    learning_rate        | 1e-06        |
|    loss                 | 0.0693       |
|    n_updates            | 1110         |
|    policy_gradient_loss | -0.000133    |
|    value_loss           | 0.521        |
------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 89           |
|    iterations           | 113          |
|    time_elapsed         | 644          |
|    total_

--------------------------------------------
| time/                   |                |
|    fps                  | 89             |
|    iterations           | 123            |
|    time_elapsed         | 703            |
|    total_timesteps      | 62976          |
| train/                  |                |
|    approx_kl            | 0.000102029066 |
|    clip_fraction        | 0              |
|    clip_range           | 0.2            |
|    entropy_loss         | -1.86          |
|    explained_variance   | 0.215          |
|    learning_rate        | 1e-06          |
|    loss                 | 30.5           |
|    n_updates            | 1220           |
|    policy_gradient_loss | 3.65e-05       |
|    value_loss           | 88.5           |
--------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 89           |
|    iterations           | 124          |
|    time_elapsed 

-------------------------------------------
| time/                   |               |
|    fps                  | 89            |
|    iterations           | 134           |
|    time_elapsed         | 767           |
|    total_timesteps      | 68608         |
| train/                  |               |
|    approx_kl            | 0.00011946948 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.8          |
|    explained_variance   | 0.196         |
|    learning_rate        | 1e-06         |
|    loss                 | 22.8          |
|    n_updates            | 1330          |
|    policy_gradient_loss | -0.000724     |
|    value_loss           | 84.5          |
-------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 89            |
|    iterations           | 135           |
|    time_elapsed         | 772 

-------------------------------------------
| time/                   |               |
|    fps                  | 89            |
|    iterations           | 145           |
|    time_elapsed         | 828           |
|    total_timesteps      | 74240         |
| train/                  |               |
|    approx_kl            | 0.00018953206 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.86         |
|    explained_variance   | 0.0038        |
|    learning_rate        | 1e-06         |
|    loss                 | 0.0972        |
|    n_updates            | 1440          |
|    policy_gradient_loss | -0.000862     |
|    value_loss           | 0.178         |
-------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 89            |
|    iterations           | 146           |
|    time_elapsed         | 833 

-------------------------------------------
| time/                   |               |
|    fps                  | 90            |
|    iterations           | 156           |
|    time_elapsed         | 883           |
|    total_timesteps      | 79872         |
| train/                  |               |
|    approx_kl            | 0.00013997639 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.84         |
|    explained_variance   | -0.00828      |
|    learning_rate        | 1e-06         |
|    loss                 | 0.134         |
|    n_updates            | 1550          |
|    policy_gradient_loss | -0.000684     |
|    value_loss           | 0.275         |
-------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 90            |
|    iterations           | 157           |
|    time_elapsed         | 889 

-------------------------------------------
| time/                   |               |
|    fps                  | 90            |
|    iterations           | 167           |
|    time_elapsed         | 940           |
|    total_timesteps      | 85504         |
| train/                  |               |
|    approx_kl            | 0.00015179336 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.79         |
|    explained_variance   | 0.0152        |
|    learning_rate        | 1e-06         |
|    loss                 | 0.114         |
|    n_updates            | 1660          |
|    policy_gradient_loss | -0.000292     |
|    value_loss           | 0.718         |
-------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 91            |
|    iterations           | 168           |
|    time_elapsed         | 945 

-------------------------------------------
| time/                   |               |
|    fps                  | 91            |
|    iterations           | 178           |
|    time_elapsed         | 996           |
|    total_timesteps      | 91136         |
| train/                  |               |
|    approx_kl            | 0.00026093074 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.82         |
|    explained_variance   | 0.00301       |
|    learning_rate        | 1e-06         |
|    loss                 | 0.0389        |
|    n_updates            | 1770          |
|    policy_gradient_loss | -0.00103      |
|    value_loss           | 0.104         |
-------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 91            |
|    iterations           | 179           |
|    time_elapsed         | 1000

-------------------------------------------
| time/                   |               |
|    fps                  | 91            |
|    iterations           | 189           |
|    time_elapsed         | 1054          |
|    total_timesteps      | 96768         |
| train/                  |               |
|    approx_kl            | 6.9042784e-05 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.8          |
|    explained_variance   | 0.0114        |
|    learning_rate        | 1e-06         |
|    loss                 | 0.0401        |
|    n_updates            | 1880          |
|    policy_gradient_loss | -0.000152     |
|    value_loss           | 0.147         |
-------------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 91          |
|    iterations           | 190         |
|    time_elapsed         | 1059        

<stable_baselines3.ppo.ppo.PPO at 0x25823afb490>

# Testing trained model

In [108]:
model = PPO.load("./train/best_model100000.zip")

In [113]:
env = gym.make("SuperMarioBros-v0",apply_api_compatibility=True,render_mode="human")
# Wrap the environment to reduce the action space : 7 instead of 256
env = JoypadSpace(env,SIMPLE_MOVEMENT)
# Graycscale the observation space
env= GrayScaleObservation(env, keep_dim=True)
# plt.imshow(env.reset()[0], cmap="Greys")  
# Wrap into the Dummy Environment
env = DummyVecEnv([lambda: env])
# Stack the frames (so the agent can predict the movements of ennemies)
env = VecFrameStack(env,4)

state = env.reset()
for steps in range(10000):
    action, _ = model.predict(state)
    state, reward, done, info = env.step(action)
    env.render()
env.close()

# To improve the model :

- Reduce the learning_rate
- Train for longer : 1 million epochs should be a minimum
