# Connect Four â€“ Maskable PPO (Colab-ready)

Run this notebook on Colab/Kaggle (GPU optional). It trains out-of-the-box PPO using `sb3-contrib`'s `MaskablePPO`, collects evaluation metrics, and plots ablation results and training losses.

**Steps:**
1. Install deps (first cell).
2. Clone/pull repo and `cd` into `rl_connect4`.
3. Run the training/ablation cells.

In [None]:
# Install dependencies (rerun on a fresh Colab runtime)
!pip install -q stable-baselines3[extra] sb3-contrib gymnasium pygame numpy torch pandas matplotlib seaborn

In [None]:
import os, sys, json, pathlib

REPO_URL = "https://github.com/UmerSR/Connect-Four-RL.git"
WORKSPACE = "/content/Connect-Four-RL"

if not os.path.exists(WORKSPACE):
    !git clone $REPO_URL $WORKSPACE
else:
    %cd $WORKSPACE
    !git pull --ff-only

%cd $WORKSPACE/rl_connect4
sys.path.insert(0, WORKSPACE)  # allow imports like envs.connect_four_env
print("CWD:", os.getcwd())

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch

from envs.connect_four_env import ConnectFourEnv
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import configure

sns.set_theme(style="whitegrid")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
# Env builders with action masking
def mask_fn(obs):
    # obs is a dict with 'action_mask'
    return obs["action_mask"].astype(bool)

def make_env(seed=None):
    def _init():
        env = ConnectFourEnv(seed=seed)
        env = ActionMasker(env, mask_fn)
        return env
    return _init

def make_vec_envs(n_envs=4, seed=0):
    env_fns = [make_env(seed + i if seed is not None else None) for i in range(n_envs)]
    vec = DummyVecEnv(env_fns)
    vec = VecMonitor(vec)
    return vec

eval_env = ActionMasker(ConnectFourEnv(), mask_fn)
obs, _ = eval_env.reset()
print("Obs keys:", obs.keys(), "action_mask shape:", obs["action_mask"].shape)

In [None]:
# Callback to record periodic evaluation
class EvalRecorder(BaseCallback):
    def __init__(self, eval_env, eval_freq=5000, n_eval_episodes=20):
        super().__init__()
        self.eval_env = eval_env
        self.eval_freq = eval_freq
        self.n_eval_episodes = n_eval_episodes
        self.history = []
        self.best_mean_reward = -np.inf

    def _on_step(self) -> bool:
        if self.n_calls % self.eval_freq == 0:
            mean_r, std_r = evaluate_policy(
                self.model,
                self.eval_env,
                n_eval_episodes=self.n_eval_episodes,
                warn=False,
                deterministic=True,
            )
            self.history.append({
                "timesteps": self.num_timesteps,
                "mean_reward": mean_r,
                "std_reward": std_r,
            })
            self.best_mean_reward = max(self.best_mean_reward, mean_r)
        return True

def train_maskable_ppo(run_name: str, total_timesteps: int, hyperparams: dict, n_envs: int = 4, eval_freq: int = 5000, n_eval_episodes: int = 20):
    log_dir = f"/content/Connect-Four-RL/outputs/{run_name}"
    pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True)

    vec_env = make_vec_envs(n_envs=n_envs)
    model = MaskablePPO(
        "MultiInputPolicy",
        vec_env,
        verbose=0,
        tensorboard_log=log_dir,
        device=device,
        **hyperparams,
    )
    model.set_logger(configure(log_dir, ["stdout", "csv"]))

    callback = EvalRecorder(eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes)
    model.learn(total_timesteps=total_timesteps, callback=callback, progress_bar=True)
    return model, callback.history, log_dir

hyperparams_example = {
    "learning_rate": 3e-4,
    "gamma": 0.99,
    "clip_range": 0.2,
    "ent_coef": 0.01,
    "vf_coef": 0.5,
    "n_steps": 1024,
    "batch_size": 1024,
    "gae_lambda": 0.95,
}
hyperparams_example

In [None]:
# Quick smoke run (adjust timesteps down for faster iteration)
smoke_timesteps = 20000
smoke_model, smoke_history, smoke_logdir = train_maskable_ppo(
    run_name="smoke",
    total_timesteps=smoke_timesteps,
    hyperparams=hyperparams_example,
    n_envs=4,
    eval_freq=5000,
    n_eval_episodes=10,
)
pd.DataFrame(smoke_history)

In [None]:
# Ablation over common hyperparameters
ablation_grid = [
    {"tag": "lr_3e-4_clip_0.2", "learning_rate": 3e-4, "clip_range": 0.2, "gamma": 0.99, "gae_lambda": 0.95},
    {"tag": "lr_1e-3_clip_0.2", "learning_rate": 1e-3, "clip_range": 0.2, "gamma": 0.99, "gae_lambda": 0.95},
    {"tag": "lr_3e-4_clip_0.1", "learning_rate": 3e-4, "clip_range": 0.1, "gamma": 0.99, "gae_lambda": 0.95},
    {"tag": "lr_3e-4_gamma_0.995", "learning_rate": 3e-4, "clip_range": 0.2, "gamma": 0.995, "gae_lambda": 0.95},
    {"tag": "lr_3e-4_ent_0.02", "learning_rate": 3e-4, "clip_range": 0.2, "gamma": 0.99, "gae_lambda": 0.95, "ent_coef": 0.02},
]

# Shared params
common_params = {
    "n_steps": 1024,
    "batch_size": 1024,
    "vf_coef": 0.5,
    "ent_coef": 0.01,
}

ablation_results = []
all_histories = []

total_timesteps = 50000  # adjust up for deeper training
for cfg in ablation_grid:
    hyper = {**common_params, **cfg}
    run_name = f"ablate_{cfg['tag']}"
    print(f"\n=== Training {run_name} ===")
    model, history, log_dir = train_maskable_ppo(
        run_name=run_name,
        total_timesteps=total_timesteps,
        hyperparams=hyper,
        n_envs=8,
        eval_freq=10000,
        n_eval_episodes=15,
    )
    if len(history) > 0:
        best = max(history, key=lambda h: h["mean_reward"])
        best_mean = best["mean_reward"]
    else:
        best_mean = np.nan

    ablation_results.append({
        "run": run_name,
        "best_mean_reward": best_mean,
        "log_dir": log_dir,
        **cfg,
    })
    hist_df = pd.DataFrame(history)
    hist_df["run"] = run_name
    all_histories.append(hist_df)

ablation_df = pd.DataFrame(ablation_results)
ablation_df

In [None]:
# Plot eval curves
if all_histories:
    hist_all_df = pd.concat(all_histories)
    plt.figure(figsize=(10, 6))
    sns.lineplot(data=hist_all_df, x="timesteps", y="mean_reward", hue="run", marker="o")
    plt.title("Evaluation mean reward vs timesteps")
    plt.show()

    plt.figure(figsize=(8, 5))
    sns.barplot(data=ablation_df, x="run", y="best_mean_reward")
    plt.xticks(rotation=30, ha="right")
    plt.title("Best mean reward per run")
    plt.show()
else:
    print("No history to plot yet.")

In [None]:
# Plot training losses from the first run's CSV logs (if present)
loss_keys = [
    "train/policy_gradient_loss",
    "train/value_loss",
    "train/entropy_loss",
    "train/approx_kl",
]

if len(ablation_results) > 0:
    sample_log = pathlib.Path(ablation_results[0]["log_dir"]) / "progress.csv"
elif 'smoke_logdir' in globals():
    sample_log = pathlib.Path(smoke_logdir) / "progress.csv"
else:
    sample_log = None

if sample_log and sample_log.exists():
    df_log = pd.read_csv(sample_log)
    available = [k for k in loss_keys if k in df_log.columns]
    if available:
        plt.figure(figsize=(10, 6))
        for k in available:
            plt.plot(df_log["time/total_timesteps"], df_log[k], label=k)
        plt.legend()
        plt.xlabel("Timesteps")
        plt.title("Training losses / diagnostics")
        plt.show()
    else:
        print("No loss keys available in progress.csv")
else:
    print("No progress.csv found yet.")