# Imports

In [3]:
"""
Import Statements: From PyTorch RL Tutorial
"""
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from itertools import count
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import nle

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display
plt.ion()

# Experiences & Memories

In [4]:
class Experience():
    def __init__(self, state, action, next_state, reward):
        self.state - state
        self.action = action
        self.next_state = next_state
        self.reward = reward

In [5]:
class ReplayMemory():
    def __init__(self, max_memory):
        self.memory = []
        self.max_memory = max_memory
    
    def push(self, newExp):
        if len(self.memory) >= self.max_memory:
            self.memory.pop(0)
        self.memory.append(newExp)
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

# Policy (& Target) Network

In [6]:
class DQNAlphaNN(nn.Module):
    def __init__(self):
        super(DQNetwork, self).__init__()
        self.model_stack = nn.Sequential(
            nn.Conv2d(4, 16, kernel_size=3, padding='same'),
            nn.BatchNorm2d(16),
            nn.ReLU(), # 21 * 79 * 16
            nn.MaxPool2d(3, padding=(0,1)), # 7 * 27 * 16
            nn.MaxPool2d(3, padding=(1,0)), # 3 * 9 * 16
            nn.Flatten(),
            nn.Linear(3*9*16, 3*64),
            nn.ReLU(),
            nn.Linear(3*64, 113),
            nn.ReLU(),
        )
        
    def obs_to_x(self, obs):
        x = np.stack((obs['glyphs'], state['chars'], state['colors'], state['specials']))
        return torch.tensor(x, device=device)[None,:,:,:]
    
    def forward(self, x):
        return self.model_stack(x.float())

In [None]:
class RLNetwork():
    def __init__(self, Model, hyperparameters = {}):
        # setup gpu if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # set hyperparameters
        self.set_hyperparameters(hyperparameters)
        
        # setup environment
        self.env = gym.make('NetHackChallenge-v0', savedir=None)
        self.n_actions = env.action_space.n
        self.memory = ReplayMemory(self.max_memory)
        
        # setup policy & target networks
        self.policy_net = Model().to(self.device)
        self.target_net = Model().to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.RMSprop(self.policy_net.parameters())
        
        # tracking metrics through the games
        self.steps_done = 0
        self.episode_durations = []
        self.episode_rewards = []
        self.episode_avg_loss = []
    
    def set_hyperparameters(self, hyperparameters):
        self.batch_size = hyperparameters.get('batch_size', 64)
        self.gamma = hyperparameters.get('gamma', 0.999)
        self.eps_start = hyperparameters.get('eps_start', 0.9)
        self.eps_end = hyperparameters.get('eps_end', 0.05)
        self.eps_decay = hyperparameters.get('eps_decay', 200)
        self.target_update = hyperparameters.get('target_update', 10)
        self.max_memory = hyperparameters.get('max_memory', 100000)
        self.episodes = hyperparameters.get('episodes', 1)
    
    def select_action(state):
        sample = random.random()
        eps_threshold = self.eps_end + (self.eps_start - self.eps_end) * math.exp(-1. * self.steps_done / self.eps_decay)
        self.steps_done += 1
        if sample > eps_threshold:
            with torch.no_grad():
                # get the max policy value
                # only is 1 at a time
                return self.policy_net(state).max(1)[1].view(1, 1)
        else:
            return torch.tensor([[random.randrange(self.n_actions)]], device=self.device, dtype=torch.long)
    
    