In [15]:
import gym
from gym import spaces

import time
import random

from copy import deepcopy

from collections import namedtuple
from collections import deque

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.distributions import Normal

In [10]:
# tuple for label information
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'done'))

# The following is a FIFO storage system that will push out the oldest entries if it exceeds the capacity
class ReplayMemory(deque):
    def __init__(self, capacity):
        super().__init__(maxlen=capacity)
            
    def sample(self, amt_sample):
        return random.sample(self, amt_sample)
    
    #expects to take in numpy or numeric values
    def store_transition(self, state, action, next_state, reward, done):
        trans = Transition(
            state,
            action,
            next_state,
            reward,
            done
        )
        self.append(trans)
    


In [1]:
class CNNNetwork(nn.Module):
    def __init__(self, output_dim):
        super(CNNNetwork, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5) # 3 in channels 
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16*13*13, 32)  # 4*4 from grid world size
        self.fc2 = nn.Linear(32, 16)
        self.fc3 = nn.Linear(16, output_dim)

    def forward(self, pov):
        x = self.pool(F.relu(self.conv1(pov)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*13*13)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return x

NameError: name 'nn' is not defined

In [None]:
class CriticNetwork(nn.Module):
    def __init__(self, action_dim):
        super(CriticNetwork, self).__init__()
        #used to process the POV
        self.cnn = CNNNetwork(1)
        
        #output of CNN + compass + actions
        self.action_dim = action_dim
        self.fc1 = nn.Linear(1 + 4 + action_dim, 1)

    def forward(self, pov, compass, action):
        x = self.cnn.forward(pov)
        compass = compass.view(-1, 1)
        action = action.view(-1, self.action_dim)
        x = self.fc1(torch.cat((x, compass, compass, compass, compass, action),dim=1))
        return x

In [None]:
class ActorNetwork(nn.Module):
    def __init__(self, action_dim):
        super(ActorNetwork, self).__init__()
        #used to process the POV
        self.cnn = CNNNetwork(1)
        #output of CNN + compass 
        self.fc1 = nn.Linear(1 + 4, action_dim)
        
        #for calculation of a normal distribution
        self.avg_out = nn.Linear(action_dim, action_dim)
        self.dev_out = nn.Linear(action_dim, action_dim)

    def forward(self, pov, compass):
        x = self.cnn.forward(pov)
        compass = compass.view(-1, 1)
        x = self.fc1(torch.cat((x, compass, compass, compass, compass),dim=1))
        avg = self.avg_out(x)
        log_dev = self.dev_out(x)
        #clamp the variance so that it is assured between ~ 2e-9 and 7.4
        log_dev = torch.clamp(log_dev, -20, 2)
        dev = torch.exp(log_dev)

        # Pre-squash distribution and sample
        distribution = Normal(avg, dev)
        action = distribution.rsample()
        
        logp = self.calc_logp(action, distribution)

        #squash the action so that it can only be (-180, 180)
        action = self.squash(action, 30)
        
        return action, logp
    
    def squash(self, action, range):
        return torch.tanh(action) * range
    
    # This is equation 21 from the original SAC paper in Appendix C
    def calc_logp(self, action, distribution):
        #this prevents the equation from approaching -inf
        epsilon = .0001
        return (distribution.log_prob(action) - torch.log(1 - torch.square(torch.tanh(action)) + epsilon)).sum(axis=1, keepdim=True)

In [17]:
class SoftACAgent():
    def __init__(self, env, config):
        self.config = config
        
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        
        self.alpha = self.config["ALPHA"]
        self.gamma = self.config["GAMMA"]
        self.polyak = self.config["POLYAK"]
        
        self.memory = ReplayMemory(self.config["MAX_MEMORY"])
        self.batch_size = self.config["BATCH_SIZE"]
        
        flat_in = gym.spaces.flatten_space(self.observation_space).shape[0]
        flat_out = gym.spaces.flatten_space(self.action_space).shape[0]
        
        #responsible for estimating our policy
        self.actor = ActorNetwork(1).to(device)
        
        #our critic is estimating the action value
        self.critic1 = CriticNetwork(1).to(device)
        self.critic2 = deepcopy(self.critic1)
        self.critic1_target = deepcopy(self.critic1)
        self.critic2_target = deepcopy(self.critic2)
        
        for p in self.critic1_target.parameters():
            p.requires_grad = False
        for p in self.critic2_target.parameters():
            p.requires_grad = False
        
        self.optimizer_actor = optim.SGD(self.actor.parameters(), lr=self.config["LEARNING_RATE"], momentum=0.9)
        self.optimizer_critic = optim.SGD(self.critic1.parameters(), lr=self.config["LEARNING_RATE"], momentum=0.9)

        
    def step(self, obs):
        with torch.no_grad():
            action, _ = self.actor(obs["pov"], obs["compassAngle"])
        return action
  
    def evaluate(self):
        try:
            transitions = self.memory.sample(self.batch_size)
        except ValueError:
            return False
        
        # manipulate transitions
        batch = Transition(*zip(*transitions))
        

        pov = torch.cat([state["pov"] for state in batch.state])
        compass = torch.cat([state["compassAngle"] for state in batch.state])
        actions = torch.cat(batch.action)
        next_pov = torch.cat([state["pov"] for state in batch.next_state])
        next_compass = torch.cat([state["compassAngle"] for state in batch.next_state])
        rewards = torch.cat(batch.reward)
        done_signals = torch.cat(batch.done)
        
        
        # optimize the critic
        self.optimizer_critic.zero_grad()
        critic_loss = self.compute_loss_critic(pov, compass, actions, next_pov, next_compass, rewards, done_signals)
        critic_loss.backward()
        self.optimizer_critic.step()
        
        # Freeze Q-networks so you don't waste computational effort 
        # computing gradients for them during the policy learning step.
        for p in self.critic1.parameters():
            p.requires_grad = False
        for p in self.critic2.parameters():
            p.requires_grad = False
        
        # optimize the critic
        self.optimizer_actor.zero_grad()
        loss_actor = self.compute_loss_actor(pov, compass)
        loss_actor.backward()
        self.optimizer_actor.step()
        
        for p in self.critic1.parameters():
            p.requires_grad = True
        for p in self.critic2.parameters():
            p.requires_grad = True
            
        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for c1,c1t,c2,c2t in zip(
                    self.critic1.parameters(),
                    self.critic1_target.parameters(),
                    self.critic2.parameters(),
                    self.critic2_target.parameters()
            ):
                c1t.data.mul_(self.polyak)
                c1t.data.add_((1 - self.polyak) * c1.data)
                c2t.data.mul_(self.polyak)
                c2t.data.add_((1 - self.polyak) * c2.data)
        

    # This function covers line 12 from the Spinning Up Article psedeocode
    def compute_loss_critic(self, pov, compass, actions, next_pov, next_compass, rewards, done_signals):
        q1 = self.critic1(pov, compass, actions)
        q2 = self.critic2(pov, compass, actions)
        
        # Bellman backup for Q functions
        with torch.no_grad():
            # Target actions come from *current* policy
            target_actions, log_probs = self.actor(next_pov, next_compass)
                                        
            # Target Q-values
            q1_targ = self.critic1_target(next_pov, next_compass, target_actions)
            q2_targ = self.critic2_target(next_pov, next_compass, target_actions)
            q_targ = torch.min(q1_targ, q2_targ)
            
            y = rewards + self.gamma * (1 - done_signals) * (q_targ - self.alpha * log_probs)
        
        loss_critic1 = ((q1-y)**2).mean()
        loss_critic2 = ((q2-y)**2).mean()
        loss_critic = loss_critic1 + loss_critic2
        
        return loss_critic
    
    # This function covers line 14 from the Spinning Up Article psedeocode
    def compute_loss_actor(self, pov, compass):
        actions, log_probs = self.actor(pov, compass)
        
        q1_targ = self.critic1_target(pov, compass, actions)
        q2_targ = self.critic2_target(pov, compass, actions)
        q_targ = torch.min(q1_targ, q2_targ)
        
        # Entropy-regularized policy loss
        loss_actor = (self.alpha * log_probs - q_targ).mean()

        return loss_actor
        