<a href="https://colab.research.google.com/github/RLWH/reinforcement-learning-notebook/blob/master/6.%20Policy%20Gradient/Actor_Critic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Actor Critic Methods

In [0]:
import numpy as np
import matplotlib.pyplot as plt
import gym
import sys

from collections import deque

import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.distributions import Categorical

#### Understanding the environment

In [0]:
env = gym.make("Taxi-v2")

In [0]:
obs = env.reset()

In [0]:
env.env.P[obs]

{0: [(1.0, 312, -1, False)],
 1: [(1.0, 112, -1, False)],
 2: [(1.0, 232, -1, False)],
 3: [(1.0, 212, -1, False)],
 4: [(1.0, 212, -10, False)],
 5: [(1.0, 212, -10, False)]}

In [0]:
env.render()

+---------+
|[35mR[0m: | : :G|
| : : : : |
|[43m [0m: : : : |
| | : | : |
|Y| : |[34;1mB[0m: |
+---------+



In [0]:
next_obs, reward, done, info = env.step(0)
env.render()

+---------+
|[35mR[0m: | : :G|
| : : : : |
| : : : : |
|[43m [0m| : | : |
|Y| : |[34;1mB[0m: |
+---------+
  (South)


In [0]:
print("Observation space:", env.observation_space.n)
print("Action space:", env.action_space.n)

Observation space: 500
Action space: 6


## Implementation

#### Algorithm
---
```
Input: a differentiable policy parameterization pi(a|s, theta)                   [Policy Network]
Input: a differentiable state-value function parameterization Q_w(s, a, w)       [Value Network]
Parameters: step sizes alpha_theta > 0; alpha_w > 0

Loop forever for each episode:

        Initialise S, theta
        Sample a from policy network
        
        Loop while S is not terminal for each time step:
                A = pi(.|S, theta) [policy(state)]
                Take action A, observe S', R
                delta = R + gamma * A(S', A', w) - A(S, A, w)  [TD(0) error, or advantage]
                theta = theta + alpha_theta * grad_pi log pi_theta(s,a) A(S,A)     [policy gradient update]
                w = w + alpha_w * delta * x(s, a)    [TD(0)]
                A = A', S = S'
```
---

In [0]:
# Check GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [0]:
class PolicyNetwork(nn.Module):
    """
    The policy network
    Args:
        n_inputs (int)
        n_outputs (int)
    """
    
    def __init__(self, n_inputs, n_outputs):
        super().__init__()
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        
        self.embeddings = nn.Embedding(self.n_inputs, 50)
        self.fc1 = nn.Linear(50, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, self.n_outputs)
        
        self.dropout1 = nn.Dropout(p=0.5)
        self.dropout2 = nn.Dropout(p=0.5)
        
    def forward(self, x):
        """
        Forward pass
        Args:
            x (torch.Tensor)
        """
        x = self.embeddings(x)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        
        return F.softmax(x, dim=-1)

In [0]:
class ValueNetwork(nn.Module):
    """Value network for value approximation"""
    
    def __init__(self, state_size, action_size):
        super().__init__()
        self.state_size = state_size
        self.action_size = action_size
        
        # MLP layers
        self.embeddings = nn.Embedding(self.state_size, 50)
        self.fc1 = nn.Linear(50, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, action_size)
        
        self.dropout1 = nn.Dropout(p=0.5)
        self.dropout2 = nn.Dropout(p=0.5)
        
    def forward(self, x):
        x = self.embeddings(x)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

In [0]:
class A2CAgent:
    """Actor Critic Agent"""
    
    def __init__(self, n_state, n_action, policy_network, value_network, gamma=0.9):
        
        self.env = env
        
        self.n_state = n_state
        self.n_action = n_action
        
        # Initialise the model
        self.policy_network = policy_network(self.n_state, self.n_action).to(device)
        self.value_network = value_network(self.n_state, 1).to(device)
        
        # Hypterparameters
        self.gamma = gamma
    
    def act(self, state):
#         state = state.float()
        probs = self.policy_network(Variable(state))
        m = Categorical(probs)
        action = m.sample()
        log_probs = m.log_prob(action)
#         policy.saved_log_probs.append(log_prob)

        return log_probs, action.item()

In [0]:
NUM_EPISODES = 10000

In [0]:
agent = A2CAgent(env.observation_space.n, env.action_space.n, PolicyNetwork, ValueNetwork)
reward_list = deque(maxlen=100)

critic_optimizer = torch.optim.Adam(agent.value_network.parameters(), lr=1e-3)
actor_optimizer = torch.optim.Adam(agent.policy_network.parameters(), lr=1e-5)

for i in range(NUM_EPISODES + 1):
    
    # Log the results
    action_log_probs = []
    rewards = []
    states = []
    targets = []
    errors = []
    t = 0
    R = 0
    
    state = env.reset()
    
    while True:
        
        # Select and take action
        state_tensor = torch.from_numpy(np.array(state)).to(device)
#         state_one_hot = F.one_hot(state_tensor, num_classes=agent.n_state).float()
        action_log_prob, action = agent.act(state_tensor)
        next_state, reward, done, _ = env.step(action)
        
        next_state_tensor = torch.from_numpy(np.array(next_state)).to(device)
        next_action_log_prob, next_action = agent.act(next_state_tensor)
        next_next_state, next_reward, next_done, _ = env.step(next_action)
        
        if i % 5000 == 0:
            env.render()
        

        next_next_state_tensor = torch.from_numpy(np.array(next_next_state)).to(device)
#         next_state_one_hot = F.one_hot(next_state_tensor, num_classes=agent.n_state).float().to(device)
        
        # Calculate predictions and error
        next_value =  agent.value_network(next_state_tensor)
        next_next_value = agent.value_network(next_next_state_tensor)
        value = agent.value_network(state_tensor)
        td_target = reward + agent.gamma * next_value + agent.gamma ** 2 * next_next_value
        td_error = td_target - value.detach()
        
        # Gradient update for policy optimizer
        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        
        critic_loss = torch.nn.MSELoss(reduction="sum")(td_target, value)
        actor_loss = -action_log_prob * td_error
        
        actor_loss.backward(retain_graph=True)
        critic_loss.backward()
#         total_loss.backward()

        actor_optimizer.step()
        critic_optimizer.step()
        

        R += reward
        
        if done:
            reward_list.append(R)
            break
            
        state = next_state
    
    if i % 1000 == 0:
        print("\rEpisode %s \t Average Score: %s" % (i, np.mean(reward_list)))

+---------+
|R: | : :[34;1mG[0m|
| : : : : |
| : : : : |
| |[43m [0m: | : |
|[35mY[0m| : |B: |
+---------+
  (West)
+---------+
|R: | : :[34;1mG[0m|
| : : : : |
| : : : : |
| |[43m [0m: | : |
|[35mY[0m| : |B: |
+---------+
  (Dropoff)
+---------+
|R: | : :[34;1mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[35mY[0m|[43m [0m: |B: |
+---------+
  (West)
+---------+
|R: | : :[34;1mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[35mY[0m|[43m [0m: |B: |
+---------+
  (Dropoff)
+---------+
|R: | : :[34;1mG[0m|
| : : : : |
| :[43m [0m: : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (North)
+---------+
|R: | : :[34;1mG[0m|
| : : : : |
| : :[43m [0m: : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (Pickup)
+---------+
|R: | : :[34;1mG[0m|
| : : : : |
| : :[43m [0m: : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (East)
+---------+
|R: | : :[34;1mG[0m|
| : : : : |
| : : : : |
| | :[43m [0m| : |
|[35mY[0m| : |B: |
+---------+
  (Dropoff)
+---------+


KeyboardInterrupt: ignored

In [0]:
state

431

In [0]:
# agent = A2CAgent(env.observation_space.n, env.action_space.n, PolicyNetwork, ValueNetwork)
state_tensor = torch.from_numpy(np.array(431)).to(device)
probs = agent.policy_network(Variable(state_tensor))
m = Categorical(probs)
action = m.sample()
log_probs = m.log_prob(action)

In [0]:
probs

tensor([0.2010, 0.1144, 0.1608, 0.2253, 0.1806, 0.1179], device='cuda:0',
       grad_fn=<SoftmaxBackward>)

In [0]:
log_probs, action3

(tensor(-2.1383, device='cuda:0', grad_fn=<SqueezeBackward1>),
 tensor(5, device='cuda:0'))