# Implementing Advantage-Actor Critic (A2C)

In [31]:
import numpy as np
from atari_wrappers import nature_dqn_env

nenvs = 8  # change this if you have more than 8 CPU ;)

env = nature_dqn_env("SpaceInvadersNoFrameskip-v4", nenvs=nenvs)

n_actions = env.action_space.spaces[0].n
obs = env.reset()
assert obs.shape == (nenvs, 4, 84, 84)
assert obs.dtype == np.float32

In [2]:
import torch
import torch.nn as nn
from torch.nn import init


def weights_init_orthogonal(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find("Conv") != -1:
        init.orthogonal_(m.weight.data, gain=np.sqrt(2))
        init.constant_(m.bias.data, 0.0)
    elif classname.find("Linear") != -1:
        init.orthogonal_(m.weight.data, gain=np.sqrt(2))
        init.constant_(m.bias.data, 0.0)


class Flatten(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)


class ACAgent(nn.Module):
    def __init__(self, state_shape, n_actions, epsilon=0):
        super().__init__()
        self.epsilon = epsilon
        self.n_actions = n_actions
        self.state_shape = state_shape
        # conv2d_size_out(conv2d_size_out(conv2d_size_out(64, 3, 2), 3, 2), 3, 2)
        # Define your network body here. Please make sure agent is fully contained here
        self.net = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=2),
            nn.ReLU(),
            Flatten(),
            nn.Linear(5184, 256),
        )

        self.action_head = nn.Sequential(
            nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, n_actions)
        )
        self.V_head = nn.Sequential(nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 1))

        self.net.apply(weights_init_orthogonal)
        self.action_head.apply(weights_init_orthogonal)
        self.V_head.apply(weights_init_orthogonal)

    def forward(self, states):
        """
        input:
            states - tensor, (batch_size x channels x width x height)
        output:
            logits - tensor, logits of action probabilities for your actor policy, (batch_size x num_actions)
            V - tensor, critic estimation, (batch_size)
        """

        features = self.net(states)
        logits = self.action_head(features)
        V = self.V_head(features)

        return logits, V.squeeze(1)

In [3]:
from torch.distributions import Categorical


class Policy:
    def __init__(self, model, device):
        self.model = model.to(device)
        self.device = device

    def act(self, inputs):
        """
        input:
            inputs - numpy array, (batch_size x channels x width x height)
        output: dict containing keys ['actions', 'logits', 'log_probs', 'values']:
            'actions' - selected actions, numpy, (batch_size)
            'logits' - actions logits, tensor, (batch_size x num_actions)
            'log_probs' - log probs of selected actions, tensor, (batch_size)
            'values' - critic estimations, tensor, (batch_size)
        """

        # print(inputs.shape)
        inputs = torch.tensor(inputs).to(self.device)
        logits, V = self.model(inputs)
        dist = Categorical(logits=logits)
        actions = dist.sample().cpu().numpy()

        log_proba = torch.log(nn.functional.softmax(logits, dim=-1))
        log_probs = log_proba[range(log_proba.size()[0]), actions]

        entropy = (
            nn.functional.softmax(logits, dim=-1)
            * nn.functional.log_softmax(logits, dim=-1)
        ).sum(1)

        return {
            "actions": actions,
            "logits": logits,
            "log_probs": log_probs,
            "values": V,
            "entropy": entropy,
        }

In [4]:
from runners import EnvRunner

This runner interacts with the environment for a given number of steps and returns a dictionary containing
keys 

* 'observations' 
* 'rewards' 
* 'dones'
* 'actions'
* all other keys that you defined in `Policy`

under each of these keys there is a python `list` of interactions with the environment of specified length $T$ &mdash; the size of partial trajectory, or rollout length. Let's have a look at how it works.

In [5]:
model = ACAgent((4, 84, 84), n_actions)
policy = Policy(model, "cuda")
runner = EnvRunner(env, policy, nsteps=5)

In [6]:
# generates new rollout
trajectory = runner.get_next()

In [7]:
# what is inside
print(trajectory.keys())

dict_keys(['actions', 'logits', 'log_probs', 'values', 'entropy', 'observations', 'rewards', 'dones'])


In [8]:
# Sanity checks
assert "logits" in trajectory, "Not found: policy didn't provide logits"
assert (
    "log_probs" in trajectory
), "Not found: policy didn't provide log_probs of selected actions"
assert "values" in trajectory, "Not found: policy didn't provide critic estimations"
assert trajectory["logits"][0].shape == (nenvs, n_actions), "logits wrong shape"
assert trajectory["log_probs"][0].shape == (nenvs,), "log_probs wrong shape"
assert trajectory["values"][0].shape == (nenvs,), "values wrong shape"

for key in trajectory.keys():
    assert (
        len(trajectory[key]) == 5
    ), f"something went wrong: 5 steps should have been done, got trajectory of length {len(trajectory[key])} for '{key}'"


The formula for the value targets is simple:

$$
\hat v(s_t) = \sum_{t'=0}^{T - 1}\gamma^{t'}r_{t+t'} + \gamma^T \hat{v}(s_{t+T}),
$$

where $s_{t + T}$ is the latest observation of the environment.

Any callable could be passed to `EnvRunner` to be applied to each partial trajectory after it is collected. 
Thus, we can implement and use `ComputeValueTargets` callable. 

In [9]:
class ComputeValueTargets:
    def __init__(self, policy, gamma=0.99):
        self.policy = policy
        self.gamma = gamma

    def __call__(self, trajectory, latest_observation):
        """
        This method should modify trajectory inplace by adding
        an item with key 'value_targets' to it

        input:
            trajectory - dict from runner
            latest_observation - last state, numpy, (num_envs x channels x width x height)
        """
        num_steps, num_envs = np.vstack(trajectory["rewards"]).shape

        out_policy = self.policy.act(latest_observation)
        values = out_policy["values"]

        value_targets = torch.cat(
            [
                torch.zeros((num_steps, num_envs), device=self.policy.device),
                values.unsqueeze(0),
            ]
        )

        for t in reversed(range(num_steps)):
            value_targets[t] = torch.tensor(trajectory["rewards"][t]).to(
                self.policy.device
            ) + self.gamma * value_targets[t + 1] * torch.tensor(
                (1 - trajectory["dones"][t])
            ).to(
                self.policy.device
            )

        # rewards = np.vstack(trajectory['rewards']) * (1 - np.vstack(trajectory['dones']))

        # discounted = rewards *  np.power(np.repeat(self.gamma, num_steps), range(num_steps))[:,np.newaxis]
        # value_targets = np.flip(np.cumsum(np.flip(discounted), axis=0))

        # value_targets = torch.tensor(value_targets.copy(), dtype=torch.float32).to(policy.device) + values * self.gamma ** (num_steps + 1)

        trajectory["value_targets"] = value_targets[:-1]

In [10]:
class MergeTimeBatch:
    """Merges first two axes typically representing time and env batch."""

    def __call__(self, trajectory, latest_observation):
        # Modify trajectory inplace.
        num_steps, num_envs = np.vstack(trajectory["rewards"]).shape

        trajectory["value_targets"] = (
            trajectory["value_targets"].view(num_steps * num_envs, -1).squeeze(1)
        )
        trajectory["values"] = (
            torch.stack(trajectory["values"]).view(num_steps * num_envs, -1).squeeze(1)
        )
        trajectory["log_probs"] = (
            torch.stack(trajectory["log_probs"])
            .view(num_steps * num_envs, -1)
            .squeeze(1)
        )
        trajectory["entropy"] = (
            torch.stack(trajectory["entropy"]).view(num_steps * num_envs, -1).squeeze(1)
        )

Let's do more sanity checks!

In [11]:
runner = EnvRunner(
    env, policy, nsteps=5, transforms=[ComputeValueTargets(policy), MergeTimeBatch()]
)

In [12]:
trajectory = runner.get_next()

In [13]:
# More sanity checks
assert "value_targets" in trajectory, "Value targets not found"
assert trajectory["log_probs"].shape == (5 * nenvs,)
assert trajectory["value_targets"].shape == (5 * nenvs,)
assert trajectory["values"].shape == (5 * nenvs,)

assert trajectory[
    "log_probs"
].requires_grad, "Gradients are not available for actor head!"
assert trajectory[
    "values"
].requires_grad, "Gradients are not available for critic head!"

Now is the time to implement the advantage actor critic algorithm itself. You can look into [Mnih et al. 2016](https://arxiv.org/abs/1602.01783) paper, and lectures ([part 1](https://www.youtube.com/watch?v=Ds1trXd6pos&list=PLkFD6_40KJIwhWJpGazJ9VSj9CFMkb79A&index=5), [part 2](https://www.youtube.com/watch?v=EKqxumCuAAY&list=PLkFD6_40KJIwhWJpGazJ9VSj9CFMkb79A&index=6)) by Sergey Levine.

In [14]:
from collections import defaultdict
from torch.nn.utils import clip_grad_norm_


class A2C:
    def __init__(
        self,
        policy,
        optimizer,
        value_loss_coef=0.25,
        entropy_coef=0.01,
        max_grad_norm=0.5,
    ):
        self.policy = policy
        self.optimizer = optimizer
        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef
        self.max_grad_norm = max_grad_norm

    def loss(self, trajectory, write):
        # compute all losses
        # do not forget to use weights for critic loss and entropy loss

        advantage = trajectory["value_targets"].detach() - trajectory["values"]
        policy_loss = torch.mean(-trajectory["log_probs"] * advantage.detach())

        critic_loss = advantage.abs().mean()

        entropy_loss = trajectory["entropy"].mean()

        # log all losses
        write(
            "losses",
            {
                "policy loss": policy_loss,
                "critic loss": critic_loss,
                "entropy loss": entropy_loss,
            },
        )

        # additional logs
        write("critic/advantage", advantage.mean())
        write(
            "critic/values",
            {
                "value predictions": trajectory["values"].mean(),
                "value targets": trajectory["value_targets"].mean(),
            },
        )

        # return scalar loss
        return (
            self.value_loss_coef * critic_loss
            + policy_loss
            + self.entropy_coef * entropy_loss
        )

    def train(self, runner):
        # collect trajectory using runner
        # compute loss and perform one step of gradient optimization
        # do not forget to clip gradients
        trajectory = runner.get_next()

        self.optimizer.zero_grad()
        loss = self.loss(trajectory, runner.write)
        loss.backward()

        clip_grad_norm_(self.policy.model.parameters(), self.max_grad_norm)
        self.optimizer.step()

        total_norm = 0
        for p in self.policy.model.parameters():
            param_norm = p.grad.detach().data.norm(2)
            total_norm += param_norm.item() ** 2
        total_norm = total_norm**0.5

        # use runner.write to log scalar to tensorboard
        runner.write("gradient norm", total_norm)

In [15]:
model = ACAgent((4, 84, 84), 6)
policy = Policy(model, "cuda")

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

a2c = A2C(
    policy,
    optimizer,
)

In [32]:
runner = EnvRunner(
    env, policy, nsteps=20, transforms=[ComputeValueTargets(policy), MergeTimeBatch()]
)

In [33]:
from tqdm import tqdm

epochs = 25000

for i in tqdm(range(epochs)):
    a2c.train(runner)

100%|██████████| 25000/25000 [1:22:51<00:00,  5.03it/s]


In [34]:
# save your model just in case
torch.save(model.state_dict(), "A2C")

In [35]:
env.close()

## Evaluation

In [36]:
env = nature_dqn_env(
    "SpaceInvadersNoFrameskip-v4",
    nenvs=None,
    clip_reward=False,
    summaries=False,
    episodic_life=False,
)

In [37]:
def evaluate(env, policy, n_games=1, t_max=10000):
    """
    Plays n_games and returns rewards
    """
    rewards = []

    for _ in range(n_games):
        s = env.reset()

        R = 0
        for _ in range(t_max):
            action = policy.act(np.array([s]))["actions"][0]

            s, r, done, _ = env.step(action)

            R += r
            if done:
                break

        rewards.append(R)
    return np.array(rewards)

In [38]:
# evaluation will take some time!
sessions = evaluate(env, policy, n_games=30)
score = sessions.mean()
print(f"Your score: {score}")

assert score >= 500, "Needs more training?"
print("Well done!")

Your score: 577.3333333333334
Well done!


In [39]:
env.close()

## Record

In [40]:
env_monitor = nature_dqn_env(
    "SpaceInvadersNoFrameskip-v4",
    nenvs=None,
    monitor=True,
    clip_reward=False,
    summaries=False,
    episodic_life=False,
)

In [41]:
# record sessions
sessions = evaluate(env_monitor, policy, n_games=3)

In [42]:
# rewards for recorded games
sessions

array([510., 500., 800.])

In [43]:
env_monitor.close()