In [283]:
from __future__ import print_function
import vizdoom as vzd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

from vizdoom import GameVariable
from random import choice as random_choice
from time import sleep
from matplotlib import pyplot as plt
from collections import deque

In [2]:
# Sets time that will pause the engine after each action (in seconds)
# Without this everything would go too fast for you to keep track of what's happening.
sleep_time = 1.0 / vzd.DEFAULT_TICRATE  # = 0.028

In [48]:
def create_game():
    # Create DoomGame instance. It will run the game and communicate with you.
    game = vzd.DoomGame()

    # Now it's time for configuration!
    # load_config could be used to load configuration instead of doing it here with code.
    # If load_config is used in-code configuration will also work - most recent changes will add to previous ones.
    # game.load_config("my_basic.cfg") TODO

    # Sets path to additional resources wad file which is basically your scenario wad.
    # If not specified default maps will be used and it's pretty much useless... unless you want to play good old Doom.
    game.set_doom_scenario_path("../scenarios/basic.wad")

    # Sets map to start (scenario .wad files can contain many maps).
    game.set_doom_map("map01")

    # Sets resolution. Default is 320X240
    game.set_screen_resolution(vzd.ScreenResolution.RES_640X480)

    # Sets the screen buffer format. Not used here but now you can change it. Default is CRCGCB.
    game.set_screen_format(vzd.ScreenFormat.RGB24)

    # Enables depth buffer.
    game.set_depth_buffer_enabled(True)

    # Enables labeling of in game objects labeling.
    game.set_labels_buffer_enabled(True)

    # Enables buffer with top down map of the current episode/level.
    game.set_automap_buffer_enabled(True)

    # Enables information about all objects present in the current episode/level.
    game.set_objects_info_enabled(True)

    # Enables information about all sectors (map layout).
    game.set_sectors_info_enabled(True)

    # Sets other rendering options (all of these options except crosshair are enabled (set to True) by default)
    game.set_render_hud(False)
    game.set_render_minimal_hud(False)  # If hud is enabled
    game.set_render_crosshair(True)
    game.set_render_weapon(True)
    game.set_render_decals(False)  # Bullet holes and blood on the walls
    game.set_render_particles(False)
    game.set_render_effects_sprites(False)  # Smoke and blood
    game.set_render_messages(False)  # In-game messages
    game.set_render_corpses(False)
    game.set_render_screen_flashes(True)  # Effect upon taking damage or picking up items

    # Adds buttons that will be allowed.
    game.add_available_button(vzd.Button.MOVE_LEFT)
    game.add_available_button(vzd.Button.MOVE_RIGHT)
    game.add_available_button(vzd.Button.ATTACK)

    # Adds game variables that will be included in state.
    game.add_available_game_variable(vzd.GameVariable.AMMO2)

    # Causes episodes to finish after 200 tics (actions)
    game.set_episode_timeout(200)

    # Makes episodes start after 10 tics (~after raising the weapon)
    game.set_episode_start_time(10)

    # Makes the window appear (turned on by default)
    game.set_window_visible(True)

    # Turns on the sound. (turned off by default)
    game.set_sound_enabled(False)

    # Sets the living reward (for each move) to -1
    game.set_living_reward(-1)

    # Sets ViZDoom mode (PLAYER, ASYNC_PLAYER, SPECTATOR, ASYNC_SPECTATOR, PLAYER mode is default)
    game.set_mode(vzd.Mode.PLAYER)

    # Define some actions. Each list entry corresponds to declared buttons:
    # MOVE_LEFT, MOVE_RIGHT, ATTACK
    # game.get_available_buttons_size() can be used to check the number of available buttons.
    # 5 more combinations are naturally possible but only 3 are included for transparency when watching.
    actions = [[True, False, False], [False, True, False], [False, False, True]]
    
    return game, actions

In [256]:
def run(game, agent, actions, episodes, verbose=True, print_step_info=False):
    game.init()

    for i in range(episodes):
        game.new_episode()
        print("Episode #" + str(i + 1))
        
        stack_size = 4
        stacked_frames = deque([torch.zeros((299 , 399)) for i in range(stack_size)], maxlen=stack_size)
        # fill the initial deque with zeros of the same shape the frame is after preprocessing

        while not game.is_episode_finished():
            state = game.get_state()      
            observation = preprocess(state.screen_buffer)
            
            stacked_frames.append(observation)
            
            
            # return observation # for exammple state
            
            reward = game.make_action(choice(actions))
            
            if print_step_info:
                print("State #" + str(state.number))
                print("Reward:", reward)
                print("=====================")

            if sleep_time > 0:
                sleep(sleep_time)

        if verbose:        
            print("Episode finished.")
            print("Total reward:", game.get_total_reward())
            print("************************")

    game.close()

In [260]:
def preprocess(img):
    rgb_weights = [0.2989, 0.5870, 0.1140]
    img = img @ rgb_weights
    img = img[181:,121:-120]
    
    return torch.tensor(img).float()

In [258]:
def plot_state(state, gray):
    plt.figure(figsize=(12,8))
    if gray:
        plt.imshow(preprocess(screen), "gray");
    else:
        plt.imshow(screen)

In [259]:
game, actions = create_game()
run(game, actions, 1)

Episode #1
Episode finished.
Total reward: 91.0
************************


In [163]:
#screen = run(game, actions, 10)
game.close()

Episode #1


In [169]:
#plot_state(screen, False), plot_state(screen, True)

In [213]:
class Qnet(nn.Module):
    def __init__(self, action_size):
        super(Qnet, self).__init__()
        self.action_size = action_size
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(4, 16, 3, padding=0, stride=2, bias=False), # (16, 149, 199)
            nn.BatchNorm2d(16),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, padding=(1,0), stride=2, bias=False), # (32, 75, 99)
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=0, stride=2, bias=False), # (64, 37, 49)
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=0, stride=2, bias=False), # (128, 18, 24)
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.fc1 = nn.Sequential(
            nn.Linear(128, 100),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(100, self.action_size),
            nn.Softmax(0)
        )
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.mean((2, 3)) # global average pool
        x = x.flatten()
        x = self.fc1(x)
        x = self.fc2(x)
        
        return x

In [290]:
class DQNAgent:
    def __init__(self, action_size, actions, epsilon=0.99):
        self.action_size = action_size
        self.q_net = Qnet(action_size)
        self.epsilon = epsilon
        self.actions = actions
    
    def chose_action(self, state):
        if np.random.uniform() < self.epsilon:
            return random_choice(actions)
        else:
            

In [285]:
random_choice([1,2,3])

3

In [215]:
q = Qnet(action_size=3)

In [288]:
np.random.uniform()

0.5175494098527065

In [247]:
torch.zeros(299, 399).shape

torch.Size([299, 399])

In [245]:
preprocess(screen).shape

torch.Size([299, 399])

In [234]:
a = preprocess(screen).repeat((1, 4, 1, 1)).float()

In [235]:
b = q(a)

In [223]:
stacked_frames = deque(maxlen=4)

In [224]:
for _ in range(4):
    stacked_frames.append(preprocess(screen))

In [243]:
.shape

torch.Size([4, 299, 399])