# Clinic Scheduling via Reinforcement Learning (Gymnasium + Stable-Baselines3)

This Colab-ready notebook builds and trains a PPO agent to schedule patients under clinic constraints:
- Monday–Saturday only (Sunday closed)
- Operating hours: 08:00–12:00 and 13:00–16:00
- Lunch break 12:00–13:00 (no scheduling)
- Max 60 scheduled patient slots per day
- Walk-ins accepted until cutoff; excess wait in walk-in queue
- If a scheduled patient is not on-site at their time, move to late list and serve next
- If a late patient arrives later, admin can restore to original position; they get priority next after current patient

We'll define a custom Gymnasium environment, train a PPO policy with Stable-Baselines3, and evaluate/visualize outcomes.

In [None]:
# If running in Colab, uncomment the next line to install packages
# !pip -q install gymnasium==0.29.1 stable-baselines3==2.3.2 sb3-contrib==2.3.2 shimmy==1.3.0 plotly==5.24.1 numpy pandas

import sys, os
print(sys.version)
print("Working dir:", os.getcwd())

# Ensure proper imports in Colab kernels
import warnings
warnings.filterwarnings("ignore")

In [None]:
from dataclasses import dataclass
from typing import Tuple, Dict, Any, Optional
import numpy as np
import gymnasium as gym
from gymnasium import spaces

# Domain constants
MINUTES_OPEN_AM = 8 * 60
MINUTES_LUNCH_START = 12 * 60
MINUTES_LUNCH_END = 13 * 60
MINUTES_CLOSE_PM = 16 * 60
WORK_MINUTES = (MINUTES_LUNCH_START - MINUTES_OPEN_AM) + (MINUTES_CLOSE_PM - MINUTES_LUNCH_END)
MAX_SCHEDULED_PER_DAY = 60
MAX_WALKIN_QUEUE = 200
DAYS_OPEN = set(range(6))  # 0=Mon ... 5=Sat, 6=Sun closed


def is_open_minute(minute_of_day: int) -> bool:
    return (MINUTES_OPEN_AM <= minute_of_day < MINUTES_LUNCH_START) or (MINUTES_LUNCH_END <= minute_of_day < MINUTES_CLOSE_PM)


def minute_to_slot(minute_of_day: int, slot_minutes: int) -> int:
    # Map minute to contiguous slot index excluding lunch
    if minute_of_day < MINUTES_OPEN_AM:
        return 0
    if MINUTES_OPEN_AM <= minute_of_day < MINUTES_LUNCH_START:
        return (minute_of_day - MINUTES_OPEN_AM) // slot_minutes
    if MINUTES_LUNCH_START <= minute_of_day < MINUTES_LUNCH_END:
        return (MINUTES_LUNCH_START - MINUTES_OPEN_AM) // slot_minutes
    if MINUTES_LUNCH_END <= minute_of_day < MINUTES_CLOSE_PM:
        am_slots = (MINUTES_LUNCH_START - MINUTES_OPEN_AM) // slot_minutes
        return am_slots + (minute_of_day - MINUTES_LUNCH_END) // slot_minutes
    return ((MINUTES_LUNCH_START - MINUTES_OPEN_AM) + (MINUTES_CLOSE_PM - MINUTES_LUNCH_END)) // slot_minutes


@dataclass
class Patient:
    id: int
    scheduled_slot: Optional[int]  # None for walk-in
    arrival_time_min: Optional[int]  # None means not arrived yet
    is_late: bool = False


class ClinicSchedulingEnv(gym.Env):
    metadata = {"render.modes": ["human"]}

    def __init__(self,
                 slot_minutes: int = 10,
                 max_scheduled: int = MAX_SCHEDULED_PER_DAY,
                 max_walkin_queue: int = MAX_WALKIN_QUEUE,
                 no_show_prob: float = 0.05,
                 late_prob: float = 0.1,
                 walkin_rate_per_hour: float = 8.0,
                 walkin_cutoff_minute: Optional[int] = None,
                 day_of_week: Optional[int] = None,
                 seed: Optional[int] = None):
        super().__init__()
        self.slot_minutes = slot_minutes
        self.slots_per_day = WORK_MINUTES // slot_minutes
        self.max_scheduled = min(max_scheduled, MAX_SCHEDULED_PER_DAY)
        self.max_walkin_queue = max_walkin_queue
        self.no_show_prob = no_show_prob
        self.late_prob = late_prob
        self.walkin_rate_per_hour = walkin_rate_per_hour
        self.walkin_cutoff_minute = walkin_cutoff_minute or MINUTES_CLOSE_PM
        self._configured_day_of_week = day_of_week
        self.rng = np.random.default_rng(seed)

        # Action space: choose next source to serve
        # 0 = next scheduled on-time, 1 = next walk-in, 2 = recall-priority late (if any)
        self.action_space = spaces.Discrete(3)

        # Observation space (compact):
        # [current_slot_index, scheduled_remaining, walkin_queue_len, late_list_len, next_scheduled_on_site(0/1), time_to_next_arrival_minutes]
        high = np.array([
            self.slots_per_day,
            MAX_SCHEDULED_PER_DAY,
            self.max_walkin_queue,
            MAX_SCHEDULED_PER_DAY,
            1,
            60
        ], dtype=np.float32)
        self.observation_space = spaces.Box(low=0.0, high=high, dtype=np.float32)

        self.reset_state()

    def reset_state(self):
        self.minute = MINUTES_OPEN_AM
        self.current_slot = 0
        # self.day_of_week is set in reset(); don't override here
        self.scheduled: Dict[int, Patient] = {}
        self.walkin_queue: list[Patient] = []
        self.late_list: list[Patient] = []
        self.served_ids: list[int] = []
        self.served_log: list[Dict[str, Any]] = []
        self.generated_patients: Dict[int, Patient] = {}
        self._generate_day_schedule()

    def _generate_day_schedule(self):
        # Pre-generate scheduled patients across slots (max 60)
        max_slots = self.slots_per_day
        chosen_slots = self.rng.choice(max_slots, size=min(self.max_scheduled, max_slots), replace=False)
        pid = 1
        for slot in sorted(chosen_slots.tolist()):
            # arrival: on time or late or no-show
            ontime = self.rng.random() > self.late_prob
            if self.rng.random() < self.no_show_prob:
                arrival = None
            else:
                if ontime:
                    # arrive within slot's first 5 minutes
                    slot_minute = self._slot_to_minute(slot)
                    jitter = int(self.rng.integers(0, min(5, self.slot_minutes)))
                    arrival = slot_minute + jitter
                else:
                    # late: arrive between +5 and +60 minutes later
                    base = self._slot_to_minute(slot) + 5
                    arrival = min(base + int(self.rng.integers(0, 60)), MINUTES_CLOSE_PM - 1)
            p = Patient(id=pid, scheduled_slot=slot, arrival_time_min=arrival)
            self.scheduled[slot] = p
            self.generated_patients[pid] = p
            pid += 1
        self.next_walkin_id = pid

    def _slot_to_minute(self, slot_index: int) -> int:
        am_slots = (MINUTES_LUNCH_START - MINUTES_OPEN_AM) // self.slot_minutes
        if slot_index < am_slots:
            return MINUTES_OPEN_AM + slot_index * self.slot_minutes
        else:
            return MINUTES_LUNCH_END + (slot_index - am_slots) * self.slot_minutes

    def _poisson(self, lam):
        # simple Poisson sampler via numpy
        return self.rng.poisson(lam)

    def _maybe_generate_walkins(self):
        if not is_open_minute(self.minute):
            return
        if self.minute >= self.walkin_cutoff_minute:
            return
        # per-minute rate
        lam = self.walkin_rate_per_hour / 60.0
        arrivals = self._poisson(lam)
        for _ in range(arrivals):
            p = Patient(id=self.next_walkin_id, scheduled_slot=None, arrival_time_min=self.minute)
            self.next_walkin_id += 1
            if len(self.walkin_queue) < self.max_walkin_queue:
                self.walkin_queue.append(p)
                self.generated_patients[p.id] = p

    def _update_late_status(self):
        # mark scheduled as late if slot passed and not on-site
        for slot, p in list(self.scheduled.items()):
            slot_minute = self._slot_to_minute(slot)
            if p.arrival_time_min is None:
                continue  # no-show remains
            if p.arrival_time_min > self.minute and self.minute >= slot_minute:
                p.is_late = True
                # move to late list if not already served and past slot
                if p not in self.late_list and p.id not in self.served_ids and self.minute >= slot_minute:
                    self.late_list.append(p)

        # move any arrived late patients out of late_list priority bucket if they just arrived now
        for p in self.late_list:
            if p.arrival_time_min is not None and p.arrival_time_min <= self.minute:
                p.is_late = True  # keep flag

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
        super().reset(seed=seed)
        if seed is not None:
            self.rng = np.random.default_rng(seed)
        # set day-of-week (exclude Sunday by default)
        if options and "day_of_week" in options:
            self.day_of_week = int(options["day_of_week"])
        else:
            if self._configured_day_of_week is not None:
                self.day_of_week = int(self._configured_day_of_week)
            else:
                self.day_of_week = int(self.rng.integers(0, 6))
                if self.day_of_week not in DAYS_OPEN:
                    # force into open set
                    self.day_of_week = int(self.rng.integers(0, 6))
        # If Sunday (closed), start and immediately end episode by returning a terminal state
        self.reset_state()
        obs = self._get_obs()
        info = {}
        return obs, info

    def _get_obs(self):
        # next scheduled slot that is not served yet
        remaining_slots = [s for s, p in self.scheduled.items() if p.id not in self.served_ids]
        scheduled_remaining = len(remaining_slots)

        # check if the next scheduled patient is on-site now
        next_slot = min(remaining_slots) if remaining_slots else None
        on_site = 0
        if next_slot is not None:
            p = self.scheduled[next_slot]
            if p.arrival_time_min is not None and p.arrival_time_min <= self.minute:
                on_site = 1

        # estimate time to next arrival (scheduled or walk-in), clipped
        next_arrival = 999
        # next scheduled arrival
        sched_arrivals = [p.arrival_time_min for p in self.scheduled.values() if p.arrival_time_min is not None and p.arrival_time_min > self.minute]
        if sched_arrivals:
            next_arrival = min(next_arrival, min(sched_arrivals) - self.minute)
        # approximate next walk-in arrival as inverse rate
        if is_open_minute(self.minute) and self.walkin_rate_per_hour > 0 and self.minute < self.walkin_cutoff_minute:
            expected_walkin_gap = max(1, int(60.0 / self.walkin_rate_per_hour))
            next_arrival = min(next_arrival, expected_walkin_gap)
        next_arrival = int(min(next_arrival, 60))

        obs = np.array([
            minute_to_slot(self.minute, self.slot_minutes),
            scheduled_remaining,
            len(self.walkin_queue),
            len(self.late_list),
            on_site,
            next_arrival
        ], dtype=np.float32)
        return obs

    def step(self, action: int):
        # simulate minute-by-minute until either we serve someone or day ends
        reward = 0.0
        info: Dict[str, Any] = {}

        # At each step, first update arrivals
        self._maybe_generate_walkins()
        self._update_late_status()

        # determine candidate queues based on action
        # 2: late priority (admin recalls) if available
        # 0: scheduled on-time if on-site
        # 1: walk-in otherwise
        served_patient: Optional[Patient] = None
        served_source = None
        served_via_recall = False

        if action == 2 and self.late_list:
            served_patient = self.late_list.pop(0)
            served_source = "late_recall"
            served_via_recall = True
        else:
            # next scheduled
            remaining_slots = sorted([s for s, p in self.scheduled.items() if p.id not in self.served_ids])
            if action == 0 and remaining_slots:
                next_slot = remaining_slots[0]
                p = self.scheduled[next_slot]
                if p.arrival_time_min is not None and p.arrival_time_min <= self.minute:
                    served_patient = p
                    served_source = "scheduled"
            if served_patient is None and self.walkin_queue:
                served_patient = self.walkin_queue.pop(0)
                served_source = "walkin"

        # apply serving and time advance
        if served_patient is not None:
            self.served_ids.append(served_patient.id)
            # reward per patient served; small bonus for serving scheduled/late via recall
            reward += 1.0
            if served_source in ("scheduled", "late_recall"):
                reward += 0.05
            # compute wait
            served_start_minute = self.minute
            arrival = served_patient.arrival_time_min
            wait = None
            if arrival is not None:
                wait = max(0, served_start_minute - arrival)
            # log
            self.served_log.append({
                "id": served_patient.id,
                "served_start_minute": served_start_minute,
                "scheduled_slot": served_patient.scheduled_slot,
                "arrival_time": arrival,
                "is_walkin": served_patient.scheduled_slot is None,
                "is_late": bool(served_patient.is_late),
                "served_via_recall": served_via_recall,
                "wait_minutes": wait,
                "source": served_source,
            })
            advance = self.slot_minutes
        else:
            # idle minute penalty for waiting while queues exist
            advance = 1
            reward -= 0.01

        # advance time respecting lunch/closed periods
        self.minute += advance
        if MINUTES_LUNCH_START <= self.minute < MINUTES_LUNCH_END:
            self.minute = MINUTES_LUNCH_END
        done = self.minute >= MINUTES_CLOSE_PM

        obs = self._get_obs()

        # small penalty for leaving late list unserved to encourage recalls
        reward -= 0.001 * len(self.late_list)

        # Encourage finishing scheduled by end of day
        if done:
            remaining_sched = len([p for p in self.scheduled.values() if p.id not in self.served_ids])
            if remaining_sched > 0:
                reward -= 0.1 * remaining_sched

        return obs, reward, done, False, info

    def render(self):
        print({
            "time": self.minute,
            "served": len(self.served_ids),
            "walkin_q": len(self.walkin_queue),
            "late": len(self.late_list)
        })

In [None]:
# Quick smoke test of env dynamics
env = ClinicSchedulingEnv(slot_minutes=10, seed=42)
obs, info = env.reset()
print("obs shape:", obs.shape, "obs:", obs)
for _ in range(5):
    a = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(a)
    print({"a": int(a), "r": float(reward), "done": bool(terminated or truncated), "obs": obs.tolist()})

In [None]:
# Training: PPO on the scheduling environment
import os
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

log_dir = "/content/logs" if os.path.exists("/content") else "/workspace/logs"
os.makedirs(log_dir, exist_ok=True)

# make_vec_env handles Monitor and Gymnasium compatibility wrappers
vec_env = make_vec_env(lambda: ClinicSchedulingEnv(slot_minutes=10), n_envs=1, monitor_dir=log_dir)

model = PPO(
    "MlpPolicy",
    vec_env,
    verbose=1,
    learning_rate=3e-4,
    n_steps=2048,
    batch_size=256,
    n_epochs=10,
    gamma=0.995,
    gae_lambda=0.95,
)

# Train
timesteps = 200_000
model.learn(total_timesteps=timesteps)

# Save
model_path = os.path.join(log_dir, "ppo_clinic_scheduling")
model.save(model_path)
print("Saved model to", model_path)

In [None]:
# Evaluation and visualization
import numpy as np
import pandas as pd
import plotly.express as px
from stable_baselines3 import PPO

# load
loaded = PPO.load(model_path)

def run_episode(env_seed=None):
    env = ClinicSchedulingEnv(slot_minutes=10, seed=env_seed)
    obs, info = env.reset()
    done = False
    events = []
    t = 0
    while not done:
        action, _ = loaded.predict(obs, deterministic=True)
        prev_min = env.minute
        prev_walk = len(env.walkin_queue)
        prev_late = len(env.late_list)
        obs, reward, terminated, truncated, info = env.step(int(action))
        done = terminated or truncated
        events.append({
            "t": t,
            "minute": prev_min,
            "action": int(action),
            "reward": float(reward),
            "served": len(env.served_ids),
            "walkin_q": prev_walk,
            "late": prev_late,
        })
        t += 1
    return pd.DataFrame(events)

 df = run_episode(env_seed=123)
fig = px.line(df, x="minute", y=["served", "walkin_q", "late"], title="Clinic metrics over time")
fig.show()

In [None]:
# Improved evaluation: metrics and charts
import numpy as np
import pandas as pd
import plotly.express as px
from stable_baselines3 import PPO

# Load trained model
loaded = PPO.load(model_path)


def run_episode(env_seed=None, walkin_cutoff_minute=None):
    env = ClinicSchedulingEnv(slot_minutes=10, seed=env_seed, walkin_cutoff_minute=walkin_cutoff_minute)
    obs, info = env.reset()
    done = False
    timeline = []
    step_idx = 0
    while not done:
        action, _ = loaded.predict(obs, deterministic=True)
        prev_state = {
            "step": step_idx,
            "minute": env.minute,
            "walkin_q": len(env.walkin_queue),
            "late": len(env.late_list),
            "served": len(env.served_ids),
            "action": int(action),
        }
        obs, reward, terminated, truncated, _ = env.step(int(action))
        done = bool(terminated or truncated)
        prev_state["reward"] = float(reward)
        timeline.append(prev_state)
        step_idx += 1
    # Build DataFrames
    timeline_df = pd.DataFrame(timeline)
    served_df = pd.DataFrame(env.served_log)
    return env, timeline_df, served_df


# Run one evaluation episode
env, timeline_df, served_df = run_episode(env_seed=123)

# Summary metrics
scheduled_total = len(env.scheduled)
served_total = len(served_df)
served_scheduled = int((~served_df["is_walkin"]).sum()) if not served_df.empty else 0
served_walkin = int((served_df["is_walkin"]).sum()) if not served_df.empty else 0
served_via_recall = int((served_df["served_via_recall"]).sum()) if not served_df.empty else 0
avg_wait_scheduled = float(served_df.loc[~served_df["is_walkin"], "wait_minutes"].dropna().mean()) if served_scheduled > 0 else np.nan
avg_wait_walkin = float(served_df.loc[served_df["is_walkin"], "wait_minutes"].dropna().mean()) if served_walkin > 0 else np.nan

print("Scheduled total:", scheduled_total)
print("Served total:", served_total, "(scheduled:", served_scheduled, ", walk-in:", served_walkin, ")")
print("Late recalls served:", served_via_recall)
print("Avg wait (scheduled):", avg_wait_scheduled)
print("Avg wait (walk-in):", avg_wait_walkin)
print("Remaining late list:", len(env.late_list))
print("Remaining walk-in queue:", len(env.walkin_queue))

# Charts
fig1 = px.line(timeline_df, x="minute", y=["served", "walkin_q", "late"], title="Queues and served over time")
fig1.show()

if not served_df.empty and served_df["wait_minutes"].notna().any():
    fig2 = px.histogram(served_df.dropna(subset=["wait_minutes"]), x="wait_minutes", nbins=30, title="Wait time distribution (minutes)")
    fig2.show()
else:
    print("No served patients to plot wait distribution.")