# Advantage Actor Critic (A2C) - One Step Learning

## Training A2C Agent in CartPole-v1

### Actor-Critic Network

Actor와 Critic 각각이 독립된 구조이다.

In [None]:
import torch
import torch.nn as nn

class Actor(nn.Module):
    def __init__(self, 
                 obs_features: int,
                 num_actions: int) -> None:
        super().__init__()
        
        self.actor = nn.Sequential(
            nn.Linear(obs_features, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_actions)
        )
        
    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.actor(obs)
    
class Critic(nn.Module):
    def __init__(self,
                 obs_features: int) -> None:
        super().__init__()
        
        self.critic = nn.Sequential(
            nn.Linear(obs_features, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.critic(obs)

### A2C 구현

one-step advantage 추정치를 사용한다. `A2C.train()` 메소드는 매 time step마다 호출된다.

In [1]:
from typing import Tuple
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

class A2C:
    def __init__(self,
                 obs_features: int,
                 num_actions: int,
                 gamma: float = 0.99) -> None:
        super().__init__()
        
        self.gamma = gamma
        
        self.actor = Actor(obs_features, num_actions)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=0.001)
        
        self.critic = Critic(obs_features)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=0.001)
        
    def select_action(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns action and log pi(a|s).
        """
        logits = self.actor(obs)
        dist = Categorical(logits=logits)
        action = dist.sample()
        log_pi = dist.log_prob(action)
        return action, log_pi
    
    def train(self,
              obs: torch.Tensor,
              next_obs: torch.Tensor,
              reward: torch.Tensor,
              terminated: torch.Tensor,
              log_pi: torch.Tensor) -> Tuple[float, float]:
        state_value = self.critic(obs)
        with torch.no_grad():
            next_state_value = self.critic(next_obs)
            
        target_state_value = self._compute_target_state_value(next_state_value, reward, terminated)
        advantage = self._compute_advantage(state_value.detach(), target_state_value)
        
        actor_loss = self._compute_actor_loss(log_pi, advantage)
        critic_loss = self._compute_critic_loss(state_value, target_state_value)
        
        self._train_step(actor_loss, self.actor_optimizer)
        self._train_step(critic_loss, self.critic_optimizer)
        
        return actor_loss.item(), critic_loss.item()
    
    def _compute_target_state_value(self,
                                   next_state_value: torch.Tensor,
                                   reward: torch.Tensor,
                                   terminated: torch.Tensor) -> torch.Tensor:
        not_terminated = 1 - terminated
        return reward + not_terminated * self.gamma * next_state_value
    
    def _compute_advantage(self,
                          state_value: torch.Tensor,
                          target_state_value: torch.Tensor) -> torch.Tensor:
        return target_state_value - state_value
    
    def _compute_actor_loss(self,
                           log_pi: torch.Tensor,
                           advantage: torch.Tensor) -> torch.Tensor:
        return -(log_pi * advantage).mean()
    
    def _compute_critic_loss(self,
                            state_value: torch.Tensor,
                            target_state_value: torch.Tensor) -> torch.Tensor:
        return F.mse_loss(state_value, target_state_value)
    
    def _train_step(self,
                    loss: torch.Tensor,
                    optimizer: optim.Optimizer):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

### Make CartPole-v1 Environment

In [2]:
import gym

env_id = "CartPole-v1"

env = gym.make(env_id) # for training
inference_env = gym.make(env_id, render_mode="human") # for inference

#### Observation Features

In [3]:
obs_features = env.observation_space.shape
obs_features

(4,)

#### Number of Actions

In [4]:
num_actions = env.action_space.n
num_actions

2

### Instantiate Agent

In [5]:
agent = A2C(obs_features[0], num_actions)

### Check Outputs

In [6]:
obs, _ = env.reset()
obs

array([ 0.04262747,  0.04798423,  0.00884291, -0.00670878], dtype=float32)

In [7]:
action, log_pi = agent.select_action(torch.from_numpy(obs))
(action, log_pi)

(tensor(1), tensor(-0.5673, grad_fn=<SqueezeBackward1>))

In [8]:
env.step(action.numpy())

(array([ 0.04358716,  0.24297826,  0.00870873, -0.2965886 ], dtype=float32),
 1.0,
 False,
 False,
 {})

### Training Start!

In [9]:
from torch.utils.tensorboard import SummaryWriter

logger = SummaryWriter("results/CartPole-v1_A2C")

total_episodes = 501
inference_freq = 50

for episode in range(total_episodes):
    obs, _ = env.reset()
    terminated = False
    
    cumulative_reward = 0
    actor_loss_mean = 0.0
    critic_loss_mean = 0.0
    n = 0
    
    while not terminated:
        # take action and observe
        action, log_pi = agent.select_action(torch.from_numpy(obs))
        next_obs, reward, terminated, truncated, _ = env.step(action.numpy())
        terminated = terminated | truncated
        
        # train the agent
        actor_loss, critic_loss = agent.train(
            torch.from_numpy(obs),
            torch.from_numpy(next_obs),
            torch.tensor(reward),
            torch.tensor(terminated, dtype=torch.float32),
            log_pi
        )
        
        # update info
        n += 1
        actor_loss_mean += (actor_loss - actor_loss_mean) / n
        critic_loss_mean += (critic_loss - critic_loss_mean) / n
        cumulative_reward += reward
        
        # next step
        obs = next_obs
        
    # log data
    logger.add_scalar("Cumulative Reward", cumulative_reward, episode)
    logger.add_scalar("Actor Loss", actor_loss_mean, episode)
    logger.add_scalar("Critic Loss", critic_loss_mean, episode)
    
    # inference
    if episode % inference_freq == 0:
        obs, _ = inference_env.reset()
        terminated = False
        inference_cumulative_reward = 0
        while not terminated:
            with torch.no_grad():
                action, _ = agent.select_action(torch.from_numpy(obs))
            next_obs, reward, terminated, truncated, _ = inference_env.step(action.numpy())
            terminated = terminated | truncated

            inference_cumulative_reward += reward
            obs = next_obs
            
        print(f"inference - episode: {episode}, cumulative reward: {inference_cumulative_reward}")
        
logger.flush()
logger.close()

env.close()
inference_env.close()

inference - episode: 0, cumulative reward: 27.0
inference - episode: 50, cumulative reward: 21.0
inference - episode: 100, cumulative reward: 64.0
inference - episode: 150, cumulative reward: 50.0
inference - episode: 200, cumulative reward: 175.0
inference - episode: 250, cumulative reward: 27.0
inference - episode: 300, cumulative reward: 10.0
inference - episode: 350, cumulative reward: 8.0
inference - episode: 400, cumulative reward: 9.0
inference - episode: 450, cumulative reward: 10.0
inference - episode: 500, cumulative reward: 9.0
