In [1]:
import gym
from gym import spaces

import time
import random

from collections import namedtuple
from collections import deque

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [2]:
# The following is a LIFO storage system that will push out the oldest entries if it exceeds the capacity
class ReplayMemory(deque):
    def __init__(self, capacity):
        super().__init__(maxlen=capacity)
            
    def sample(self, amt_sample):
        return random.sample(self, amt_sample)
    
    def store_transition(self, trans):
        self.append(trans)
    
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

In [3]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5) # 3 in channels 
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16*13*13 , 32)  
        self.fc2 = nn.Linear(32, 16)
        self.fc3 = nn.Linear(16, 12)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*13*13)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
class DQNAgent():
    def __init__(self, obs_space, action_space, config):
        self.config = config
        
        self.observation_space = obs_space
        self.action_space = action_space
        
        self.epsilon = self.config["EPS"]
        self.epsilon_decay = self.config["EPS_DECAY"]
        self.gamma = self.config["GAMMA"]
        
        self.policy_network = CNN().to(device)
        self.target_network = CNN().to(device)
        self.target_network.load_state_dict(self.policy_network.state_dict())
        self.target_network.eval()
        
        self.optimizer = optim.SGD(self.policy_network.parameters(), lr=0.01, momentum=0.9)

        self.memory = ReplayMemory(self.config["MEMORY_CAPACITY"])
        self.batch_size = self.config["BATCH_SIZE"]
        
    def step(self, obs):
        if np.random.uniform(0,1) < self.epsilon:
            action = self.action_space.sample()
        else:
            with torch.no_grad():
                action_tensor = self.policy_network(obs).clone()
                action_tensor[action_tensor > 0] = 1
                action_tensor[action_tensor <= 0] = 0
                action = {
                    'attack': action_tensor[0][0].item(),
                    'back': action_tensor[0][1].item(),
                    'camera': {
                        'look_up': action_tensor[0][2].item(),
                        'look_down': action_tensor[0][3].item(),
                        'look_right': action_tensor[0][4].item(),
                        'look_left': action_tensor[0][5].item()
                    },
                    'forward': action_tensor[0][6].item(),
                    'jump': action_tensor[0][7].item(),
                    'left': action_tensor[0][8].item(),
                    'right': action_tensor[0][9].item(),
                    'sneak': action_tensor[0][10].item(),
                    'sprint': action_tensor[0][11].item()
                }
        return action
  
    def evaluate(self, old_state, action_taken, new_state, reward):
        if len(self.memory) < self.batch_size:
            return
        else:
            transitions = self.memory.sample(self.batch_size)

            batch = Transition(*zip(*transitions))

            all_states = torch.cat(batch.state)
            all_actions = torch.cat(batch.action)
            all_rewards = torch.cat(batch.reward)
            all_next_states = torch.cat([s for s in batch.next_state if s is not None])      

            #This will help us prevent the calculation for y when the final state is none... not working atm
            non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
            policy_q_table = self.policy_network(all_states)
            actual = torch.einsum('ij, ij->i', policy_q_table, all_actions)

            next_state_max_q = torch.zeros(self.batch_size, device=device)
            #will be zero on final states
            target_q_table = self.target_network(all_next_states)

            next_state_max_q[non_final_mask] = torch.sum(target_q_table[target_q_table > 0])

            y = (all_rewards + self.gamma *next_state_max_q)

            loss = F.l1_loss(actual, y)
            self.optimizer.zero_grad()
            loss.backward()
            for param in self.policy_network.parameters():
                param.grad.data.clamp_(-1, 1)
            self.optimizer.step()

    def agent_sync_networks(self):
        self.target_network.load_state_dict(self.policy_network.state_dict())
            
    def decay_epsilon(self):
        self.epsilon = np.clip(self.epsilon * self.epsilon_decay, .01, .99)