# Advantage Actor Critic (A2C) - Batch Learning

## Generalized Advantage Estimation (GAE)

Reference: [High-Dimensional Continuous Control Using Generalized Advantage Estimation](https://arxiv.org/abs/1506.02438)

Advantage Function:

$$
A^\pi(s_t, a_t) := Q^\pi(s_t, a_t) - V^\pi(s_t)
$$

TD error $\delta_t^V = r_t + \gamma V(s_{t+1}) - V(s_t)$를 정의하자. $n$-step Advantage 추정치를 아래와 같이 정의할 수 있다:

$$
\begin{align*}
    \hat{A}_t^{(1)} &:= \delta_t^V &= r_t + \gamma V(s_{t+1}) - V(s_t) \\
    \hat{A}_t^{(2)} &:= \delta_t^V + \gamma \delta_{t+1}^V &= r_t + \gamma r_{t+1} + \gamma^2 V(s_{t+2}) - V(s_t) \\
    \hat{A}_t^{(3)} &:= \delta_t^V + \gamma \delta_{t+1}^V + \gamma^2 \delta_{t+2}^V &= r_t + \gamma r_{t+1} + \gamma^2 r_{t+2} + \gamma^3 V(s_{t+3}) - V(s_t) \\
    \\
    \hat{A}_t^{(n)} &:= \sum_{l=0}^{n-1}\gamma^l \delta_{t+l}^V &= r_t + \gamma r_{t+1} + \gamma^2 r_{t+2} + \cdots + \gamma^n V(s_{t+n}) - V(s_t) \\
\end{align*}
$$

*Generalized Advantage Estimator* (GAE)는 $n$-step Advantage 추정치의 exponentially-weighted average로 정의된다:

$$
\begin{align*}
    \hat{A}_t^{\text{GAE}(\gamma, \lambda)} &:= (1 - \lambda)\Big(\hat{A}_t^{(1)} + \lambda \hat{A}_t^{(2)} + \lambda^2 \hat{A}_t^{(3)} + \dots \Big) \\
    &= \sum_{l=0}^\infty (\gamma \lambda)^l \delta_{t+l}^V
\end{align*}
$$

만약 $lambda = 0$일 경우 $\hat{A}_t := \delta_t$로 one-step TD error이다. 따라서 분산이 낮지만 편향이 크다. 반대로 $\lambda = 1$일 경우 $\hat{A}_t := \sum_{l=0}^\infty \gamma^l \delta_{t+l} = \sum_{l=0}^\infty \gamma^l r_{t+l} - V(s_t)$로 편향은 낮지만 분산이 크다. 따라서 $\lambda$는 분산과 편향 사이의 tradeoff를 적절히 조정하는 파라미터이다.

계산적 트릭을 사용하면 GAE를 recursive하게 계산할 수 있다:

$$
\begin{align*}
    \hat{A}_t^{\text{GAE}} = \delta_t^V + (\gamma \lambda) \hat{A}_{t+1}^{\text{GAE}}
\end{align*}
$$

In [63]:
import torch

def compute_gae(state_value: torch.Tensor, 
                reward: torch.Tensor, 
                terminated: torch.Tensor,
                gamma: float,
                lam: float) -> torch.Tensor:
    """
    Compute generalized advantage estimation (GAE) during n-step transitions. See details in https://arxiv.org/abs/1506.02438.

    Args:
        state_value (Tensor): `(n_steps + 1,)`, 마지막 transition에서의 next state value를 포함한 n-step 동안의 state value
        rewards (Tensor): `(n_steps,)`, n-step 동안의 reward
        terminateds (Tensor): `(n_steps,)`, n-step 동안의 terminated
        gamma (float): discount factor for future rewards
        lam (float): lambda which controls the balanace between bias and variance

    Returns:
        gae (Tensor): `(n_steps,)`, n-step 동안의 GAE
    """
    
    n_step = len(reward)
    gae = torch.empty_like(reward)
    discounted_gae = 0.0 # GAE at time step t+n
    not_terminated = 1 - terminated
    delta = reward + not_terminated * gamma * state_value[1:] - state_value[:-1]
    discount_factor = gamma * lam
    
    # compute GAE
    for t in reversed(range(n_step)):
        discounted_gae = delta[t] + not_terminated[t] * discount_factor * discounted_gae
        gae[t] = discounted_gae
     
    return gae

## Training A2C Agent in CartPole-v1

### Actor-Critic Network

Actor와 Critic이 parameter를 공유하는 parameter sharing 기법을 사용한다. 이는 Actor와 Critic이 state space에 대한 공통된 feature를 학습할 수 있고, 계산 효율성이 증가하는 장점이 있다. 또한 batch learning 기법을 활용하기 때문에 sample 효율성이 증가한다.

In [64]:
from typing import Tuple
import torch.nn as nn

class ActorCriticSharedNetwork(nn.Module):
    def __init__(self,
                 obs_features: int,
                 num_actions: int) -> None:
        super().__init__()
        
        hidden_features = 64
        
        # parameter sharing
        self.encoding_layer = nn.Sequential(
            nn.Linear(obs_features, 128),
            nn.ReLU(),
            nn.Linear(128, hidden_features),
            nn.ReLU()
        )
        
        self.actor = nn.Linear(hidden_features, num_actions)
        self.critic = nn.Linear(hidden_features, 1)
        
    def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Feed forward method.

        Args:
            obs (torch.Tensor): observation

        Returns:
            logits, state_value
        """
        encoding = self.encoding_layer(obs)
        logits = self.actor(encoding)
        state_value = self.critic(encoding)
        return logits, state_value.squeeze_(dim=-1)

### A2C 구현

advantage를 계산하기 위해 GAE를 사용한다. `A2C.train()` 메소드는 $n$-step 마다 호출된다.

configuration:

* `gamma`: discount factor $\gamma$
* `lam`: $\lambda$, GAE의 bias-variance tradeoff를 조정
* `value_loss_coef`: actor loss와 critic loss 결합 시 critic loss를 원래 값에 어느 정도 반영할 지 조정하는 multiplier

In [65]:
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

class A2C:
    def __init__(self,
                 obs_features: int,
                 num_actions: int,
                 gamma: float = 0.99,
                 lam: float = 0.95,
                 value_loss_coef: float = 0.5) -> None:
        self.gamma = gamma
        self.lam = lam
        self.value_loss_coef = value_loss_coef
        
        self.actor_critic = ActorCriticSharedNetwork(obs_features, num_actions)
        self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=0.001)
    
    def select_action(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Returns action, log pi(a|s), state value V(s).
        """
        # feed forward
        logits, state_value = self.actor_critic(obs)
        dist = Categorical(logits=logits)
        action = dist.sample()
        log_pi = dist.log_prob(action)
        return action, log_pi, state_value
    
    def train(self,
              fianl_next_obs: torch.Tensor,
              reward: torch.Tensor,
              terminated: torch.Tensor,
              log_pi: torch.Tensor,
              state_value: torch.Tensor) -> Tuple[float, float]:
        """
        Train A2C agent.

        Args:
            fianl_next_obs (torch.Tensor): `(obs_features,)`
            reward (torch.Tensor): `(n_steps,)`
            terminated (torch.Tensor): `(n_steps,)`
            log_pi (torch.Tensor): `(n_steps,)`
            state_value (torch.Tensor): `(n_steps,)`

        Returns:
            actor_loss (float): computed actor loss
            critic_loss (float): computed critic loss
        """
        # compute final next state value
        with torch.no_grad():
            _, final_next_state_value = self.actor_critic(fianl_next_obs)
            # (n_steps + 1,)
            state_value_with_final = torch.cat((state_value, final_next_state_value.unsqueeze(0)))
        
        # compute advantage
        advantage = compute_gae(
            state_value_with_final,
            reward,
            terminated,
            self.gamma,
            self.lam
        )
        
        # compute target state value
        target_state_value = advantage + state_value.detach()
        
        # compute actor-critic loss
        actor_loss = self._compute_actor_loss(log_pi, advantage)
        critic_loss = self._compute_critic_loss(state_value, target_state_value)
        
        # train step
        loss = actor_loss + self.value_loss_coef * critic_loss
        self._train_step(loss)
        
        return actor_loss.item(), critic_loss.item()
        
        
    def _compute_actor_loss(self,
                           log_pi: torch.Tensor,
                           advantage: torch.Tensor) -> torch.Tensor:
        return -(advantage * log_pi).mean()
    
    def _compute_critic_loss(self,
                            state_value: torch.Tensor,
                            target_state_value: torch.Tensor) -> torch.Tensor:
        return F.mse_loss(state_value, target_state_value)
    
    def _train_step(self,
                    loss: torch.Tensor):
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

### Make CartPole-v1 Environment

In [66]:
import gym

env_id = "CartPole-v1"

env = gym.make(env_id) # for training
inference_env = gym.make(env_id, render_mode="human") # for inference

#### Observation Features

In [67]:
obs_features = env.observation_space.shape
obs_features

(4,)

#### Number of Actions

In [68]:
num_actions = env.action_space.n
num_actions

2

### Instantiate Agent

In [69]:
agent = A2C(obs_features[0], num_actions)

### Check Outputs

In [70]:
obs, _ = env.reset()
obs

array([ 0.0003315 , -0.04382272, -0.04053487, -0.01080584], dtype=float32)

In [71]:
action, log_pi, state_value = agent.select_action(torch.from_numpy(obs))
(action, log_pi, state_value)

(tensor(0),
 tensor(-0.7390, grad_fn=<SqueezeBackward1>),
 tensor(0.0305, grad_fn=<SqueezeBackward3>))

In [72]:
env.step(action.numpy())

(array([-0.00054495, -0.2383406 , -0.04075098,  0.26881734], dtype=float32),
 1.0,
 False,
 False,
 {})

### Training Start!

In [73]:
def reset_buffer(n_steps: int):
    reward_arr = [None] * n_steps
    terminated_arr = [None] * n_steps
    log_pi_arr = [None] * n_steps
    state_value_arr = [None] * n_steps
    return reward_arr, terminated_arr, log_pi_arr, state_value_arr

In [74]:
from torch.utils.tensorboard import SummaryWriter

logger = SummaryWriter("results/CartPole-v1_A2C_BatchLearning")

total_episodes = 501
inference_freq = 50
n_steps = 16

# reset buffers
t = 0
reward_arr, terminated_arr, log_pi_arr, state_value_arr = reset_buffer(n_steps)

for episode in range(total_episodes):
    obs, _ = env.reset()
    terminated = False
    
    cumulative_reward = 0
    
    while not terminated:
        # take action and observe
        action, log_pi, state_value = agent.select_action(torch.from_numpy(obs))
        next_obs, reward, terminated, truncated, _ = env.step(action.numpy())
        terminated = terminated | truncated
        
        # update buffer
        reward_arr[t] = reward
        terminated_arr[t] = terminated
        log_pi_arr[t] = log_pi
        state_value_arr[t] = state_value
        t += 1
        
        if t == n_steps:
            # train the agent
            actor_loss, critic_loss = agent.train(
                torch.from_numpy(next_obs),
                torch.tensor(reward_arr),
                torch.tensor(terminated_arr, dtype=torch.float32),
                torch.stack(log_pi_arr),
                torch.stack(state_value_arr)
            )
            
            # log losses
            logger.add_scalar("Actor Loss", actor_loss, episode)
            logger.add_scalar("Critic Loss", critic_loss, episode)
            
            # reset buffers
            t = 0
            reward_arr, terminated_arr, log_pi_arr, state_value_arr = reset_buffer(n_steps)
                
        # next step
        cumulative_reward += reward
        obs = next_obs
        
    # log cumulative reward
    logger.add_scalar("Cumulative Reward", cumulative_reward, episode)
    
    # inference
    if episode % inference_freq == 0:
        obs, _ = inference_env.reset()
        terminated = False
        inference_cumulative_reward = 0
        while not terminated:
            with torch.no_grad():
                action, _, _ = agent.select_action(torch.from_numpy(obs))
            next_obs, reward, terminated, truncated, _ = inference_env.step(action.numpy())
            terminated = terminated | truncated

            inference_cumulative_reward += reward
            obs = next_obs
            
        print(f"inference - episode: {episode}, cumulative reward: {inference_cumulative_reward}")
        
logger.flush()
logger.close()

env.close()
inference_env.close()

The key you just pressed is not recognized by SDL. To help get this fixed, please report this to the SDL forums/mailing list <https://discourse.libsdl.org/> X11 KeyCode 208 (200), X11 KeySym 0x0 ((null)).


inference - episode: 0, cumulative reward: 12.0
inference - episode: 50, cumulative reward: 24.0
inference - episode: 100, cumulative reward: 14.0
inference - episode: 150, cumulative reward: 21.0
inference - episode: 200, cumulative reward: 21.0
inference - episode: 250, cumulative reward: 77.0
inference - episode: 300, cumulative reward: 89.0
inference - episode: 350, cumulative reward: 98.0
inference - episode: 400, cumulative reward: 228.0
inference - episode: 450, cumulative reward: 222.0
inference - episode: 500, cumulative reward: 277.0
