<a href="https://colab.research.google.com/github/arampacha/hf_rl_class/blob/main/1b_PPO_hp_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
import sys
if 'google.colab' in sys.modules:
    !apt install python-opengl ffmpeg xvfb
    !pip install pyvirtualdisplay
    !pip install gym[box2d] stable-baselines3[extra] huggingface_sb3 pyglet
    !pip install ale-py==0.7.4 # To overcome an issue with gym (https://github.com/DLR-RM/stable-baselines3/issues/875)
    !pip install wandb

In [None]:
# Virtual display
from pyvirtualdisplay import Display

virtual_display = Display(visible=0, size=(1400, 900))
virtual_display.start()

<pyvirtualdisplay.display.Display at 0x7f5b53fd0250>

In [None]:
import gym

from huggingface_sb3 import load_from_hub, package_to_hub, push_to_hub
from huggingface_hub import notebook_login

from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor

import wandb
from wandb.integration.sb3 import WandbCallback

In [None]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33marampacha[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
# notebook_login()
# !git config --global credential.helper store

In [None]:
sweep_config = {
    "project": "hf-deep-lr-class",
    "entity": "arampacha",
    "name" : "LunarLander-v2",
    "method" : "grid",
    "parameters": {
        "lr": {
            "values": [1e-4, 5e-4, 1e-3]
        },
        "decay": {
            "values": [False]
        },
        "batch_size": {"values": [128]},
        "target_kl": {"values":[None, .005, .01]},
        "ent_coef": {"values":[0, 0.01]}
    }
}

In [None]:
from IPython.display import clear_output

def train():

    clear_output()
    with wandb.init(sync_tensorboard=True) as run:
        ENV_NAME = 'LunarLander-v2'
        cfg = run.config

        env = make_vec_env(ENV_NAME, n_envs=16)
        def linear_decay_sched(pct):
            return pct*cfg.lr
        lr = linear_decay_sched if cfg.decay else cfg.lr

        model = PPO(
            "MlpPolicy", 
            env,
            learning_rate=lr, 
            verbose=1, 
            tensorboard_log=f"runs/{run.id}",
            batch_size=cfg.batch_size,
            target_kl=cfg.target_kl,
            ent_coef=cfg.ent_coef
        )
        model_name = f"{ENV_NAME}-ppo"
        model.learn(
            total_timesteps=int(1e6),
            callback=WandbCallback(
                gradient_save_freq=100,
                model_save_path=model_name,
                verbose=2,
            ),
        )
        eval_env = DummyVecEnv([lambda: Monitor(gym.make(ENV_NAME))])
        mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, deterministic=True)
        wandb.log({"eval_reward_mean":mean_reward, "eval_reward_std":std_reward})

In [None]:
sweep_id = wandb.sweep(sweep_config, entity="arampacha", project="hf-deep-rl-class")
wandb.agent(sweep_id, function=train)



Using cpu device
Logging to runs/w6sb18k2/PPO_1
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 92.7     |
|    ep_rew_mean     | -179     |
| time/              |          |
|    fps             | 3117     |
|    iterations      | 1        |
|    time_elapsed    | 10       |
|    total_timesteps | 32768    |
---------------------------------
Early stopping at step 9 due to reaching max kl: 0.01
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 97.2        |
|    ep_rew_mean          | -158        |
| time/                   |             |
|    fps                  | 2210        |
|    iterations           | 2           |
|    time_elapsed         | 29          |
|    total_timesteps      | 65536       |
| train/                  |             |
|    approx_kl            | 0.005927865 |
|    clip_fraction        | 0.0191      |
|    clip_range           | 0.2         |
|    entropy

VBox(children=(Label(value='0.155 MB of 0.155 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval_reward_mean,▁
eval_reward_std,▁
global_step,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇███
rollout/ep_len_mean,▁▁▁▁▁▁▂▁▂▂▂▂▃▃▄▄▅▅▆▇▇▇▇████████
rollout/ep_rew_mean,▁▂▃▃▂▁▂▄▄▆▅▄▇▆▆▆▇▇████▇▇▆▆▆▇▇▇▇
time/fps,█▅▅▆▆▆▆▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁
train/approx_kl,█▇▆▂▃▅▆▇▆▆▆▇▇▇▇█▇▁▄▆▆▆▁▃▅▆▂▄▃▆
train/clip_fraction,▅▃▁▁▁▆▄▄▆▂▂▆▃▆▅▇▆▂▇▄▇▆▂█▄▅▂▄▂▅
train/clip_range,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/entropy_loss,▁▁▂▂▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇██▇▇█▇█▇█

0,1
eval_reward_mean,-83.50977
eval_reward_std,27.32187
global_step,1015808.0
rollout/ep_len_mean,956.12
rollout/ep_rew_mean,-16.932
time/fps,809.0
train/approx_kl,0.00448
train/clip_fraction,0.02142
train/clip_range,0.2
train/entropy_loss,-0.98729


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Sweep Agent: Exiting.
