# 3. Train the Model

In [1]:
pip install pyqt5 --user

Note: you may need to restart the kernel to use updated packages.


## 3.1 Create Callback

In [2]:
# Import os for file path management
import os 
# Import Base Callback for saving models
from stable_baselines3.common.callbacks import BaseCallback
# Check Environment    
from stable_baselines3.common import env_checker

In [3]:
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

In [4]:
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs/'

In [5]:
callback = TrainAndLoggingCallback(check_freq=1000, save_path=CHECKPOINT_DIR)

## 3.2 Build DQN and Train

In [6]:
from stable_baselines3 import DQN
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack

In [7]:
from spidermanENV import Spiderman_ENV
env = Spiderman_ENV()

In [8]:
# from matplotlib import pyplot as plt
# import cv2
# while True:
#     plt.imshow(env.get_observation())

In [9]:
model = DQN('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, buffer_size=650000, learning_starts=5000)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




In [10]:
model.learn(total_timesteps=100000, callback=callback)

Logging to ./logs/DQN_26
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 77.5     |
|    ep_rew_mean      | -42.8    |
|    exploration_rate | 0.971    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 4        |
|    time_elapsed     | 76       |
|    total_timesteps  | 310      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 73.5     |
|    ep_rew_mean      | -21.4    |
|    exploration_rate | 0.944    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 4        |
|    time_elapsed     | 143      |
|    total_timesteps  | 588      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 70.7     |
|    ep_rew_mean      | -14.2    |
|    exploration_rate | 0.919    |
| time/               |       

FailSafeException: PyDirectInput fail-safe triggered from mouse moving to a corner of the screen. To disable this fail-safe, set pydirectinput.FAILSAFE to False. DISABLING FAIL-SAFE IS NOT RECOMMENDED.

In [None]:
model.load('train_first/best_mode l_50000') 

# 4. Test out Model

In [None]:
for episode in range(5): 
    obs = env.reset()
    done = False
    total_reward = 0
    while not done: 
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(int(action))
        time.sleep(0.01)
        total_reward += reward
    print('Total Reward for episode {} is {}'.format(episode, total_reward))
    time.sleep(2)