<a href="https://colab.research.google.com/github/arampacha/hf_rl_class/blob/main/1a_LunarLander_v2_PPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
import sys
if 'google.colab' in sys.modules:
    !apt install python-opengl ffmpeg xvfb
    !pip install pyvirtualdisplay
    !pip install gym[box2d] stable-baselines3[extra] huggingface_sb3 pyglet
    !pip install ale-py==0.7.4 # To overcome an issue with gym (https://github.com/DLR-RM/stable-baselines3/issues/875)
    !pip install wandb

In [None]:
# Virtual display
from pyvirtualdisplay import Display

virtual_display = Display(visible=0, size=(1400, 900))
virtual_display.start()

<pyvirtualdisplay.display.Display at 0x7f73b232e610>

In [4]:
import gym

from huggingface_sb3 import load_from_hub, package_to_hub, push_to_hub
from huggingface_hub import notebook_login

from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
from stable_baselines3.common.monitor import Monitor

import wandb
from wandb.integration.sb3 import WandbCallback

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33marampacha[0m (use `wandb login --relogin` to force relogin)


True

In [3]:
notebook_login()
!git config --global credential.helper store

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
ENV_NAME = 'LunarLander-v2'

config = {
    "policy_type": "MlpPolicy",
    "total_timesteps": int(2e6),
    "seed": 124,
    "lr":2e-3,
    "decay":True,
    "batch_size":128
}

In [None]:
def linear_decay_sched(pct):
    return pct*config["lr"]

In [None]:
_d = "linear" if config['decay'] else 'no'
experiment_name = f"{ENV_NAME}-ppo-{config['lr']:.0e}-{_d}-decay-{config['total_timesteps']:.0e}steps"
with wandb.init(
        project="hf-deep-rl-class",
        name=f"{experiment_name}-1",
        group=experiment_name,
        config=config,
        sync_tensorboard=True,
        monitor_gym=True,
        save_code=False,
    ) as run:


    env = make_vec_env(ENV_NAME, n_envs=16)
    env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x>1e6)

    model = PPO(
        config["policy_type"], 
        env,
        learning_rate=linear_decay_sched if config["decay"] else config['lr'], 
        verbose=1, 
        tensorboard_log=f"runs/{run.id}",
        batch_size=config["batch_size"],

    )
    model_name = f"{ENV_NAME}-ppo"
    model.learn(
        total_timesteps=config["total_timesteps"],
        callback=WandbCallback(
            gradient_save_freq=100,
            model_save_path=model_name,
            verbose=2,
        ),
    )
    eval_env = DummyVecEnv([lambda: Monitor(gym.make(ENV_NAME))])
    package_to_hub(
        model=model,
        model_name=model_name,
        model_architecture="PPO",
        env_id=ENV_NAME,
        eval_env=eval_env,
        repo_id=f"arampacha/{model_name}",
        commit_message=f"trained model {config['total_timesteps']:.0e} steps"
    )