In [1]:
import os
from stable_baselines3.common.callbacks import BaseCallback

# Creating callback for PPO algorithm
class TrainAndLoggingCallback(BaseCallback):
    def __init__(self, check_frequency, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_frequency = check_frequency
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_frequency == 0:
            model_path = os.path.join(self.save_path, f"best_model_{self.n_calls}")
            self.model.save(model_path)

In [2]:
CHECKPOINT_DIR = './models/basic'
LOG_DIR = './logs/basic'

In [3]:
from envs.doom_env import DoomEnv
from stable_baselines3.common import env_checker

# Initializing environment
env = DoomEnv('vizdoom/scenarios/basic.cfg', True)

# Checking environment compatability
env_checker.check_env(env)

In [4]:
from stable_baselines3 import PPO

callback = TrainAndLoggingCallback(10000, CHECKPOINT_DIR)
model = PPO('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, learning_rate=0.0001, n_steps=2048)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env in a VecTransposeImage.


In [5]:
# Training model using PPO algorithm
model.learn(total_timesteps=70000, callback=callback)

Logging to ./logs/basic\PPO_1
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 33.8     |
|    ep_rew_mean     | -92.6    |
| time/              |          |
|    fps             | 32       |
|    iterations      | 1        |
|    time_elapsed    | 63       |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 37.4        |
|    ep_rew_mean          | -115        |
| time/                   |             |
|    fps                  | 33          |
|    iterations           | 2           |
|    time_elapsed         | 123         |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.009397702 |
|    clip_fraction        | 0.143       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.09       |
|    explained_variance   | 3.82e-05    |


<stable_baselines3.ppo.ppo.PPO at 0x17dcccaba90>

In [6]:
env.close()