In [15]:
import gym
from gym import spaces

import time
import random

from copy import deepcopy

from collections import namedtuple
from collections import deque

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.distributions import Categorical

In [10]:
# tuple for label information
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'done'))

# The following is a FIFO storage system that will push out the oldest entries if it exceeds the capacity
class ReplayMemory(deque):
    def __init__(self, capacity):
        super().__init__(maxlen=capacity)
            
    def sample(self, amt_sample):
        return random.sample(self, amt_sample)
    
    #expects to take in numpy or numeric values
    def store_transition(self, state, action, next_state, reward, done):
        trans = Transition(
            torch.tensor(state, device=device).float().unsqueeze(0), 
            torch.tensor([action], device=device).float().unsqueeze(0),
            torch.tensor(next_state, device=device).float().unsqueeze(0) if next_state is not None else None, 
            torch.tensor([reward], device=device).float().unsqueeze(0),
            torch.tensor([done], device=device).int().unsqueeze(0)
        )
        self.append(trans)
    


In [11]:
class CNNNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(CNNNetwork, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5) # 3 in channels 
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16*13*13, 32)  # 4*4 from grid world size
        self.fc2 = nn.Linear(32, 16)
        self.fc3 = nn.Linear(16 + 1, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        #insert compass at end of the linear layer
        return x

In [12]:
class CNNNetwork(CNNNetwork):
    def forward(self, x):
        x = Linear.forward(self, x)
        #We need to change the distribution and make a parameterization trick
        return Categorical(F.softmax(x, dim=-1))

In [17]:
class SoftACAgent():
    def __init__(self, env, config):
        self.config = config
        
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        
        self.alpha = self.config["ALPHA"]
        self.gamma = self.config["GAMMA"]
        self.polyak = self.config["POLYAK"]
        
        self.memory = ReplayMemory(self.config["MAX_MEMORY"])
        self.batch_size = self.config["BATCH_SIZE"]
        
        flat_in = gym.spaces.flatten_space(self.observation_space).shape[0]
        flat_out = gym.spaces.flatten_space(self.action_space).shape[0]
        
        print(gym.spaces.flatten_space(self.observation_space))
        
        #responsible for estimating our policy
        self.actor = LinearDistribution(flat_in, flat_out).to(device)
        
        #our critic is estimating the action value
        self.critic1 = Linear(flat_in, flat_out).to(device)
        self.critic2 = deepcopy(self.critic1)
        self.critic1_target = deepcopy(self.critic1)
        self.critic2_target = deepcopy(self.critic2)
        
        for p in self.critic1_target.parameters():
            p.requires_grad = False
        for p in self.critic2_target.parameters():
            p.requires_grad = False
        
        self.optimizer_actor = optim.SGD(self.actor.parameters(), lr=self.config["LEARNING_RATE"], momentum=0.9)
        self.optimizer_critic = optim.SGD(self.critic1.parameters(), lr=self.config["LEARNING_RATE"], momentum=0.9)

        
    def step(self, obs):
        with torch.no_grad():
            obs = torch.tensor(obs, device=device).float()
            policy_distribution = self.actor(obs)
            action = policy_distribution.sample().item()
        return action
  
    def evaluate(self):
        try:
            transitions = self.memory.sample(self.batch_size)
        except ValueError:
            return False

        # manipulate transitions
        batch = Transition(*zip(*transitions))

        states = torch.cat(batch.state)
        actions = torch.cat(batch.action)
        next_states = torch.cat(batch.next_state)
        rewards = torch.cat(batch.reward)
        done_signals = torch.cat(batch.done)
        
        
        # optimize the critic
        self.optimizer_critic.zero_grad()
        critic_loss = self.compute_loss_critic(states, actions, next_states, rewards, done_signals)
        critic_loss.backward()
        self.optimizer_critic.step()
        
        # Freeze Q-networks so you don't waste computational effort 
        # computing gradients for them during the policy learning step.
        for p in self.critic1.parameters():
            p.requires_grad = False
        for p in self.critic2.parameters():
            p.requires_grad = False
        
        # optimize the critic
        self.optimizer_actor.zero_grad()
        loss_actor = self.compute_loss_actor(states)
        loss_actor.backward()
        self.optimizer_actor.step()
        
        for p in self.critic1.parameters():
            p.requires_grad = True
        for p in self.critic2.parameters():
            p.requires_grad = True
            
        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for c1,c1t,c2,c2t in zip(
                    self.critic1.parameters(),
                    self.critic1_target.parameters(),
                    self.critic2.parameters(),
                    self.critic2_target.parameters()
            ):
                c1t.data.mul_(self.polyak)
                c1t.data.add_((1 - self.polyak) * c1.data)
                c2t.data.mul_(self.polyak)
                c2t.data.add_((1 - self.polyak) * c2.data)
        

    # This function covers line 12 from the Spinning Up Article psedeocode
    def compute_loss_critic(self, states, actions, next_states, rewards, done_signals):
        ind_actions = actions.long()
        q1 = self.critic1(states).gather(1, ind_actions)
        q2 = self.critic2(states).gather(1, ind_actions)
        
        
        # Bellman backup for Q functions
        with torch.no_grad():
            # Target actions come from *current* policy
            target_distribution = self.actor(next_states)
            target_actions = target_distribution.sample()
            log_probs = target_distribution.log_prob(target_actions).unsqueeze(1)
            ind_target_actions = target_actions.long().unsqueeze(1)
                                        
            # Target Q-values
            q1_targ = self.critic1_target(next_states).gather(1, ind_target_actions)
            q2_targ = self.critic2_target(next_states).gather(1, ind_target_actions)
            q_targ = torch.min(q1_targ, q2_targ)
            
            y = rewards + self.gamma * (1 - done_signals) * (q_targ - self.alpha * log_probs)
        
        loss_critic1 = ((q1-y)**2).mean()
        loss_critic2 = ((q2-y)**2).mean()
        loss_critic = loss_critic1 + loss_critic2
        
        return loss_critic
    
    # This function covers line 14 from the Spinning Up Article psedeocode
    def compute_loss_actor(self, states):
        current_distribution = self.actor(states)
        actions = current_distribution.sample()
        log_probs = current_distribution.log_prob(actions).unsqueeze(1)
        ind_actions = actions.long().unsqueeze(1)
        
        q1_targ = self.critic1_target(states).gather(1, ind_actions)
        q2_targ = self.critic2_target(states).gather(1, ind_actions)
        q_targ = torch.min(q1_targ, q2_targ)
        
        # Entropy-regularized policy loss
        loss_actor = (self.alpha * log_probs - q_targ).mean()

        return loss_actor
        