# Cartpole tests with policy gradient

This notebook contains a simple test for each implemented policy gradient method. In order to test if they function properly, we rely on the [Cartpole](https://gym.openai.com/envs/CartPole-v0/) environment, provided out-of-the-box in OpenAI Gym. As stated in Gym's documentation, the problem is considered "solved" if the agent is able to obtain a mean return of 195 in the last 100 episodes.

## Pre-requisites

The cells down below install and import the necessary libraries to successfully run the notebook examples.

In [None]:
import sys
sys.path.append('../')

In [None]:
%%capture
!pip install -r ../init/requirements.txt

In [None]:
import io
import base64

import numpy as np
import torch
import gym
from gym import wrappers
from IPython.display import HTML

from src import models, policies

%load_ext autoreload
%autoreload 2

## Utilities

The cell down below defines the environment, along with common variables to be used throughout the notebook.

In [None]:
env = gym.make('CartPole-v0')

In [None]:
observation_space_size = 4
action_space_size = 2
hidden_sizes = [32, 32]
epochs = 800
steps_per_epoch = 200
minibatch_size = 100
episodes_mean_return = 100
wandb_config = {
    "api_key": open("../wandb_api_key_file", "r").read().strip(),
    "project": "cpr-appropriation",
    "entity": "wadaboa",
}

## VPG

This section deals with training a Cartpole agent using our custom Vanilla Policy Gradient implementation.

In [None]:
vpg_policy_nn = models.MLP(observation_space_size, hidden_sizes, action_space_size)
vpg_baseline_nn = models.MLP(observation_space_size, hidden_sizes, 1, log_softmax=False)
vpg_policy = policies.VPGPolicy(env, vpg_policy_nn, baseline_nn=vpg_baseline_nn)
vpg_policy.train(
    epochs,
    steps_per_epoch,
    minibatch_size,
    enable_wandb=True,
    wandb_config={**wandb_config, "group": "VPG"},
    episodes_mean_return=episodes_mean_return
)

## TRPO

This section deals with training a Cartpole agent using our custom Trust Region Policy Optimization implementation.

In [None]:
beta = 1.0
kl_target = 0.01

In [None]:
trpo_policy_nn = models.MLP(observation_space_size, hidden_sizes, action_space_size)
trpo_baseline_nn = models.MLP(observation_space_size, hidden_sizes, 1, log_softmax=False)
trpo_policy = policies.TRPOPolicy(env, trpo_policy_nn, trpo_baseline_nn, beta=beta, kl_target=kl_target)
trpo_policy.train(
    epochs,
    steps_per_epoch,
    minibatch_size,
    enable_wandb=True,
    wandb_config={**wandb_config, "group": "TRPO"},
    episodes_mean_return=episodes_mean_return
)

## PPO

This section deals with training a Cartpole agent using our custom Proximal Policy Optimization implementation.

In [None]:
alpha = 0.5
beta = 0.01
eps = 0.2

In [None]:
ppo_policy_nn = models.MLP(observation_space_size, hidden_sizes, action_space_size)
ppo_baseline_nn = models.MLP(observation_space_size, hidden_sizes, 1, log_softmax=False)
ppo_policy = policies.PPOPolicy(env, ppo_policy_nn, ppo_baseline_nn, alpha=alpha, beta=beta, eps=eps)
ppo_policy.train(
    epochs,
    steps_per_epoch,
    minibatch_size,
    enable_wandb=False,
    save_every=200,
    checkpoints_path="../checkpoints",
    wandb_config={**wandb_config, "group": "PPO"},
    episodes_mean_return=episodes_mean_return
)

## Evaluation

In this section we are evaluating one of the trained models on the Cartpole environment and visualizing results through Gym Monitor videos.

In [None]:
policy = ppo_policy
checkpoint = ""
policy.load(checkpoint)

In [None]:
policy.policy_nn.eval()
policy.baseline_nn.eval()
env = wrappers.Monitor(env, "../gym-results", force=True)
observation = env.reset()
for _ in range(1000):
    probs = policy.policy_nn(torch.tensor(observation, dtype=torch.float32))
    action = probs.argmax().item()
    observation, _, done, _= env.step(action)
    if done:
        break
env.close()

In [None]:
video = io.open('../gym-results/openaigym.video.%s.video000000.mp4' % env.file_infix, 'r+b').read()
encoded = base64.b64encode(video)
HTML(
    data='''
        <video width="360" height="auto" alt="test" controls>
            <source src="data:video/mp4;base64,{0}" type="video/mp4" />
        </video>
    '''.format(encoded.decode('ascii'))
)