# DQN on Cartpole

In this example we will see how to train a DQN agent using `torchrl`. We assume basic knowledge about Reinforcement Learning and DQN. See [Documentation](https://torchrl.sanyamkapoor.com/) for an introduction to *TorchRL* and installation instructions.

## Problem Specification

The full problem can specified in **less than 50 lines of code**!

In [None]:
import argparse
import torch
import numpy as np

from torchrl import registry
from torchrl import utils
from torchrl.problems import base_hparams, DQNProblem
from torchrl.agents import BaseDQNAgent

We use a pre-built version of the DQN agent from *TorchRL* library to initialize a `Problem`. This `Problem` class is also based on a pre-built version from the library.

In [None]:
class DQNCartpole(DQNProblem):
  def init_agent(self):
    observation_space, action_space = utils.get_gym_spaces(self.runner.make_env)

    agent = BaseDQNAgent(
        observation_space,
        action_space,
        double_dqn=self.hparams.double_dqn,
        lr=self.hparams.actor_lr,
        gamma=self.hparams.gamma,
        target_update_interval=self.hparams.target_update_interval)

    return agent

This class requires us to extend the `init_agent` method. There is no restriction on the contents as long as it returns a valid `BaseAgent`.

## Hyperparameter Specification

We use the `HParams` object from the library to add custom properties. Again, arbitrary properties can be provided to such objects as long as they are consistently used within the previously specified `Problem` class (e.g. within the `init_agent` routine).

In [None]:
def hparams_dqn_cartpole():
    params = base_hparams.base_dqn()

    params.env_id = 'CartPole-v1'

    params.rollout_steps = 1
    params.num_processes = 1
    params.actor_lr = 1e-3
    params.gamma = 0.99
    params.target_update_interval = 10
    params.eps_min = 1e-2
    params.buffer_size = 1000
    params.batch_size = 32
    params.num_total_steps = 5000
    params.num_eps_steps = 500

    return params

## Initialize Problem Instance

We use GPUs if available and some basic arguments, most importantly the seed. Make sure to run using different seeds.

**NOTE**: We use `argparse.Namespace` class as the argument to the `Problem` class which explains the type cast. If interested, track this issue [here](https://github.com/salmanazarr/torchrl/issues/61).

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

args=dict(
    seed=1,
    log_interval=1000,
    eval_interval=1000,
    num_eval=1,
)

dqn_cartpole = DQNCartpole(
    hparams_dqn_cartpole(),
    argparse.Namespace(**args),
    None, # Disable logging
    device=device,
    show_progress=True,
)

## Training the DQN Agent

Calling the `run()` routine allows us to execute training. Note that for now we have disabled logging by keeping `log_dir=None` in the above instatiation.

In [None]:
dqn_cartpole.run()

## Evaluate Training

Quoting the documentation, this environment is

> Considered solved when the average reward is greater than or equal to 195.0 over 100 consecutive trials.

Let's see if we were able to achieve that. We use parallel environments for faster evaluation.

In [None]:
%%time

dqn_cartpole.agent.train(False)

eval_runner = dqn_cartpole.make_runner(n_envs=10)
eval_rewards = []
for _ in range(100 // dqn_cartpole.runner.n_envs):
  eval_history = eval_runner.rollout(dqn_cartpole.agent)
  for i in range(dqn_cartpole.runner.n_envs):
    _, _, reward_history, _, _ = eval_history[0]
    eval_rewards.append(np.sum(reward_history, axis=0))
eval_runner.close()

In [None]:
avg_reward, std_reward = np.average(eval_rewards), np.std(eval_rewards)
possible_win = avg_reward > 195.0

print('Reward: {} +/- {}'.format(avg_reward, std_reward))
print('Did we possibly win? {}!'.format('Yay' if possible_win else 'Nay'))

## What's next?

See [Documentation](https://torchrl.sanyamkapoor.com/) for full details on the flexibility of the API.