# Actor Critic Methods

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import gym
import sys

from collections import deque

import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.distributions import Categorical

#### Create the environment

In [2]:
env = gym.make("Taxi-v2")

In [3]:
obs = env.reset()
next_obs, reward, done, info = env.step(env.action_space.sample())
env.render()
env.close()

+---------+
|R: | : :G|
| : : : : |
| : : : : |
| |[43m [0m: | : |
|[35mY[0m| : |[34;1mB[0m: |
+---------+
  (South)


In [4]:
print("Observation space:", env.observation_space.n)
print("Action space:", env.action_space.n)

Observation space: 500
Action space: 6


## Implementation

#### Algorithm
---
```
Input: a differentiable policy parameterization pi(a|s, theta)                   [Policy Network]
Input: a differentiable state-value function parameterization Q_w(s, a, w)       [Value Network]
Parameters: step sizes alpha_theta > 0; alpha_w > 0

Loop forever for each episode:

        Initialise S, theta
        Sample a from policy network
        
        Loop while S is not terminal for each time step:
                A = pi(.|S, theta) [policy(state)]
                Take action A, observe S', R
                delta = R + gamma * A(S', A', w) - A(S, A, w)  [TD(0) error, or advantage]
                theta = theta + alpha_theta * grad_pi log pi_theta(s,a) A(S,A)     [policy gradient update]
                w = w + alpha_w * delta * x(s, a)    [TD(0)]
                A = A', S = S'
```
---

In [5]:
class PolicyNetwork(nn.Module):
    """
    The policy network
    Args:
        n_inputs (int)
        n_outputs (int)
    """
    
    def __init__(self, n_inputs, n_outputs):
        super().__init__()
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        
        self.fc1 = nn.Linear(self.n_inputs, 128)
        self.fc2 = nn.Linear(128, self.n_outputs)
        
    def forward(self, x):
        """
        Forward pass
        Args:
            x (torch.Tensor)
        """
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim=-1)

In [6]:
class ValueNetwork(nn.Module):
    """Value network for value approximation"""
    
    def __init__(self, state_size, action_size):
        super().__init__()
        self.state_size = state_size
        self.action_size = action_size
        
        # MLP layers
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, action_size)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [7]:
class A2CAgent:
    """Actor Critic Agent"""
    
    def __init__(self, n_state, n_action, policy_network, value_network, gamma=0.9):
        
        self.env = env
        
        self.n_state = n_state
        self.n_action = n_action
        
        # Initialise the model
        self.policy_network = policy_network(self.n_state, self.n_action)
        self.value_network = value_network(self.n_state, 1)
        
        # Hypterparameters
        self.gamma = gamma
    
    def act(self, state):
        state = state.float()
        probs = self.policy_network(Variable(state))
        m = Categorical(probs)
        action = m.sample()
        log_probs = m.log_prob(action)
#         policy.saved_log_probs.append(log_prob)

        return log_probs, action.item()

In [8]:
NUM_EPISODES = 1000

In [19]:
agent = A2CAgent(env.observation_space.n, env.action_space.n, PolicyNetwork, ValueNetwork)
reward_list = deque(maxlen=100)

critic_optimizer = torch.optim.Adam(agent.value_network.parameters(), lr=1e-3)
actor_optimizer = torch.optim.Adam(agent.policy_network.parameters(), lr=1e-5)

for i in range(501):
    
    # Log the results
    action_log_probs = []
    rewards = []
    states = []
    targets = []
    errors = []
    t = 0
    R = 0
    
    state = env.reset()
    
    while True:
        
        # Select and take action
        state_tensor = torch.from_numpy(np.array(state))
        state_one_hot = F.one_hot(state_tensor, num_classes=agent.n_state).float()
        action_log_prob, action = agent.act(state_one_hot)
        next_state, reward, done, _ = env.step(action)
        
        next_state_tensor = torch.from_numpy(np.array(next_state))
        next_state_one_hot = F.one_hot(next_state_tensor, num_classes=agent.n_state).float()
        
        # Calculate predictions and error
        next_value =  agent.value_network(next_state_one_hot)
        value = agent.value_network(state_one_hot)
        td_target = reward + agent.gamma * next_value
        td_error = td_target - value.detach()
        
        # Gradient update for policy optimizer
        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        
        critic_loss = torch.nn.MSELoss(reduction="sum")(td_target, value)
        actor_loss = -action_log_prob * td_error
        
        actor_loss.backward(retain_graph=True)
        critic_loss.backward()
#         total_loss.backward()

        actor_optimizer.step()
        critic_optimizer.step()
        

        R += reward
        
        if done:
            reward_list.append(R)
            break
            
        state = next_state
    
    if i % 10 == 0:
        print("\rEpisode %s \t Average Score: %s" % (i, np.mean(reward_list)))

Episode 0 	 Average Score: -776.0
Episode 10 	 Average Score: -765.4545454545455
Episode 20 	 Average Score: -753.9047619047619
Episode 30 	 Average Score: -761.3225806451613
Episode 40 	 Average Score: -744.219512195122
Episode 50 	 Average Score: -742.156862745098
Episode 60 	 Average Score: -719.4754098360655


KeyboardInterrupt: 