# Imports for A3C

In [None]:
import gym
import torch as T
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np
import nle

# Used for while loop with counter:
from itertools import count

# Helper Functions

In [None]:
def concat_state(state):
    state = np.stack((state['glyphs'],
                   state['chars'],
                   state['colors'],
                   state['specials']))
    state = np.reshape(state, (21, 79, 4))
    state = state.flatten()
    return state

# Shared Optimizer
## Used by all Agents/Workers to update after each finish an episode

In [None]:
"""
Shared Optimizer is the optimzer that is updated throughout training from each of the workers. 
"""
class SharedAdam(T.optim.Adam):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
            weight_decay=0):
        super(SharedAdam, self).__init__(params, lr=lr, betas=betas, eps=eps,
                weight_decay=weight_decay)

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state['exp_avg'] = T.zeros_like(p.data)
                state['exp_avg_sq'] = T.zeros_like(p.data)

                state['exp_avg'].share_memory_()
                state['exp_avg_sq'].share_memory_()

# Actor/Critic NN
## Each Agent has their own set of NN to update when they are in their own independent run

In [None]:
class ActorCritic(nn.Module):
    def __init__(self, input_dims, n_actions, gamma=0.99):
        super(ActorCritic, self).__init__()
        
        self.gamma = gamma
        
        # Very simple 2 layer neural netwrork. 
        i, j, k = input_dims
        num_inputs = i * j * k
        hidden_layer = int(num_inputs / 2)
        
        # Two distinct networks, one for the policy and one for the value. 
        self.pi1 = nn.Linear(num_inputs, hidden_layer)
        self.v1 = nn.Linear(num_inputs, hidden_layer)
        self.pi = nn.Linear(hidden_layer, n_actions)
        self.v = nn.Linear(hidden_layer, 1)

        self.rewards = []
        self.actions = []
        self.states = []

    def remember(self, state, action, reward):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)

    def clear_memory(self):
        self.states = []
        self.actions = []
        self.rewards = []

    def forward(self, state):
        pi1 = T.sigmoid(self.pi1(state))
        v1 = T.sigmoid(self.v1(state))

        pi = self.pi(pi1)
        v = self.v(v1)

        return pi, v

    def calc_R(self, done):
        states = T.tensor(self.states, dtype=T.float)
        _, v = self.forward(states)

        R = v[-1]*(1-int(done))

        batch_return = []
        for reward in self.rewards[::-1]:
            R = reward + self.gamma*R
            batch_return.append(R)
        batch_return.reverse()
        batch_return = T.tensor(batch_return, dtype=T.float)

        return batch_return

    def calc_loss(self, done):
        states = T.tensor(self.states, dtype=T.float)
        actions = T.tensor(self.actions, dtype=T.float)

        returns = self.calc_R(done)
        pi, values = self.forward(states)
        values = values.squeeze()
        critic_loss = (returns-values)**2

        probs = T.softmax(pi, dim=0)
        # Get rid of negative values, and normalize the tensor in order 
        # for the probabilities to be valid for Categorical. 
#         probs[probs < 0] = 0
#         probs = T.nn.functional.normalize(probs, dim=0)
        dist = Categorical(probs)
        log_probs = dist.log_prob(actions)
        actor_loss = -log_probs*(returns-values)

        total_loss = (critic_loss + actor_loss).mean()
    
        return total_loss

    def choose_action(self, observation):
        observation = concat_state(observation)
        state = T.tensor(observation, dtype=T.float)

        pi, v = self.forward(state)
        probs = T.softmax(pi, dim=0)
        # Get rid of negative values, and normalize the tensor in order 
        # for the probabilities to be valid for Categorical. 
#         probs[probs < 0] = 0
#         probs = T.nn.functional.normalize(probs, dim=0)
        # Use softmax in order to determine the next action to take
        dist = Categorical(probs)
        
        # Sample the action with the highest probability. Should be from (zero to n_actions-1)
        action = dist.sample().item()
        return action

# Agent Class
## Initializes its own set of NN, updates shared optimizer accordingly

In [None]:
class Agent(mp.Process):
    def __init__(self, global_actor_critic, optimizer, input_dims, n_actions, 
                gamma, lr, name, global_ep_idx, env_id, num_eps):
        super(Agent, self).__init__()
        self.local_actor_critic = ActorCritic(input_dims, n_actions, gamma)
        self.global_actor_critic = global_actor_critic
        self.name = 'w%02i' % name
        self.episode_idx = global_ep_idx
        self.env = gym.make(env_id)
        self.optimizer = optimizer
        self.num_eps = num_eps

    def run(self):
        # Run an environment instance for num_eps
        for curr_ep in range(self.num_eps):
            # Print when an episode starts
            with self.episode_idx.get_lock():
                print(self.name, ' has started. Episode: ', self.episode_idx.value)
            # New NetHack Run is about to spawn. Delete anything from previous runs. 
            done = False
            observation = self.env.reset()
            score = 0
            self.local_actor_critic.clear_memory()
            
            # While the game has not ended
            for ep_step in count():
                action = self.local_actor_critic.choose_action(observation)
                next_observation, reward, done, _ = self.env.step(action)
                score += reward
                # Store the current value into the memory, used for calculating loss and R.
                observation_concat = concat_state(observation)
                self.local_actor_critic.remember(observation_concat, action, reward)
                
                # Used to update our Global Optimizer: "SharedAdam". Only do so every 5 iteration of each agent. 
                # Then, copy the global optimizer onto the the local one, so that each agent
                # uses the most updated version of it (contributed by all other agents). 
                if curr_ep % T_MAX == 0 or done:
                    loss = self.local_actor_critic.calc_loss(done)
                    self.optimizer.zero_grad()
                    loss.backward()
                    for local_param, global_param in zip(
                            self.local_actor_critic.parameters(),
                            self.global_actor_critic.parameters()):
                        global_param._grad = local_param.grad
                    self.optimizer.step()
                    self.local_actor_critic.load_state_dict(
                            self.global_actor_critic.state_dict())
                    self.local_actor_critic.clear_memory()
                observation = next_observation
                # Update Global episode counter. I.e. Keep track of total worker steps.
                with self.episode_idx.get_lock():
                    self.episode_idx.value += 1
                    print(self.name, 'Global Ep.', self.episode_idx.value, 'reward %.1f' % score, 'Local Ep.', curr_ep, 'Ep. Step', ep_step)
            with self.episode_idx.get_lock():
                print(self.name, 'has finished Episode:', curr_ep, 'with a score of', score)
        # End the NetHack Environment if it has done all of the episods its reponsible for
        self.env.close()

# Main Function

In [None]:
if __name__ == '__main__':
    lr = 1e-4
    # Environment string name to be used to make 
    # a new environment in gym. One per thread.
    env_id = 'NetHackScore-v0'
    # Number of actions in the environment env_id (NetHack-v0 has 79)
    n_actions = 23
    # Size of the input layer, i.e. size of the world in nethack
    # (Eventually with more of a third dimension as it gets more data
    # from the state dictionary)
    input_dims = (21, 79, 4)
    # Steps to play
    num_eps = 1
    T_MAX = 5
    
    # The main Network connecting all of the Agents. Holds the 
    # Most current parameters.
    global_actor_critic = ActorCritic(input_dims, n_actions)
    global_actor_critic.share_memory()
    optim = SharedAdam(global_actor_critic.parameters(), lr=lr, 
                        betas=(0.92, 0.999))
    global_ep = mp.Value('i', 0)
    
    # Initializes each Agent. One per thread.
    workers = [Agent(global_actor_critic,
                    optim,
                    input_dims,
                    n_actions,
                    gamma=0.99,
                    lr=lr,
                    name=i,
                    global_ep_idx=global_ep,
                    env_id=env_id, num_eps=num_eps) for i in range(mp.cpu_count())]
    [w.start() for w in workers]
    [w.join() for w in workers]
    print('Finished Training')

# Misc. Testing code
## Ignore, Used to make sure different variables are valid throughout the code.

In [None]:
env_test = gym.make('NetHackScore-v0', savedir=None)

In [None]:
len(env_test._actions)

In [None]:
observation = env_test.reset()

In [None]:
print(observation)

In [None]:
arr = np.stack((observation['glyphs'],
               observation['chars'],
               observation['colors'],
               observation['specials']))
print(arr.shape)
arr = np.reshape(arr, (21, 79, 4))
print(arr.shape)
arr = arr.flatten()
print(arr.shape)