<a href="https://colab.research.google.com/github/Khaarl/ViZDOOM-PPO/blob/STAGING/ViZDOOM_PPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# CELL 1: Mount Google Drive

from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Define the base folder in Google Drive
GDRIVE_BASE_FOLDER = "/content/drive/MyDrive/ViZDoom-PPO"

# Create the base folder if it doesn't exist
if not os.path.exists(GDRIVE_BASE_FOLDER):
    os.makedirs(GDRIVE_BASE_FOLDER)
    print(f"Created Google Drive folder: {GDRIVE_BASE_FOLDER}")
else:
    print(f"Google Drive folder exists: {GDRIVE_BASE_FOLDER}")

Mounted at /content/drive
Google Drive folder exists: /content/drive/MyDrive/ViZDoom-PPO


In [2]:
# CELL 2: Install Dependencies

!apt-get update
!apt-get install -y build-essential zlib1g-dev libsdl2-dev libjpeg-dev \
    nasm tar libbz2-dev libgtk2.0-dev cmake git libfluidsynth-dev libgme-dev \
    libopenal-dev timidity libwildmidi-dev unzip ffmpeg

!pip install vizdoom
!pip install stable-baselines3[extra]

Hit:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Get:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,626 B]
Get:3 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:4 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Hit:5 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:6 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:7 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ Packages [61.9 kB]
Hit:8 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Get:9 https://r2u.stat.illinois.edu/ubuntu jammy/main all Packages [8,590 kB]
Get:10 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease [24.3 kB]
Hit:11 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Get:12 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Get:13 http://security.ubuntu.com/ubuntu jammy-security/main amd64 P

In [6]:
# CELL 3: Download doom2.wad from Google Drive and Scenario Config to local storage (Corrected)

import os
import shutil

# Define local paths for scenario and storage
LOCAL_SCENARIO_PATH = "/content/scenarios/deathmatch.cfg"
LOCAL_STORAGE_PATH = "/content/scenarios/training_data"
LOCAL_MODEL_PATH = "/content/scenarios/training_data/models"
LOCAL_LOG_PATH = "/content/scenarios/training_data/logs"
LOCAL_TENSORBOARD_PATH = "/content/scenarios/training_data/tensorboard"
LOCAL_WAD_PATH = "/content/scenarios/doom2.wad"

# Define Google Drive path for the WAD file
GDRIVE_WAD_FOLDER = os.path.join(GDRIVE_BASE_FOLDER, "WADS")
GDRIVE_WAD_PATH = os.path.join(GDRIVE_WAD_FOLDER, "doom2.wad")

# Create local storage directories if they don't exist
for path in [LOCAL_STORAGE_PATH, LOCAL_MODEL_PATH, LOCAL_LOG_PATH, LOCAL_TENSORBOARD_PATH]:
    os.makedirs(path, exist_ok=True)
    print(f"Created local directory: {path}")

# Download doom2.wad from Google Drive if it exists
if os.path.exists(GDRIVE_WAD_PATH):
    shutil.copy(GDRIVE_WAD_PATH, LOCAL_WAD_PATH)
    print(f"Copied doom2.wad from {GDRIVE_WAD_PATH} to {LOCAL_WAD_PATH}")
else:
    print(f"doom2.wad not found in {GDRIVE_WAD_FOLDER}. Please make sure it exists.")

# Download deathmatch.cfg if it doesn't exist
if not os.path.exists(LOCAL_SCENARIO_PATH):
    !wget https://raw.githubusercontent.com/mwydmuch/ViZDoom/master/scenarios/deathmatch.cfg -P /content/scenarios/
    print(f"Downloaded deathmatch.cfg to {LOCAL_SCENARIO_PATH}")
else:
    print(f"deathmatch.cfg already exists at {LOCAL_SCENARIO_PATH}")

Created local directory: /content/scenarios/training_data
Created local directory: /content/scenarios/training_data/models
Created local directory: /content/scenarios/training_data/logs
Created local directory: /content/scenarios/training_data/tensorboard
Copied doom2.wad from /content/drive/MyDrive/ViZDoom-PPO/WADS/doom2.wad to /content/scenarios/doom2.wad
--2025-01-09 19:38:53--  https://raw.githubusercontent.com/mwydmuch/ViZDoom/master/scenarios/deathmatch.cfg
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1327 (1.3K) [text/plain]
Saving to: ‘/content/scenarios/deathmatch.cfg’


2025-01-09 19:38:53 (46.6 MB/s) - ‘/content/scenarios/deathmatch.cfg’ saved [1327/1327]

Downloaded deathmatch.cfg to /content/scenarios/deathmatch.cfg


In [7]:
# CELL 4: Define ViZDoom Environment with Reward Shaping (Corrected)

from vizdoom import *
import random
import time
import numpy as np
import gymnasium as gym
from gymnasium import spaces
import os

class VizdoomEnv(gym.Env):
    def __init__(self, scenario_path, frame_skip=4):
        super(VizdoomEnv, self).__init__()
        self.game = DoomGame()
        self.game.load_config(scenario_path)
        self.game.set_doom_game_path("/content/scenarios/doom2.wad") # Corrected path
        self.game.set_window_visible(False)
        self.game.set_mode(Mode.PLAYER)
        self.game.set_screen_format(ScreenFormat.GRAY8)
        self.game.set_screen_resolution(ScreenResolution.RES_640X480)
        self.game.init()
        self.frame_skip = frame_skip
        self.action_space = spaces.Discrete(self.game.get_available_buttons_size())
        self.observation_space = spaces.Box(low=0, high=255, shape=(self.game.get_screen_height(), self.game.get_screen_width(), 1), dtype=np.uint8)

        self.previous_kill_count = 0
        self.previous_ammo = 0
        self.previous_health = 100
        self.min_dist_prev = float('inf')

    def step(self, action):
        buttons = np.zeros(self.game.get_available_buttons_size())
        buttons[action] = 1

        reward = self.game.make_action(buttons.tolist(), self.frame_skip)
        done = self.game.is_episode_finished()

        if done:
            state = np.zeros(self.observation_space.shape, dtype=np.uint8)
        else:
            state = self.game.get_state().screen_buffer
            state = np.expand_dims(state, axis=-1)

        info = {}
        shaped_reward = reward + self._shape_reward()

        return state, shaped_reward, done, False, info

    def _shape_reward(self):
        reward = 0
        current_game_vars = self.game.get_state().game_variables

        if current_game_vars is None:
            return 0

        kill_count = current_game_vars[0]
        health = current_game_vars[1]
        ammo = current_game_vars[2]
        pos_x, pos_y = current_game_vars[3], current_game_vars[4]

        # Reward for kills
        reward += (kill_count - self.previous_kill_count) * 100.0

        # Penalty for ammo used
        reward -= (self.previous_ammo - ammo) * 0.1

        # Penalty for health loss
        health_loss = self.previous_health - health
        if health_loss > 0:
            reward -= health_loss * 1.0

        # Encourage survival
        reward += 0.1

        # Reward for getting closer to enemies (simplified for deathmatch - incentivizing movement)
        min_dist_now = self._get_closest_enemy_distance()
        if min_dist_now < self.min_dist_prev and min_dist_now < 500: #tune threshold
            reward += 0.05
        elif min_dist_now > self.min_dist_prev and self.min_dist_prev < 500:
            reward -= 0.05
        self.min_dist_prev = min_dist_now

        self.previous_kill_count = kill_count
        self.previous_ammo = ammo
        self.previous_health = health

        return reward

    def _get_closest_enemy_distance(self):
        min_dist = float('inf')
        current_game_vars = self.game.get_state().game_variables
        if current_game_vars is None:
            return min_dist
        px, py = current_game_vars[3], current_game_vars[4]

        for o in self.game.get_state().objects:
            if o.is_enemy():
                dist = ((px - o.position_x)**2 + (py - o.position_y)**2)**0.5
                min_dist = min(min_dist, dist)
        return min_dist

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.game.new_episode()
        state = self.game.get_state().screen_buffer
        state = np.expand_dims(state, axis=-1)
        self.previous_kill_count = 0
        self.previous_ammo = self.game.get_state().game_variables[2] if self.game.get_state() else 0
        self.previous_health = self.game.get_state().game_variables[1] if self.game.get_state() else 100
        self.min_dist_prev = float('inf')
        return state, {}

    def close(self):
        self.game.close()

In [8]:
# CELL 5: Train PPO Agent with User Input, Loading, and Saving

from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import CheckpointCallback
import os

# Get user input for the number of episodes
while True:
    try:
        num_episodes = int(input("Enter the number of training timesteps: "))
        if num_episodes > 0:
            break
        else:
            print("Please enter a positive number of timesteps.")
    except ValueError:
        print("Invalid input. Please enter a number.")

# Prompt user to load a pre-trained model
load_pretrained = input("Do you want to load a pre-trained model? (yes/no): ").lower()

if load_pretrained == "yes":
    while True:
        pretrained_model_path = input("Enter the path to the pre-trained model: ")
        if os.path.exists(pretrained_model_path):
            try:
                model = PPO.load(pretrained_model_path)
                print(f"Successfully loaded model from: {pretrained_model_path}")
                break
            except Exception as e:
                print(f"Error loading model: {e}")
                print("Please enter a valid path.")
        else:
            print("Model path does not exist. Please enter a valid path.")
    # Update the model's environment if it's different
    model.set_env(VizdoomEnv(LOCAL_SCENARIO_PATH))
else:
    # Create a new environment and model
    env = VizdoomEnv(LOCAL_SCENARIO_PATH)
    env = Monitor(env, LOCAL_LOG_PATH)
    model = PPO("CnnPolicy", env, verbose=1, tensorboard_log=LOCAL_TENSORBOARD_PATH)

# Define checkpoint callback for saving during training to local storage
checkpoint_callback = CheckpointCallback(save_freq=max(10000, num_episodes // 10),  # Save at least every 10000 steps or 10% of total
                                       save_path=LOCAL_MODEL_PATH,
                                       name_prefix="ppo_deathmatch")

# Train the agent
model.learn(total_timesteps=num_episodes, callback=checkpoint_callback)

# Save the final model to local storage
final_model_local_path = os.path.join(LOCAL_MODEL_PATH, "ppo_deathmatch_final")
model.save(final_model_local_path)
print(f"Final model saved locally to: {final_model_local_path}")

# Optionally save the final model to Google Drive
save_to_gdrive = input("Do you want to save the final model to Google Drive? (yes/no): ").lower()
if save_to_gdrive == "yes":
    gdrive_model_path = os.path.join(GDRIVE_BASE_FOLDER, "models", "ppo_deathmatch_final")
    try:
        model.save(gdrive_model_path)
        print(f"Final model saved to Google Drive: {gdrive_model_path}")
    except Exception as e:
        print(f"Error saving to Google Drive: {e}")

# Close the environment if it was created in this cell
if 'env' in locals():
    env.close()

print("Training finished!")

Enter the number of training timesteps: 1
Do you want to load a pre-trained model? (yes/no): 


FileDoesNotExistException: File "/content/scenarios/deathmatch.wad" does not exist.