# Rwanda Health RL Training Notebook

End-to-end experimentation pipeline for DQN, REINFORCE, PPO, and A2C on the clinic environment.

## How to Use
1. Open this notebook in Google Colab or any GPU-enabled environment and make sure the repo is available (e.g., run `!git clone <repo-url>` so `environment.custom_env` can be imported).
2. Run the dependency cell if the runtime is fresh.
3. Configure the toggles (`RUN_DQN`, `RUN_PPO`, `RUN_A2C`, `RUN_REINFORCE`) before launching training.
4. After runs finish, execute the results and visualization cells to compare agents.
5. Export the final table to CSV for your report.

In [None]:
# Uncomment when running on a fresh Colab runtime
# !pip install --quiet "gymnasium==0.29.1" "stable-baselines3==2.3.2" "sb3-contrib==2.3.2" "pygame==2.6.0"
# !pip install --quiet "torch==2.3.0" "numpy" "pandas" "matplotlib" "seaborn"

In [None]:
import os
import json
import time
import itertools
from datetime import datetime

import numpy as np
import pandas as pd
import gymnasium as gym
from gymnasium import spaces

from typing import Any, Dict, List, Optional, Tuple

from stable_baselines3 import DQN, PPO, A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

import torch
from torch import nn
from torch.distributions import Categorical
from torch.nn.utils import clip_grad_norm_

import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style='whitegrid')

PROJECT_ROOT = os.getcwd()
MODEL_DIR = os.path.join(PROJECT_ROOT, 'models')
RESULTS_DIR = os.path.join(PROJECT_ROOT, 'results')
LOG_DIR = os.path.join(PROJECT_ROOT, 'logs')

for subdir in [MODEL_DIR, RESULTS_DIR, LOG_DIR]:
    os.makedirs(subdir, exist_ok=True)

for algo_name in ['dqn', 'ppo', 'a2c', 'reinforce']:
    os.makedirs(os.path.join(MODEL_DIR, algo_name), exist_ok=True)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {DEVICE}')

In [None]:
from environment.custom_env import (
    RwandaHealthEnv,
    ACTION_MEANINGS,
    REQUEST_NCD_TEST,
    REQUEST_INFECTION_TEST,
    DIAGNOSE_CHRONIC,
    DIAGNOSE_INFECTION,
    ALLOCATE_MED,
    REFER_PATIENT,
    WAIT,
    make_env as _base_make_env,
    CONDITION_HEALTHY_OR_MILD,
    CONDITION_CHRONIC,
    CONDITION_INFECTION,
    CONDITION_BOTH_SERIOUS,
    )

In [None]:
experiment_runs: List[Dict[str, Any]] = []

def make_env(seed: Optional[int] = None, render_mode: Optional[str] = None, monitor: bool = True):
    return _base_make_env(seed=seed, render_mode=render_mode, monitor=monitor)

def evaluate_sb3_model(model, eval_episodes: int = 20, seed: int = 10_000) -> Tuple[float, float]:
    eval_env = make_env(seed=seed, monitor=False)()
    mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=eval_episodes, deterministic=True)
    eval_env.close()
    return float(mean_reward), float(std_reward)

def evaluate_reinforce_policy(policy: nn.Module, eval_episodes: int = 20, seed: int = 20_000) -> Tuple[float, float]:
    policy.eval()
    rewards: List[float] = []
    for episode in range(eval_episodes):
        env = make_env(seed=seed + episode, monitor=False)()
        obs, _ = env.reset(seed=seed + episode)
        terminated = False
        truncated = False
        episode_reward = 0.0
        while not (terminated or truncated):
            obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=DEVICE)
            with torch.no_grad():
                logits = policy(obs_tensor)
            action = int(torch.argmax(logits).item())
            obs, reward, terminated, truncated, _ = env.step(action)
            episode_reward += reward
        env.close()
        rewards.append(episode_reward)
    policy.train()
    return float(np.mean(rewards)), float(np.std(rewards))

def record_result(
    algorithm: str,
    run_id: int,
    seed: int,
    params: Dict[str, Any],
    mean_reward: float,
    std_reward: float,
    metadata: Optional[Dict[str, Any]] = None,
 ) -> None:
    run_record = {
        'algorithm': algorithm,
        'run_id': run_id,
        'seed': seed,
        'mean_reward': mean_reward,
        'std_reward': std_reward,
        'hyperparameters': dict(params),
    }
    if metadata:
        run_record.update(metadata)
    experiment_runs.append(run_record)

dummy_env = RwandaHealthEnv()
OBS_DIM = dummy_env.observation_space.shape[0]
ACT_DIM = dummy_env.action_space.n
dummy_env.close()
del dummy_env
print(f'Observation dim: {OBS_DIM}, action dim: {ACT_DIM}')

In [None]:
def build_dqn_grid() -> List[Dict[str, Any]]:
    combos = list(itertools.product(
        [1e-3, 5e-4, 3e-4],
        [0.95, 0.98, 0.99],
        [4096, 8192, 16384],
        [32, 64, 128],
        [1, 4, 8],
        [500, 750, 1000],
        [0.2, 0.3],
        [0.01, 0.02],
    ))
    grid: List[Dict[str, Any]] = []
    for combo in combos:
        lr, gamma, buffer_size, batch_size, train_freq, target_update_interval, exploration_fraction, exploration_final_eps = combo
        grid.append({
            'learning_rate': lr,
            'gamma': gamma,
            'buffer_size': buffer_size,
            'batch_size': batch_size,
            'train_freq': train_freq,
            'target_update_interval': target_update_interval,
            'exploration_fraction': exploration_fraction,
            'exploration_final_eps': exploration_final_eps,
        })
        if len(grid) >= 12:
            break
    return grid

def build_ppo_grid() -> List[Dict[str, Any]]:
    combos = list(itertools.product(
        [3e-4, 1e-4, 5e-4],
        [0.95, 0.98, 0.99],
        [1024, 2048, 3072],
        [64, 128, 256],
        [3, 5, 10],
        [0.9, 0.95, 0.99],
        [0.1, 0.2, 0.3],
    ))
    grid: List[Dict[str, Any]] = []
    for combo in combos:
        lr, gamma, n_steps, batch_size, n_epochs, gae_lambda, clip_range = combo
        if batch_size > n_steps:
            continue
        grid.append({
            'learning_rate': lr,
            'gamma': gamma,
            'n_steps': n_steps,
            'batch_size': batch_size,
            'n_epochs': n_epochs,
            'gae_lambda': gae_lambda,
            'clip_range': clip_range,
        })
        if len(grid) >= 12:
            break
    return grid

def build_a2c_grid() -> List[Dict[str, Any]]:
    combos = list(itertools.product(
        [7e-4, 5e-4, 3e-4],
        [0.95, 0.98, 0.99],
        [0.9, 0.95, 0.99],
        [5, 10, 20],
        [0.0, 0.01, 0.05],
    ))
    grid: List[Dict[str, Any]] = []
    for combo in combos:
        lr, gamma, gae_lambda, n_steps, ent_coef = combo
        grid.append({
            'learning_rate': lr,
            'gamma': gamma,
            'gae_lambda': gae_lambda,
            'n_steps': n_steps,
            'ent_coef': ent_coef,
        })
        if len(grid) >= 12:
            break
    return grid

def build_reinforce_grid() -> List[Dict[str, Any]]:
    combos = list(itertools.product(
        [1e-3, 5e-4, 3e-4],
        [0.95, 0.98, 0.99],
        [(128, 128), (256, 128), (256, 256), (128, 64)],
        [0.0, 0.01, 0.02],
        [0.5, 1.0],
        [20, 30, 40],
    ))
    grid: List[Dict[str, Any]] = []
    for combo in combos:
        lr, gamma, hidden_layers, entropy_coef, max_grad_norm, log_interval = combo
        grid.append({
            'learning_rate': lr,
            'gamma': gamma,
            'hidden_layers': hidden_layers,
            'entropy_coef': entropy_coef,
            'max_grad_norm': max_grad_norm,
            'log_interval': log_interval,
        })
        if len(grid) >= 12:
            break
    return grid

DQN_PARAM_GRID = build_dqn_grid()
PPO_PARAM_GRID = build_ppo_grid()
A2C_PARAM_GRID = build_a2c_grid()
REINFORCE_PARAM_GRID = build_reinforce_grid()

print(f'DQN grid size: {len(DQN_PARAM_GRID)}')
print(f'PPO grid size: {len(PPO_PARAM_GRID)}')
print(f'A2C grid size: {len(A2C_PARAM_GRID)}')
print(f'REINFORCE grid size: {len(REINFORCE_PARAM_GRID)}')

In [None]:
TOTAL_TIMESTEPS = {
    'dqn': 200_000,
    'ppo': 600_000,
    'a2c': 400_000,
}

REINFORCE_EPISODES = 800
EVAL_EPISODES = 20
BASE_SEED = 2024

print('Timesteps per algorithm:', TOTAL_TIMESTEPS)
print(f'Reinforce episodes: {REINFORCE_EPISODES}, evaluation episodes: {EVAL_EPISODES}')

In [None]:
def run_sb3_experiments(
    algo_name: str,
    algo_cls,
    param_grid: List[Dict[str, Any]],
    total_timesteps: int,
    eval_episodes: int = EVAL_EPISODES,
    base_seed: int = BASE_SEED,
 ) -> None:
    print(f'Starting {algo_name.upper()} experiments ({len(param_grid)} runs)...')
    for run_idx, params in enumerate(param_grid):
        run_seed = base_seed + run_idx
        set_random_seed(run_seed)
        env = DummyVecEnv([make_env(seed=run_seed, monitor=True)])
        log_dir = os.path.join(LOG_DIR, algo_name, f'run_{run_idx:02d}')
        os.makedirs(log_dir, exist_ok=True)
        model = algo_cls('MlpPolicy', env, verbose=1, tensorboard_log=log_dir, seed=run_seed, **params)
        start_time = time.time()
        try:
            model.learn(total_timesteps=total_timesteps, progress_bar=True)
        except TypeError:
            model.learn(total_timesteps=total_timesteps)
        duration = time.time() - start_time
        mean_reward, std_reward = evaluate_sb3_model(model, eval_episodes=eval_episodes, seed=run_seed + 10_000)
        model_path = os.path.join(MODEL_DIR, algo_name, f'{algo_name}_run_{run_idx:02d}')
        model.save(model_path)
        env.close()
        record_result(
            algorithm=algo_name,
            run_id=run_idx,
            seed=run_seed,
            params=params,
            mean_reward=mean_reward,
            std_reward=std_reward,
            metadata={'duration_sec': duration, 'timesteps': total_timesteps, 'model_path': model_path},
        )
        print(f'Run {run_idx:02d} -> mean reward {mean_reward:.2f} ± {std_reward:.2f} (saved to {model_path})')

In [None]:
RUN_DQN = False
if RUN_DQN:
    run_sb3_experiments(
        algo_name='dqn',
        algo_cls=DQN,
        param_grid=DQN_PARAM_GRID,
        total_timesteps=TOTAL_TIMESTEPS['dqn'],
        eval_episodes=EVAL_EPISODES,
        base_seed=BASE_SEED,
    )

In [None]:
RUN_PPO = False
if RUN_PPO:
    run_sb3_experiments(
        algo_name='ppo',
        algo_cls=PPO,
        param_grid=PPO_PARAM_GRID,
        total_timesteps=TOTAL_TIMESTEPS['ppo'],
        eval_episodes=EVAL_EPISODES,
        base_seed=BASE_SEED + 500,
    )

In [None]:
RUN_A2C = False
if RUN_A2C:
    run_sb3_experiments(
        algo_name='a2c',
        algo_cls=A2C,
        param_grid=A2C_PARAM_GRID,
        total_timesteps=TOTAL_TIMESTEPS['a2c'],
        eval_episodes=EVAL_EPISODES,
        base_seed=BASE_SEED + 1_000,
    )

In [None]:
class ReinforceNetwork(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden_layers: Tuple[int, ...]):
        super().__init__()
        layers: List[nn.Module] = []
        input_dim = obs_dim
        for hidden_dim in hidden_layers:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            input_dim = hidden_dim
        layers.append(nn.Linear(input_dim, act_dim))
        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

def train_reinforce(
    param_grid: List[Dict[str, Any]],
    total_episodes: int = REINFORCE_EPISODES,
    eval_episodes: int = EVAL_EPISODES,
    base_seed: int = BASE_SEED + 2_000,
 ) -> None:
    print(f'Starting REINFORCE experiments ({len(param_grid)} runs)...')
    for run_idx, params in enumerate(param_grid):
        run_seed = base_seed + run_idx
        set_random_seed(run_seed)
        np.random.seed(run_seed)
        torch.manual_seed(run_seed)
        env = make_env(seed=run_seed, monitor=False)()
        policy = ReinforceNetwork(OBS_DIM, ACT_DIM, params['hidden_layers']).to(DEVICE)
        policy.train()
        optimizer = torch.optim.Adam(policy.parameters(), lr=params['learning_rate'])
        gamma = params['gamma']
        entropy_coef = params.get('entropy_coef', 0.0)
        max_grad_norm = params.get('max_grad_norm', 1.0)
        log_interval = params.get('log_interval', 20)
        episode_rewards: List[float] = []
        start_time = time.time()
        for episode in range(total_episodes):
            log_probs: List[torch.Tensor] = []
            entropies: List[torch.Tensor] = []
            rewards: List[float] = []
            obs, _ = env.reset(seed=run_seed + episode)
            terminated = False
            truncated = False
            while not (terminated or truncated):
                obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=DEVICE)
                dist = Categorical(logits=policy(obs_tensor))
                action = dist.sample()
                log_probs.append(dist.log_prob(action))
                entropies.append(dist.entropy())
                obs, reward, terminated, truncated, _ = env.step(int(action.item()))
                rewards.append(reward)
            returns: List[float] = []
            discounted_return = 0.0
            for reward in reversed(rewards):
                discounted_return = reward + gamma * discounted_return
                returns.insert(0, discounted_return)
            returns_tensor = torch.as_tensor(returns, dtype=torch.float32, device=DEVICE)
            normalized_returns = (returns_tensor - returns_tensor.mean()) / (returns_tensor.std() + 1e-8)
            log_prob_tensor = torch.stack(log_probs)
            entropy_tensor = torch.stack(entropies)
            loss = -(normalized_returns.detach() * log_prob_tensor).sum()
            if entropy_coef > 0.0:
                loss -= entropy_coef * entropy_tensor.sum()
            optimizer.zero_grad()
            loss.backward()
            if max_grad_norm is not None:
                clip_grad_norm_(policy.parameters(), max_grad_norm)
            optimizer.step()
            episode_reward = float(np.sum(rewards))
            episode_rewards.append(episode_reward)
            if (episode + 1) % log_interval == 0:
                rolling = float(np.mean(episode_rewards[-log_interval:]))
                print(f'Run {run_idx:02d} | Episode {episode + 1} | rolling reward {rolling:.2f}')
        duration = time.time() - start_time
        mean_reward, std_reward = evaluate_reinforce_policy(policy, eval_episodes=eval_episodes, seed=run_seed + 30_000)
        model_path = os.path.join(MODEL_DIR, 'reinforce', f'reinforce_run_{run_idx:02d}.pt')
        torch.save({'state_dict': policy.state_dict(), 'hyperparameters': dict(params)}, model_path)
        env.close()
        record_result(
            algorithm='reinforce',
            run_id=run_idx,
            seed=run_seed,
            params=params,
            mean_reward=mean_reward,
            std_reward=std_reward,
            metadata={'duration_sec': duration, 'episodes': total_episodes, 'model_path': model_path},
        )
        print(f'Run {run_idx:02d} -> mean reward {mean_reward:.2f} ± {std_reward:.2f} (saved to {model_path})')

In [None]:
RUN_REINFORCE = False
if RUN_REINFORCE:
    train_reinforce(
        param_grid=REINFORCE_PARAM_GRID,
        total_episodes=REINFORCE_EPISODES,
        eval_episodes=EVAL_EPISODES,
        base_seed=BASE_SEED + 4_000,
    )

In [None]:
if experiment_runs:
    results_df = pd.DataFrame(experiment_runs)
    results_df = results_df.sort_values(by='mean_reward', ascending=False).reset_index(drop=True)
    display(results_df)
else:
    print('No experiments recorded yet. Enable one of the RUN_* toggles and re-run.')

In [None]:
if experiment_runs:
    results_df = pd.DataFrame(experiment_runs)
    plt.figure(figsize=(10, 5))
    sns.boxplot(data=results_df, x='algorithm', y='mean_reward')
    plt.title('Reward distribution per algorithm')
    plt.show()

    plt.figure(figsize=(10, 5))
    sns.scatterplot(data=results_df, x='run_id', y='mean_reward', hue='algorithm', style='algorithm')
    plt.title('Per-run rewards by algorithm')
    plt.show()
else:
    print('No results to visualize yet.')

In [None]:
if experiment_runs:
    results_df = pd.DataFrame(experiment_runs)
    timestamp = datetime.utcnow().strftime('%Y%m%d_%H%M%S')
    csv_path = os.path.join(RESULTS_DIR, f'rl_experiment_results_{timestamp}.csv')
    results_df.to_csv(csv_path, index=False)
    print(f'Results saved to {csv_path}')
else:
    print('Run at least one experiment before exporting results.')

In [None]:
def load_and_evaluate_sb3_model(model_path: str, eval_episodes: int = EVAL_EPISODES) -> None:
    if not os.path.exists(model_path) and os.path.exists(f'{model_path}.zip'):
        model_path = f'{model_path}.zip'
    model_path_lower = model_path.lower()
    if 'dqn' in model_path_lower:
        model = DQN.load(model_path)
    elif 'ppo' in model_path_lower:
        model = PPO.load(model_path)
    elif 'a2c' in model_path_lower:
        model = A2C.load(model_path)
    else:
        raise ValueError('Unsupported SB3 model path. Include algo name in filename.')
    mean_reward, std_reward = evaluate_sb3_model(model, eval_episodes=eval_episodes, seed=BASE_SEED + 50_000)
    print(f'Loaded {model.__class__.__name__} -> mean reward {mean_reward:.2f} ± {std_reward:.2f}')

def evaluate_saved_reinforce(model_path: str, eval_episodes: int = EVAL_EPISODES) -> None:
    payload = torch.load(model_path, map_location=DEVICE)
    params = payload.get('hyperparameters', {'hidden_layers': (128, 128)})
    hidden_layers = tuple(params.get('hidden_layers', (128, 128)))
    policy = ReinforceNetwork(OBS_DIM, ACT_DIM, hidden_layers).to(DEVICE)
    policy.load_state_dict(payload['state_dict'])
    mean_reward, std_reward = evaluate_reinforce_policy(policy, eval_episodes=eval_episodes, seed=BASE_SEED + 60_000)
    print(f'Loaded REINFORCE policy -> mean reward {mean_reward:.2f} ± {std_reward:.2f}')

## Next Steps
- Flip one toggle at a time and monitor the console for each run.
- Use TensorBoard logs saved under `logs/` for qualitative diagnostics.
- Export the aggregated CSV and figures for the PDF report and presentation assets.
- Record the best agent by loading the saved model and running `env.render()` with `render_mode='human'`.