In [2]:
#setup game
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT

In [3]:
env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, SIMPLE_MOVEMENT)

In [None]:
done = True
for step in range(100000): 
    if done: 
        env.reset()
    state, reward, done, info = env.step(env.action_space.sample())
    env.render()
env.close()

In [4]:
#preprocess environment
from gym.wrappers import GrayScaleObservation
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
from matplotlib import pyplot as plt


env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, SIMPLE_MOVEMENT)
env = GrayScaleObservation(env, keep_dim=True) #less colors = less processing
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env, 4, channels_order='last') #kind of like the computer keeping track, here remembers 4 frames in last order


In [5]:
#training the reinforcement learning model
import os
from stable_baselines3 import PPO #algorithm
from stable_baselines3.common.callbacks import BaseCallback #for saving state of trained AI

In [6]:
#saves model every X steps so it doesnt have to be retrained
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

In [7]:
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs/'

In [8]:
callback = TrainAndLoggingCallback(check_freq=100000, save_path=CHECKPOINT_DIR)

In [9]:
#AI model
model = PPO('CnnPolicy', env, verbose=1, tensorboard_log=LOG_DIR, learning_rate=0.000001, n_steps=512)

Using cpu device
Wrapping the env in a VecTransposeImage.


In [24]:
#start of training ai model
model.learn(total_timesteps=1000000, callback=callback)

Logging to ./logs/PPO_1
----------------------------
| time/              |     |
|    fps             | 74  |
|    iterations      | 1   |
|    time_elapsed    | 6   |
|    total_timesteps | 512 |
----------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 15            |
|    iterations           | 2             |
|    time_elapsed         | 65            |
|    total_timesteps      | 1024          |
| train/                  |               |
|    approx_kl            | 6.6111097e-06 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.95         |
|    explained_variance   | -0.00246      |
|    learning_rate        | 1e-06         |
|    loss                 | 264           |
|    n_updates            | 10            |
|    policy_gradient_loss | -5.04e-05     |
|    value_loss           | 673           |
-------------------------

-------------------------------------------
| time/                   |               |
|    fps                  | 10            |
|    iterations           | 13            |
|    time_elapsed         | 652           |
|    total_timesteps      | 6656          |
| train/                  |               |
|    approx_kl            | 1.4070189e-05 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.95         |
|    explained_variance   | 0.001         |
|    learning_rate        | 1e-06         |
|    loss                 | 0.0712        |
|    n_updates            | 120           |
|    policy_gradient_loss | -7.24e-05     |
|    value_loss           | 0.38          |
-------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 10            |
|    iterations           | 14            |
|    time_elapsed         | 703 

--------------------------------------------
| time/                   |                |
|    fps                  | 10             |
|    iterations           | 24             |
|    time_elapsed         | 1197           |
|    total_timesteps      | 12288          |
| train/                  |                |
|    approx_kl            | 1.44174555e-05 |
|    clip_fraction        | 0              |
|    clip_range           | 0.2            |
|    entropy_loss         | -1.94          |
|    explained_variance   | -0.00116       |
|    learning_rate        | 1e-06          |
|    loss                 | 0.0806         |
|    n_updates            | 230            |
|    policy_gradient_loss | -0.000221      |
|    value_loss           | 0.197          |
--------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 10           |
|    iterations           | 25           |
|    time_elapsed 

------------------------------------------
| time/                   |              |
|    fps                  | 9            |
|    iterations           | 35           |
|    time_elapsed         | 1795         |
|    total_timesteps      | 17920        |
| train/                  |              |
|    approx_kl            | 0.0001319577 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.94        |
|    explained_variance   | 0.0571       |
|    learning_rate        | 1e-06        |
|    loss                 | 48.7         |
|    n_updates            | 340          |
|    policy_gradient_loss | 0.000201     |
|    value_loss           | 109          |
------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 36            |
|    time_elapsed         | 1844          |
|    t

-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 46            |
|    time_elapsed         | 2392          |
|    total_timesteps      | 23552         |
| train/                  |               |
|    approx_kl            | 1.2708362e-05 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.94         |
|    explained_variance   | -0.0137       |
|    learning_rate        | 1e-06         |
|    loss                 | 0.108         |
|    n_updates            | 450           |
|    policy_gradient_loss | -8.42e-05     |
|    value_loss           | 0.177         |
-------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 9            |
|    iterations           | 47           |
|    time_elapsed         | 2453    

-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 57            |
|    time_elapsed         | 2994          |
|    total_timesteps      | 29184         |
| train/                  |               |
|    approx_kl            | 1.1924654e-05 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.93         |
|    explained_variance   | 0.0905        |
|    learning_rate        | 1e-06         |
|    loss                 | 0.161         |
|    n_updates            | 560           |
|    policy_gradient_loss | 0.000114      |
|    value_loss           | 0.433         |
-------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 58            |
|    time_elapsed         | 3047

-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 68            |
|    time_elapsed         | 3567          |
|    total_timesteps      | 34816         |
| train/                  |               |
|    approx_kl            | 1.5917583e-05 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.92         |
|    explained_variance   | 0.0101        |
|    learning_rate        | 1e-06         |
|    loss                 | 0.0787        |
|    n_updates            | 670           |
|    policy_gradient_loss | 0.000255      |
|    value_loss           | 0.441         |
-------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 9            |
|    iterations           | 69           |
|    time_elapsed         | 3617    

------------------------------------------
| time/                   |              |
|    fps                  | 9            |
|    iterations           | 79           |
|    time_elapsed         | 4132         |
|    total_timesteps      | 40448        |
| train/                  |              |
|    approx_kl            | 6.941811e-05 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.92        |
|    explained_variance   | -0.0197      |
|    learning_rate        | 1e-06        |
|    loss                 | 0.0587       |
|    n_updates            | 780          |
|    policy_gradient_loss | -0.000806    |
|    value_loss           | 0.125        |
------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 80            |
|    time_elapsed         | 4185          |
|    t

------------------------------------------
| time/                   |              |
|    fps                  | 9            |
|    iterations           | 90           |
|    time_elapsed         | 4701         |
|    total_timesteps      | 46080        |
| train/                  |              |
|    approx_kl            | 8.739869e-05 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.91        |
|    explained_variance   | 0.0108       |
|    learning_rate        | 1e-06        |
|    loss                 | 0.105        |
|    n_updates            | 890          |
|    policy_gradient_loss | -0.000237    |
|    value_loss           | 0.343        |
------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 9            |
|    iterations           | 91           |
|    time_elapsed         | 4752         |
|    total_

-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 101           |
|    time_elapsed         | 5276          |
|    total_timesteps      | 51712         |
| train/                  |               |
|    approx_kl            | 0.00021138391 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.89         |
|    explained_variance   | -0.063        |
|    learning_rate        | 1e-06         |
|    loss                 | 0.0957        |
|    n_updates            | 1000          |
|    policy_gradient_loss | -0.00112      |
|    value_loss           | 0.548         |
-------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 9            |
|    iterations           | 102          |
|    time_elapsed         | 5327    

--------------------------------------------
| time/                   |                |
|    fps                  | 9              |
|    iterations           | 112            |
|    time_elapsed         | 5855           |
|    total_timesteps      | 57344          |
| train/                  |                |
|    approx_kl            | 0.000106135965 |
|    clip_fraction        | 0              |
|    clip_range           | 0.2            |
|    entropy_loss         | -1.9           |
|    explained_variance   | -0.00425       |
|    learning_rate        | 1e-06          |
|    loss                 | 0.0478         |
|    n_updates            | 1110           |
|    policy_gradient_loss | -0.000749      |
|    value_loss           | 0.113          |
--------------------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 9            |
|    iterations           | 113          |
|    time_elapsed 

-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 123           |
|    time_elapsed         | 6416          |
|    total_timesteps      | 62976         |
| train/                  |               |
|    approx_kl            | 6.1137835e-05 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.89         |
|    explained_variance   | 0.302         |
|    learning_rate        | 1e-06         |
|    loss                 | 0.0485        |
|    n_updates            | 1220          |
|    policy_gradient_loss | -0.000354     |
|    value_loss           | 0.172         |
-------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 124           |
|    time_elapsed         | 6470

------------------------------------------
| time/                   |              |
|    fps                  | 9            |
|    iterations           | 134          |
|    time_elapsed         | 6986         |
|    total_timesteps      | 68608        |
| train/                  |              |
|    approx_kl            | 0.0002728314 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.89        |
|    explained_variance   | 0.104        |
|    learning_rate        | 1e-06        |
|    loss                 | 38.1         |
|    n_updates            | 1330         |
|    policy_gradient_loss | 0.000246     |
|    value_loss           | 81.2         |
------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 135           |
|    time_elapsed         | 7039          |
|    t

-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 145           |
|    time_elapsed         | 7562          |
|    total_timesteps      | 74240         |
| train/                  |               |
|    approx_kl            | 0.00013056211 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.87         |
|    explained_variance   | 0.0197        |
|    learning_rate        | 1e-06         |
|    loss                 | 0.128         |
|    n_updates            | 1440          |
|    policy_gradient_loss | -0.000347     |
|    value_loss           | 0.826         |
-------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 146           |
|    time_elapsed         | 7614

-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 156           |
|    time_elapsed         | 8139          |
|    total_timesteps      | 79872         |
| train/                  |               |
|    approx_kl            | 2.6324647e-05 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.87         |
|    explained_variance   | 0.0241        |
|    learning_rate        | 1e-06         |
|    loss                 | 0.0475        |
|    n_updates            | 1550          |
|    policy_gradient_loss | 0.000181      |
|    value_loss           | 0.207         |
-------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 157           |
|    time_elapsed         | 8191

------------------------------------------
| time/                   |              |
|    fps                  | 9            |
|    iterations           | 167          |
|    time_elapsed         | 8697         |
|    total_timesteps      | 85504        |
| train/                  |              |
|    approx_kl            | 0.0002051926 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.86        |
|    explained_variance   | 0.0979       |
|    learning_rate        | 1e-06        |
|    loss                 | 0.0526       |
|    n_updates            | 1660         |
|    policy_gradient_loss | -0.00083     |
|    value_loss           | 0.21         |
------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 168           |
|    time_elapsed         | 8749          |
|    t

-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 178           |
|    time_elapsed         | 9272          |
|    total_timesteps      | 91136         |
| train/                  |               |
|    approx_kl            | 2.0118314e-05 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.85         |
|    explained_variance   | -0.0274       |
|    learning_rate        | 1e-06         |
|    loss                 | 0.102         |
|    n_updates            | 1770          |
|    policy_gradient_loss | 0.000226      |
|    value_loss           | 0.377         |
-------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 179           |
|    time_elapsed         | 9324

------------------------------------------
| time/                   |              |
|    fps                  | 9            |
|    iterations           | 189          |
|    time_elapsed         | 9849         |
|    total_timesteps      | 96768        |
| train/                  |              |
|    approx_kl            | 0.0023889067 |
|    clip_fraction        | 0.0082       |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.83        |
|    explained_variance   | 0.29         |
|    learning_rate        | 1e-06        |
|    loss                 | 172          |
|    n_updates            | 1880         |
|    policy_gradient_loss | -0.00131     |
|    value_loss           | 439          |
------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 9             |
|    iterations           | 190           |
|    time_elapsed         | 9901          |
|    t

KeyboardInterrupt: 

In [11]:
model = PPO.load('./train/best_model_100000')

In [12]:
state = env.reset()

In [13]:
state = env.reset()

while True: 
    
    action, _ = model.predict(state)
    state, reward, done, info = env.step(action)
    env.render()



KeyboardInterrupt: 

In [14]:
env.close()