In [8]:
import numpy as np
import random
import time
import cv2
import os
import stable_baselines3
import gymnasium as gym
from vizdoom import *
from gymnasium import spaces
from gym.spaces import Discrete, Box
from matplotlib import pyplot as plt
from stable_baselines3.common.callbacks import BaseCallback

In [9]:
class DOOMGym(gym.Env):
    def __init__(self, render = True):
        
        super(DOOMGym, self).__init__()
        
        self.game = DoomGame() # type: ignore
        self.game.load_config("github\\ViZDoom\\scenarios\\basic.cfg")
        
        if render == False:
            self.game.set_window_visible(False)
        else:
            self.game.set_window_visible(True)
        
        self.game.init()
        
        self.observation_space = spaces.Box(low = 0, high = 255, shape = (3, 240, 320), dtype = np.uint8)
        self.action_space = spaces.Discrete(3)
    
    def step(self, action):
        actions = np.identity(3, dtype=np.uint8)
        reward = self.game.make_action(actions[action], 4)
        
        if self.game.get_state():
            state = self.game.get_state().screen_buffer
            ammo = self.game.get_state().game_variables[0]
            info = {"ammo":ammo}
        else:
            state = np.zeros(self.observation_space.shape)
            info = {}
        
        done = self.game.is_episode_finished()
        terminated = done
        truncated = False  
        return state, reward, terminated, truncated, info
    
    def render():
        pass
    
    def reset(self, seed = None):
        super().reset(seed = seed)
        np.random.seed(seed)
        
        self.game.new_episode()
        state = self.game.get_state().screen_buffer
        return state, {}
        
    def close(self):
        self.game.close()

In [10]:
class TLCallback(BaseCallback):
    def __init__(self, check_freq, save_path, verbose = 1):
        super(TLCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path
    
    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok = True)
    
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)
        return True

In [11]:
CPT_DIR = "./train/train_basic"
LOG_DIR = "./logs/log_basic"

In [12]:
callback = TLCallback(check_freq = 1000, save_path = CPT_DIR)

In [95]:
from stable_baselines3.common import env_checker

In [107]:
env = DOOMGym()
env_checker.check_env(env)

In [13]:
env = DOOMGym(render = False)
model = stable_baselines3.PPO("CnnPolicy", env, tensorboard_log = LOG_DIR, verbose = 1, learning_rate = 0.0001, n_steps = 256)
model.learn(total_timesteps = 100000, callback = callback)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./logs/log_basic\PPO_1
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 23.2     |
|    ep_rew_mean     | -29.6    |
| time/              |          |
|    fps             | 13       |
|    iterations      | 1        |
|    time_elapsed    | 19       |
|    total_timesteps | 256      |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 27.9         |
|    ep_rew_mean          | -44.1        |
| time/                   |              |
|    fps                  | 3            |
|    iterations           | 2            |
|    time_elapsed         | 145          |
|    total_timesteps      | 512          |
| train/                  |              |
|    approx_kl            | 0.0025812471 |
|    clip_fraction        | 0.215        |
|    clip_range 

IndexError: tuple index out of range

In [118]:
state = env.reset()

In [None]:
env.step(2)

In [66]:
env.close()