In [None]:
!nvidia-smi

### Install the necessary stuff

In [None]:
!sudo apt-get install xvfb

In [None]:
pip install xagents

In [None]:
pip install matplotlib==3.1.3

### Training (trial 1)

We will train A2C and PPO agents on the CartPole-v1 environment. Since no hyperparameter optimization was conducted, both agents will yield suboptimal results.

### PPO (training)

In [None]:
!xagents train ppo --env CartPole-v1 --max-steps 300000 --n-envs 16 --seed 55 --checkpoints ppo-cartpole.tf --history-checkpoint ppo-cartpole.parquet

### A2C (training)

In [None]:
!xagents train a2c --env CartPole-v1 --max-steps 300000 --n-envs 16 --checkpoints a2c-cartpole.tf \
--seed 55 --history-checkpoint a2c-cartpole.parquet --n-steps 128

### Tuning

In this section, we are going to tune hyperparameters for A2C and PPO.

**Notes:**
* The `xagents <command> <agent>` syntax displays the available options for the given command and agent. We will use this syntax for displaying tuning options for both agents.
* There are multiple hyperparameter types, which you can find under `hp_type` column in the menu displayed below. We will be using the 2 below ...
1. `log_uniform` hp_type, accepts a minimum and maximum bound. Therefore, you will need to pass either 1 value or 2 values in the following fashion:

  `xagents tune <agent> --<log-uniform-hp> <min-val> <max-val>`

2. `categorical` hp_type, accepts n number of values.

`xagents tune <agent> --<categorical-hp> <val1> <val2> <val3> ...`

In [None]:
!xagents tune a2c

### A2C (tuning)

In [None]:
!xagents tune a2c --study a2c-cartpole --env CartPole-v1 --trial-steps 200000 --warmup-trials 4 \
--n-trials 20 --n-jobs 2 --storage sqlite:///a2c-cartpole.sqlite --entropy-coef 1e-5 0.5 --gamma 0.9 0.99 \
--grad-norm 0.1 10 --lr 1e-5 1e-2 --n-envs 16 --n-steps 8 16 32 64 128 256 512 1024 \
--opt-epsilon 1e-7 1e-3

We can use [optuna.visualization.matplotlib](https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=&cad=rja&uact=8&ved=2ahUKEwi5p9239uDyAhXb_rsIHZ9EDPYQFnoECAMQAQ&url=https%3A%2F%2Foptuna.readthedocs.io%2Fen%2Flatest%2Freference%2Fvisualization%2Fmatplotlib.html&usg=AOvVaw20M4GHVSpQJQAkegkfJttS) API to visualize hyperparameter importances. 



In [None]:
import optuna
import matplotlib.pyplot as plt

a2c_study = optuna.load_study('a2c-cartpole', storage='sqlite:///a2c-cartpole.sqlite')
optuna.visualization.matplotlib.plot_param_importances(a2c_study)
plt.show()

In [None]:
!xagents tune ppo

### PPO (tuning)

In [None]:
!xagents tune ppo --env CartPole-v1 --study ppo-cartpole --storage sqlite:///ppo-cartpole.sqlite \
--trial-steps 200000 --warmup-trials 4 --n-trials 20 --advantage-epsilon 1e-8 1e-5 --clip-norm 0.01 0.5 \
--entropy-coef 1e-4 0.3 --gamma 0.9 0.999 --grad-norm 0.1 10 --lam 0.7 0.99 --lr 1e-5 1e-2 \
--n-steps 16 32 64 128 256 512 1024 --opt-epsilon 1e-7 1e-4 --n-envs 16 32 \
--opt-epsilon 1e-7 1e-3 --n-jobs 2

In [None]:
ppo_study = optuna.load_study('ppo-cartpole', 'sqlite:///ppo-cartpole.sqlite')
optuna.visualization.matplotlib.plot_param_importances(ppo_study)
plt.show()

Display A2C best parameters, and use them to re-train.

In [None]:
a2c_study.best_params

### A2C (training using tuned hyperparameters)

In [None]:
!xagents train a2c --env CartPole-v1 --max-steps 300000 --n-envs 16 --checkpoints a2c-cartpole-tuned.tf \
--seed 55 --history-checkpoint a2c-cartpole-tuned.parquet --n-steps 8 --opt-epsilon 0.0009386796496510724 \
--lr 0.0012985885268425004 --grad-norm 0.9964628998438626 --gamma 0.9387388102974632 \
--entropy-coef 0.010565924673903932

In [None]:
ppo_study.best_params

### PPO (training using tuned hyperparameters) 

In [None]:
!xagents train ppo --env CartPole-v1 --max-steps 300000 --n-envs 32 --seed 55 \
--checkpoints ppo-cartpole-tuned.tf --history-checkpoint ppo-cartpole-tuned.parquet \
--advantage-epsilon 1.3475350681876062e-08 --clip-norm 0.0503693625084303 \
--entropy-coef 0.06363366133416302 --gamma 0.93959608546301 --grad-norm 6.2465542151066495 \
--lam 0.9818834679479003 --lr 0.001549335940636045 --n-steps 16 --opt-epsilon 8.539506175014364e-07

### Benchmarks

In [None]:
from xagents.utils.common import plot_history

plot_history(
    ['a2c-cartpole.parquet', 'a2c-cartpole-tuned.parquet', 'ppo-cartpole.parquet', 'ppo-cartpole-tuned.parquet'],
    ['A2C', 'A2C(tuned)', 'PPO', 'PPO(tuned)'],
    'CartPole-v1',
    history_interval=50
)
plt.show()


### Play and save episode video for A2C and PPO

In [None]:
!xvfb-run xagents play a2c --env CartPole-v1 --weights a2c-cartpole-tuned.tf --video-dir a2c-vid

In [None]:
from IPython.display import HTML
from base64 import b64encode
import glob


def get_vid_url(vid_folder):
  vid = open(glob.glob(f'{vid_folder}/*.mp4')[0],'rb').read()
  return "data:video/mp4;base64," + b64encode(vid).decode()

In [None]:
ppo_url = get_vid_url('ppo-vid')
HTML("""
<video width=400 controls>
      <source src="%s" type="video/mp4">
</video>
""" % ppo_url)

In [None]:
!xvfb-run xagents play ppo --env CartPole-v1 --weights ppo-cartpole-tuned.tf --video-dir ppo-vid

In [None]:
a2c_url = get_vid_url('a2c-vid')
HTML("""
<video width=400 controls>
      <source src="%s" type="video/mp4">
</video>
""" % a2c_url)