# Actor Critic Methods

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import gym
import sys

from collections import deque

import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.distributions import Categorical

#### Create the environment

In [18]:
env = gym.make("CartPole-v1")

In [20]:
obs = env.reset()
next_obs, reward, done, info = env.step(env.action_space.sample())
env.render()
env.close()

In [22]:
print("Observation space:", env.observation_space.shape[0])
print("Action space:", env.action_space.n)

Observation space: 4
Action space: 2


## Implementation

#### Algorithm
---
```
Input: a differentiable policy parameterization pi(a|s, theta)                   [Policy Network]
Input: a differentiable state-value function parameterization Q_w(s, a, w)       [Value Network]
Parameters: step sizes alpha_theta > 0; alpha_w > 0

Loop forever for each episode:

        Initialise S, theta
        Sample a from policy network
        
        Loop while S is not terminal for each time step:
                A = pi(.|S, theta) [policy(state)]
                Take action A, observe S', R
                delta = R + gamma * A(S', A', w) - A(S, A, w)  [TD(0) error, or advantage]
                theta = theta + alpha_theta * grad_pi log pi_theta(s,a) A(S,A)     [policy gradient update]
                w = w + alpha_w * delta * x(s, a)    [TD(0)]
                A = A', S = S'
```
---

In [134]:
class PolicyNetwork(nn.Module):
    """
    The policy network
    Args:
        n_inputs (int)
        n_outputs (int)
    """
    
    def __init__(self, n_inputs, n_outputs):
        super().__init__()
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        
        self.reward_history = []
        self.loss_history = []
        
        self.fc1 = nn.Linear(self.n_inputs, 128)
        self.dropout1 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(128, self.n_outputs)
        self.softmax = nn.Softmax(dim=-1)
        
        # save log probs history and rewards history
        self.saved_log_probs = []
        self.rewards = []
        
    def reset(self):
        self.saved_log_probs = []
        self.rewards = []
        
    def forward(self, x):
        """
        Forward pass
        Args:
            x (torch.Tensor)
        """
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = self.fc2(x)
        return F.softmax(x, dim=-1)

In [135]:
class ValueNetwork(nn.Module):
    """Value network for value approximation"""
    
    def __init__(self, state_size, action_size):
        super().__init__()
        self.state_size = state_size
        self.action_size = action_size
        
        # MLP layers
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, action_size)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [140]:
class A2CAgent:
    """Actor Critic Agent"""
    
    def __init__(self, n_state, n_action, policy_network, value_network, gamma=0.9):
        
        self.env = env
        
        self.n_state = n_state
        self.n_action = n_action
        
        # Initialise the model
        self.policy_network = policy_network(self.n_state, self.n_action)
        self.value_network = value_network(self.n_state, 1)
        
        # Hypterparameters
        self.gamma = gamma
    
    def act(self, state):
        state = state.float()
        probs = self.policy_network(Variable(state))
        m = Categorical(probs)
        action = m.sample()
        log_probs = m.log_prob(action)
#         policy.saved_log_probs.append(log_prob)

        return log_probs, action.item()

In [141]:
NUM_EPISODES = 1000

In [157]:
agent = A2CAgent(env.observation_space.shape[0], env.action_space.n, PolicyNetwork, ValueNetwork)
reward_list = deque(maxlen=100)


for i in range(2000):
    
    # Log the results
    actions = []
    rewards = []
    states = []
    targets = []
    errors = []
    t = 0
    R = 0
    
    state = env.reset()
    
    while True:
        
        # Select and take action
        state_tensor = torch.from_numpy(state).float()
        
        action_log_probs, action = agent.act(state_tensor)
        next_state, reward, done, _ = env.step(action)
        
        next_state_tensor = torch.from_numpy(next_state).float()
        
        # Calculate predictions and error
        predicted_next_state_value =  agent.value_network(next_state_tensor)
        predicted_current_state_value = agent.value_network(state_tensor)
        target = reward + agent.gamma * predicted_next_state_value
        error = target - predicted_current_state_value
        
        # Gradient update for policy optimizer
        value_optimizer = torch.optim.Adam(agent.value_network.parameters(), lr=1e-2)
        value_optimizer.zero_grad()
        critic_loss = torch.nn.MSELoss(reduction="sum")(predicted_current_state_value, target)
        critic_loss.backward(retain_graph=True)
        value_optimizer.step()

        
        # Gradient update for value optimizer
#         predicted_next_state_value =  agent.value_network(next_state_tensor)
#         predicted_current_state_value = agent.value_network(state_tensor)
#         target = reward + agent.gamma * predicted_next_state_value
#         error = target - predicted_current_state_value
        
        policy_optimizer = torch.optim.Adam(agent.policy_network.parameters(), lr=1e-2)
        policy_optimizer.zero_grad()
        actor_loss = -action_log_probs * error
        actor_loss.backward()
        policy_optimizer.step()
        
        
        # Log the results
        states.append(state)
        actions.append(action)
        rewards.append(reward)
        targets.append(target)
        errors.append(error)
        R += reward
        
        
        if done:
            reward_list.append(R)
            break
            
        state = next_state
    
    if i % 100 == 0:
        print("\rEpisode %s \t Average Score: %s" % (i, np.mean(reward_list)))

Episode 0 	 Average Score: 14.0
Episode 100 	 Average Score: 9.48
Episode 200 	 Average Score: 9.44


KeyboardInterrupt: 

In [122]:
print(rewards)

[1.0, 1.0, 1.0, 1.0, 1.0]
