# A2C on Cartpole

In this example we will see how to train a Advantage Actor critic (A2C) agent using `torchrl`. See [Documentation](https://torchrl.sanyamkapoor.com/) for an introduction to *TorchRL* and installation instructions.

## Problem Specification

The full problem can specified in **less than 50 lines of code**!

In [1]:
import argparse
import torch
import numpy as np

from torchrl import registry
from torchrl import utils
from torchrl.problems import base_hparams, A2CProblem
from torchrl.agents import BaseA2CAgent

We use a pre-built version of the A2C agent from *TorchRL* library to initialize a `Problem`. This `Problem` class is also based on a pre-built version from the library.

In [2]:
class A2CCartpole(A2CProblem):
  def init_agent(self):
    observation_space, action_space = utils.get_gym_spaces(self.runner.make_env)

    agent = BaseA2CAgent(
        observation_space,
        action_space,
        lr=self.hparams.actor_lr,
        gamma=self.hparams.gamma,
        lmbda=self.hparams.lmbda,
        alpha=self.hparams.alpha,
        beta=self.hparams.beta)

    return agent

This class requires us to extend the `init_agent` method. There is no restriction on the contents as long as it returns a valid `BaseAgent`.

## Hyperparameter Specification

We use the `HParams` object from the library to add custom properties. Again, arbitrary properties can be provided to such objects as long as they are consistently used within the previously specified `Problem` class (e.g. within the `init_agent` routine).

In [3]:
def hparams_a2c_cartpole():
    params = base_hparams.base_pg()

    params.env_id = 'CartPole-v0'

    params.num_processes = 16

    params.rollout_steps = 5
    params.max_episode_steps = 500
    params.num_total_steps = int(1e6)

    params.alpha = 0.5
    params.gamma = 0.99
    params.beta = 1e-3
    params.lmbda = 1.0

    params.batch_size = 128
    params.tau = 1e-2
    params.actor_lr = 3e-4

    return params


## Initialize Problem Instance

We use GPUs if available and some basic arguments, most importantly the seed. Make sure to run using different seeds.

**NOTE**: We use `argparse.Namespace` class as the argument to the `Problem` class which explains the type cast. If interested, track this issue [here](https://github.com/activatedgeek/torchrl/issues/61).

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

args=dict(
    seed=1,
    log_interval=1000,
    eval_interval=1000,
    num_eval=1,
)

a2c_cartpole = A2CCartpole(
    hparams_a2c_cartpole(),
    argparse.Namespace(**args),
    None, # Disable logging
    device=device,
    show_progress=True,
)

## Training the A2C Agent

Calling the `run()` routine allows us to execute training. Note that for now we have disabled logging by keeping `log_dir=None` in the above instatiation.

In [5]:
a2c_cartpole.run()

100%|██████████| 12500/12500 [05:58<00:00, 34.86epochs/s]


## Evaluate Training

Quoting the documentation, this environment is

> Considered solved when the average reward is greater than or equal to 195.0 over 100 consecutive trials.

Let's see if we were able to achieve that. We use parallel environments for faster evaluation.

In [6]:
%%time

a2c_cartpole.agent.train(False)

eval_runner = a2c_cartpole.make_runner(n_envs=10)
eval_rewards = []
for _ in range(100 // a2c_cartpole.runner.n_envs):
  eval_history = eval_runner.rollout(a2c_cartpole.agent)
  for i in range(a2c_cartpole.runner.n_envs):
    _, _, reward_history, _, _ = eval_history[0]
    eval_rewards.append(np.sum(reward_history, axis=0))
eval_runner.close()

CPU times: user 1.2 s, sys: 221 ms, total: 1.43 s
Wall time: 1.51 s


In [7]:
avg_reward, std_reward = np.average(eval_rewards), np.std(eval_rewards)
possible_win = avg_reward > 195.0

print('Reward: {} +/- {}'.format(avg_reward, std_reward))
print('Did we possibly win? {}!'.format('Yay' if possible_win else 'Nay'))

Reward: 170.5 +/- 18.714967272212903
Did we possibly win? Nay!


# Visualization

In [8]:
vis_runner = a2c_cartpole.make_runner(n_envs=1)
vis_runner.rollout(a2c_cartpole.agent,render = True)
vis_runner.close()