<a href="https://colab.research.google.com/github/Heather306/test-repo/blob/cursor%2Fcontinue-rl-training-for-project-context-4a2f/notebooks/clinic_scheduling_rl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Clinic Scheduling via Reinforcement Learning (Gymnasium + Stable-Baselines3)

This Colab-ready notebook builds and trains a PPO agent to schedule patients under clinic constraints:
- Monday–Saturday only (Sunday closed)
- Operating hours: 08:00–12:00 and 13:00–16:00
- Lunch break 12:00–13:00 (no scheduling)
- Max 60 scheduled patient slots per day
- Walk-ins accepted until cutoff; excess wait in walk-in queue
- If a scheduled patient is not on-site at their time, move to late list and serve next
- If a late patient arrives later, admin can restore to original position; they get priority next after current patient

We'll define a custom Gymnasium environment, train a PPO policy with Stable-Baselines3, and evaluate/visualize outcomes.

## Keep Notebook Synced With GitHub

Working in Colab and want to make sure you are on the latest GitHub commit? Run the next cell. It will clone (or update) the repo branch into `/content/test-repo`, so you can reopen this notebook directly from the synced copy after it finishes.

In [2]:
# Clone or pull the latest notebook/code from GitHub
import shutil
import subprocess
from pathlib import Path

REPO_URL = "https://github.com/Heather306/test-repo.git"
BRANCH = "cursor/train-system-app-scheduling-with-reinforcement-learning-095a"
LOCAL_DIR = Path("/content/test-repo")


def run(cmd):
    print(">", " ".join(cmd))
    subprocess.run(cmd, check=True)


# If the folder exists without a git repo, clear it so we can clone fresh
if LOCAL_DIR.exists() and not (LOCAL_DIR / ".git").exists():
    shutil.rmtree(LOCAL_DIR)

if (LOCAL_DIR / ".git").exists():
    run(["git", "-C", str(LOCAL_DIR), "fetch", "origin"])
    run(["git", "-C", str(LOCAL_DIR), "checkout", BRANCH])
    run(["git", "-C", str(LOCAL_DIR), "pull", "origin", BRANCH])
else:
    if LOCAL_DIR.exists():
        shutil.rmtree(LOCAL_DIR)
    run(["git", "clone", "--branch", BRANCH, REPO_URL, str(LOCAL_DIR)])

nb_path = LOCAL_DIR / "notebooks" / "clinic_scheduling_rl.ipynb"
print(f"Latest notebook available at: {nb_path}")
print("In Colab, use File → Open notebook → Upload → Browse (left 'Files' pane) to reopen from that path if needed.")

> git clone --branch cursor/train-system-app-scheduling-with-reinforcement-learning-095a https://github.com/Heather306/test-repo.git /content/test-repo
Latest notebook available at: /content/test-repo/notebooks/clinic_scheduling_rl.ipynb
In Colab, use File → Open notebook → Upload → Browse (left 'Files' pane) to reopen from that path if needed.


In [3]:
# If running in Colab, uncomment the next line to install packages
# !pip -q install gymnasium==0.29.1 stable-baselines3==2.3.2 sb3-contrib==2.3.2 shimmy==1.3.0 plotly==5.24.1 numpy pandas

import sys, os
print(sys.version)
print("Working dir:", os.getcwd())

# Ensure proper imports in Colab kernels
import warnings
warnings.filterwarnings("ignore")

3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
Working dir: /content


In [4]:
from dataclasses import dataclass
from typing import Tuple, Dict, Any, Optional
import numpy as np
import gymnasium as gym
from gymnasium import spaces

# Domain constants
MINUTES_OPEN_AM = 8 * 60
MINUTES_LUNCH_START = 12 * 60
MINUTES_LUNCH_END = 13 * 60
MINUTES_CLOSE_PM = 16 * 60
WORK_MINUTES = (MINUTES_LUNCH_START - MINUTES_OPEN_AM) + (MINUTES_CLOSE_PM - MINUTES_LUNCH_END)
MAX_SCHEDULED_PER_DAY = 60
MAX_WALKIN_QUEUE = 200
DAYS_OPEN = set(range(6))  # 0=Mon ... 5=Sat, 6=Sun closed


def is_open_minute(minute_of_day: int) -> bool:
    return (MINUTES_OPEN_AM <= minute_of_day < MINUTES_LUNCH_START) or (MINUTES_LUNCH_END <= minute_of_day < MINUTES_CLOSE_PM)


def minute_to_slot(minute_of_day: int, slot_minutes: int) -> int:
    # Map minute to contiguous slot index excluding lunch
    if minute_of_day < MINUTES_OPEN_AM:
        return 0
    if MINUTES_OPEN_AM <= minute_of_day < MINUTES_LUNCH_START:
        return (minute_of_day - MINUTES_OPEN_AM) // slot_minutes
    if MINUTES_LUNCH_START <= minute_of_day < MINUTES_LUNCH_END:
        return (MINUTES_LUNCH_START - MINUTES_OPEN_AM) // slot_minutes
    if MINUTES_LUNCH_END <= minute_of_day < MINUTES_CLOSE_PM:
        am_slots = (MINUTES_LUNCH_START - MINUTES_OPEN_AM) // slot_minutes
        return am_slots + (minute_of_day - MINUTES_LUNCH_END) // slot_minutes
    return ((MINUTES_LUNCH_START - MINUTES_OPEN_AM) + (MINUTES_CLOSE_PM - MINUTES_LUNCH_END)) // slot_minutes


@dataclass
class Patient:
    id: int
    scheduled_slot: Optional[int]  # None for walk-in
    arrival_time_min: Optional[int]  # None means not arrived yet
    is_late: bool = False


class ClinicSchedulingEnv(gym.Env):
    metadata = {"render.modes": ["human"]}

    def __init__(self,
                 slot_minutes: int = 10,
                 max_scheduled: int = MAX_SCHEDULED_PER_DAY,
                 max_walkin_queue: int = MAX_WALKIN_QUEUE,
                 no_show_prob: float = 0.05,
                 late_prob: float = 0.1,
                 walkin_rate_per_hour: float = 8.0,
                 walkin_cutoff_minute: Optional[int] = None,
                 day_of_week: Optional[int] = None,
                 seed: Optional[int] = None):
        super().__init__()
        self.slot_minutes = slot_minutes
        self.slots_per_day = WORK_MINUTES // slot_minutes
        self.max_scheduled = min(max_scheduled, MAX_SCHEDULED_PER_DAY)
        self.max_walkin_queue = max_walkin_queue
        self.no_show_prob = no_show_prob
        self.late_prob = late_prob
        self.walkin_rate_per_hour = walkin_rate_per_hour
        self.walkin_cutoff_minute = walkin_cutoff_minute or MINUTES_CLOSE_PM
        self._configured_day_of_week = day_of_week
        self.rng = np.random.default_rng(seed)

        # Action space: choose next source to serve
        # 0 = next scheduled on-time, 1 = next walk-in, 2 = recall-priority late (if any)
        self.action_space = spaces.Discrete(3)

        # Observation space (compact):
        # [current_slot_index, scheduled_remaining, walkin_queue_len, late_list_len, next_scheduled_on_site(0/1), time_to_next_arrival_minutes]
        high = np.array([
            self.slots_per_day,
            MAX_SCHEDULED_PER_DAY,
            self.max_walkin_queue,
            MAX_SCHEDULED_PER_DAY,
            1,
            60
        ], dtype=np.float32)
        self.observation_space = spaces.Box(low=0.0, high=high, dtype=np.float32)

        self.reset_state()

    def reset_state(self):
        self.minute = MINUTES_OPEN_AM
        self.current_slot = 0
        # self.day_of_week is set in reset(); don't override here
        self.scheduled: Dict[int, Patient] = {}
        self.walkin_queue: list[Patient] = []
        self.late_list: list[Patient] = []
        self.served_ids: list[int] = []
        self.served_log: list[Dict[str, Any]] = []
        self.generated_patients: Dict[int, Patient] = {}
        self._generate_day_schedule()

    def _generate_day_schedule(self):
        # Pre-generate scheduled patients across slots (max 60)
        max_slots = self.slots_per_day
        chosen_slots = self.rng.choice(max_slots, size=min(self.max_scheduled, max_slots), replace=False)
        pid = 1
        for slot in sorted(chosen_slots.tolist()):
            # arrival: on time or late or no-show
            ontime = self.rng.random() > self.late_prob
            if self.rng.random() < self.no_show_prob:
                arrival = None
            else:
                if ontime:
                    # arrive within slot's first 5 minutes
                    slot_minute = self._slot_to_minute(slot)
                    jitter = int(self.rng.integers(0, min(5, self.slot_minutes)))
                    arrival = slot_minute + jitter
                else:
                    # late: arrive between +5 and +60 minutes later
                    base = self._slot_to_minute(slot) + 5
                    arrival = min(base + int(self.rng.integers(0, 60)), MINUTES_CLOSE_PM - 1)
            p = Patient(id=pid, scheduled_slot=slot, arrival_time_min=arrival)
            self.scheduled[slot] = p
            self.generated_patients[pid] = p
            pid += 1
        self.next_walkin_id = pid

    def _slot_to_minute(self, slot_index: int) -> int:
        am_slots = (MINUTES_LUNCH_START - MINUTES_OPEN_AM) // self.slot_minutes
        if slot_index < am_slots:
            return MINUTES_OPEN_AM + slot_index * self.slot_minutes
        else:
            return MINUTES_LUNCH_END + (slot_index - am_slots) * self.slot_minutes

    def _poisson(self, lam):
        # simple Poisson sampler via numpy
        return self.rng.poisson(lam)

    def _maybe_generate_walkins(self):
        if not is_open_minute(self.minute):
            return
        if self.minute >= self.walkin_cutoff_minute:
            return
        # per-minute rate
        lam = self.walkin_rate_per_hour / 60.0
        arrivals = self._poisson(lam)
        for _ in range(arrivals):
            p = Patient(id=self.next_walkin_id, scheduled_slot=None, arrival_time_min=self.minute)
            self.next_walkin_id += 1
            if len(self.walkin_queue) < self.max_walkin_queue:
                self.walkin_queue.append(p)
                self.generated_patients[p.id] = p

    def _update_late_status(self):
        # mark scheduled as late if slot passed and not on-site
        for slot, p in list(self.scheduled.items()):
            slot_minute = self._slot_to_minute(slot)
            if p.arrival_time_min is None:
                continue  # no-show remains
            if p.arrival_time_min > self.minute and self.minute >= slot_minute:
                p.is_late = True
                # move to late list if not already served and past slot
                if p not in self.late_list and p.id not in self.served_ids and self.minute >= slot_minute:
                    self.late_list.append(p)

        # move any arrived late patients out of late_list priority bucket if they just arrived now
        for p in self.late_list:
            if p.arrival_time_min is not None and p.arrival_time_min <= self.minute:
                p.is_late = True  # keep flag

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
        super().reset(seed=seed)
        if seed is not None:
            self.rng = np.random.default_rng(seed)
        # set day-of-week (exclude Sunday by default)
        if options and "day_of_week" in options:
            self.day_of_week = int(options["day_of_week"])
        else:
            if self._configured_day_of_week is not None:
                self.day_of_week = int(self._configured_day_of_week)
            else:
                self.day_of_week = int(self.rng.integers(0, 6))
                if self.day_of_week not in DAYS_OPEN:
                    # force into open set
                    self.day_of_week = int(self.rng.integers(0, 6))
        # If Sunday (closed), start and immediately end episode by returning a terminal state
        self.reset_state()
        obs = self._get_obs()
        info = {}
        return obs, info

    def _get_obs(self):
        # next scheduled slot that is not served yet
        remaining_slots = [s for s, p in self.scheduled.items() if p.id not in self.served_ids]
        scheduled_remaining = len(remaining_slots)

        # check if the next scheduled patient is on-site now
        next_slot = min(remaining_slots) if remaining_slots else None
        on_site = 0
        if next_slot is not None:
            p = self.scheduled[next_slot]
            if p.arrival_time_min is not None and p.arrival_time_min <= self.minute:
                on_site = 1

        # estimate time to next arrival (scheduled or walk-in), clipped
        next_arrival = 999
        # next scheduled arrival
        sched_arrivals = [p.arrival_time_min for p in self.scheduled.values() if p.arrival_time_min is not None and p.arrival_time_min > self.minute]
        if sched_arrivals:
            next_arrival = min(next_arrival, min(sched_arrivals) - self.minute)
        # approximate next walk-in arrival as inverse rate
        if is_open_minute(self.minute) and self.walkin_rate_per_hour > 0 and self.minute < self.walkin_cutoff_minute:
            expected_walkin_gap = max(1, int(60.0 / self.walkin_rate_per_hour))
            next_arrival = min(next_arrival, expected_walkin_gap)
        next_arrival = int(min(next_arrival, 60))

        obs = np.array([
            minute_to_slot(self.minute, self.slot_minutes),
            scheduled_remaining,
            len(self.walkin_queue),
            len(self.late_list),
            on_site,
            next_arrival
        ], dtype=np.float32)
        return obs

    def step(self, action: int):
        # simulate minute-by-minute until either we serve someone or day ends
        reward = 0.0
        info: Dict[str, Any] = {}

        # At each step, first update arrivals
        self._maybe_generate_walkins()
        self._update_late_status()

        # determine candidate queues based on action
        # 2: late priority (admin recalls) if available
        # 0: scheduled on-time if on-site
        # 1: walk-in otherwise
        served_patient: Optional[Patient] = None
        served_source = None
        served_via_recall = False

        if action == 2 and self.late_list:
            served_patient = self.late_list.pop(0)
            served_source = "late_recall"
            served_via_recall = True
        else:
            # next scheduled
            remaining_slots = sorted([s for s, p in self.scheduled.items() if p.id not in self.served_ids])
            if action == 0 and remaining_slots:
                next_slot = remaining_slots[0]
                p = self.scheduled[next_slot]
                if p.arrival_time_min is not None and p.arrival_time_min <= self.minute:
                    served_patient = p
                    served_source = "scheduled"
            if served_patient is None and self.walkin_queue:
                served_patient = self.walkin_queue.pop(0)
                served_source = "walkin"

        # apply serving and time advance
        if served_patient is not None:
            self.served_ids.append(served_patient.id)
            # reward per patient served; small bonus for serving scheduled/late via recall
            reward += 1.0
            if served_source in ("scheduled", "late_recall"):
                reward += 0.05
            # compute wait
            served_start_minute = self.minute
            arrival = served_patient.arrival_time_min
            wait = None
            if arrival is not None:
                wait = max(0, served_start_minute - arrival)
            # log
            self.served_log.append({
                "id": served_patient.id,
                "served_start_minute": served_start_minute,
                "scheduled_slot": served_patient.scheduled_slot,
                "arrival_time": arrival,
                "is_walkin": served_patient.scheduled_slot is None,
                "is_late": bool(served_patient.is_late),
                "served_via_recall": served_via_recall,
                "wait_minutes": wait,
                "source": served_source,
            })
            advance = self.slot_minutes
        else:
            # idle minute penalty for waiting while queues exist
            advance = 1
            reward -= 0.01

        # advance time respecting lunch/closed periods
        self.minute += advance
        if MINUTES_LUNCH_START <= self.minute < MINUTES_LUNCH_END:
            self.minute = MINUTES_LUNCH_END
        done = self.minute >= MINUTES_CLOSE_PM

        obs = self._get_obs()

        # small penalty for leaving late list unserved to encourage recalls
        reward -= 0.001 * len(self.late_list)

        # Encourage finishing scheduled by end of day
        if done:
            remaining_sched = len([p for p in self.scheduled.values() if p.id not in self.served_ids])
            if remaining_sched > 0:
                reward -= 0.1 * remaining_sched

        return obs, reward, done, False, info

    def render(self):
        print({
            "time": self.minute,
            "served": len(self.served_ids),
            "walkin_q": len(self.walkin_queue),
            "late": len(self.late_list)
        })

In [5]:
!pip -q install stable-baselines3==2.3.2 gymnasium==0.29.1 shimmy==1.3.0 plotly==5.24.1 ipywidgets

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/182.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m182.3/182.3 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/953.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.9/953.9 kB[0m [31m30.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m50.1 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
dopamine-rl 4.1.2 requires gymnasium>=1.0.0, but you have gymnasium 0.29.1 which is incompatible.[0m[31m


In [6]:
# Quick smoke test of env dynamics (7-min slots)
env = ClinicSchedulingEnv(slot_minutes=7, seed=42)
obs, info = env.reset()
print("obs shape:", obs.shape, "obs:", obs)
for _ in range(5):
    a = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(a)
    print({"a": int(a), "r": float(reward), "done": bool(terminated or truncated), "obs": obs.tolist()})

obs shape: (6,) obs: [ 0. 60.  0.  0.  0.  1.]
{'a': 1, 'r': -0.011, 'done': False, 'obs': [0.0, 60.0, 0.0, 1.0, 1.0, 7.0]}
{'a': 2, 'r': 1.05, 'done': False, 'obs': [1.0, 59.0, 0.0, 0.0, 0.0, 7.0]}
{'a': 0, 'r': -0.011, 'done': False, 'obs': [1.0, 59.0, 0.0, 1.0, 0.0, 7.0]}
{'a': 0, 'r': 0.999, 'done': False, 'obs': [2.0, 59.0, 0.0, 1.0, 0.0, 7.0]}
{'a': 2, 'r': 1.0490000000000002, 'done': False, 'obs': [3.0, 58.0, 0.0, 1.0, 0.0, 2.0]}


In [7]:
# Training: PPO on the scheduling environment (7-min slots)
import os
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

log_dir = "/content/logs" if os.path.exists("/content") else "/workspace/logs"
os.makedirs(log_dir, exist_ok=True)

# make_vec_env handles Monitor and Gymnasium compatibility wrappers
vec_env = make_vec_env(lambda: ClinicSchedulingEnv(slot_minutes=7), n_envs=1, monitor_dir=log_dir)

model = PPO(
    "MlpPolicy",
    vec_env,
    verbose=1,
    learning_rate=3e-4,
    n_steps=2048,
    batch_size=256,
    n_epochs=10,
    gamma=0.995,
    gae_lambda=0.95,
)

# Train
timesteps = 200_000
model.learn(total_timesteps=timesteps)

# Save
model_path = os.path.join(log_dir, "ppo_clinic_scheduling")
model.save(model_path)
print("Saved model to", model_path)

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


Using cpu device


  return datetime.utcnow().replace(tzinfo=utc)


---------------------------------
| rollout/           |          |
|    ep_len_mean     | 110      |
|    ep_rew_mean     | 51.2     |
| time/              |          |
|    fps             | 1564     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------


  return datetime.utcnow().replace(tzinfo=utc)


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 111         |
|    ep_rew_mean          | 51.1        |
| time/                   |             |
|    fps                  | 1426        |
|    iterations           | 2           |
|    time_elapsed         | 2           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.010848872 |
|    clip_fraction        | 0.085       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.09       |
|    explained_variance   | 0.07424039  |
|    learning_rate        | 0.0003      |
|    loss                 | 10.5        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.00563    |
|    value_loss           | 36          |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 111   

In [7]:
# Evaluation and visualization
import numpy as np
import pandas as pd
import plotly.express as px
from stable_baselines3 import PPO

# load
loaded = PPO.load(model_path)

def run_episode(env_seed=None):
    env = ClinicSchedulingEnv(slot_minutes=10, seed=env_seed)
    obs, info = env.reset()
    done = False
    events = []
    t = 0
    while not done:
        action, _ = loaded.predict(obs, deterministic=True)
        prev_min = env.minute
        prev_walk = len(env.walkin_queue)
        prev_late = len(env.late_list)
        obs, reward, terminated, truncated, info = env.step(int(action))
        done = terminated or truncated
        events.append({
            "t": t,
            "minute": prev_min,
            "action": int(action),
            "reward": float(reward),
            "served": len(env.served_ids),
            "walkin_q": prev_walk,
            "late": prev_late,
        })
        t += 1
    return pd.DataFrame(events)

df = run_episode(env_seed=123)
fig = px.line(df, x="minute", y=["served", "walkin_q", "late"], title="Clinic metrics over time")
fig.show()


datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



In [8]:
# Improved evaluation: metrics and charts (7-min slots)
import numpy as np
import pandas as pd
import plotly.express as px
from stable_baselines3 import PPO

# Load trained model
loaded = PPO.load(model_path)


def run_episode(env_seed=None, walkin_cutoff_minute=None):
    env = ClinicSchedulingEnv(slot_minutes=7, seed=env_seed, walkin_cutoff_minute=walkin_cutoff_minute)
    obs, info = env.reset()
    done = False
    timeline = []
    step_idx = 0
    while not done:
        action, _ = loaded.predict(obs, deterministic=True)
        prev_state = {
            "step": step_idx,
            "minute": env.minute,
            "walkin_q": len(env.walkin_queue),
            "late": len(env.late_list),
            "served": len(env.served_ids),
            "action": int(action),
        }
        obs, reward, terminated, truncated, _ = env.step(int(action))
        done = bool(terminated or truncated)
        prev_state["reward"] = float(reward)
        timeline.append(prev_state)
        step_idx += 1
    # Build DataFrames
    timeline_df = pd.DataFrame(timeline)
    served_df = pd.DataFrame(env.served_log)
    return env, timeline_df, served_df


# Run one evaluation episode
env, timeline_df, served_df = run_episode(env_seed=123)

# Summary metrics
scheduled_total = len(env.scheduled)
served_total = len(served_df)
served_scheduled = int((~served_df["is_walkin"]).sum()) if not served_df.empty else 0
served_walkin = int((served_df["is_walkin"]).sum()) if not served_df.empty else 0
served_via_recall = int((served_df["served_via_recall"]).sum()) if not served_df.empty else 0
avg_wait_scheduled = float(served_df.loc[~served_df["is_walkin"], "wait_minutes"].dropna().mean()) if served_scheduled > 0 else np.nan
avg_wait_walkin = float(served_df.loc[served_df["is_walkin"], "wait_minutes"].dropna().mean()) if served_walkin > 0 else np.nan

print("Scheduled total:", scheduled_total)
print("Served total:", served_total, "(scheduled:", served_scheduled, ", walk-in:", served_walkin, ")")
print("Late recalls served:", served_via_recall)
print("Avg wait (scheduled):", avg_wait_scheduled)
print("Avg wait (walk-in):", avg_wait_walkin)
print("Remaining late list:", len(env.late_list))
print("Remaining walk-in queue:", len(env.walkin_queue))

# Charts
fig1 = px.line(timeline_df, x="minute", y=["served", "walkin_q", "late"], title="Queues and served over time")
fig1.show()

if not served_df.empty and served_df["wait_minutes"].notna().any():
    fig2 = px.histogram(served_df.dropna(subset=["wait_minutes"]), x="wait_minutes", nbins=30, title="Wait time distribution (minutes)")
    fig2.show()
else:
    print("No served patients to plot wait distribution.")

Scheduled total: 60
Served total: 56 (scheduled: 43 , walk-in: 13 )
Late recalls served: 30
Avg wait (scheduled): 14.627906976744185
Avg wait (walk-in): 28.0
Remaining late list: 0
Remaining walk-in queue: 7


  return datetime.utcnow().replace(tzinfo=utc)

datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).




datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).




datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



## 7-minute availability booking (user selects an hour)

This section adds a simple booking planner so users can pick an hour label based on availability. Each hour has 7-minute sub-slots:
- Hours: 08, 09, 10, 11, 13, 14, 15 (12:00–13:00 lunch, closed; 16:00 close)
- Capacity rules per hour (7-min consult):
  - Normal hours (08, 09, 10, 13, 14): 9 bookings (last may bleed a few minutes into next hour)
  - Boundary hours (11 before lunch, 15 before day close): 8 bookings (must finish by 12:00 and 16:00 respectively)
- Example behavior: After 9 bookings at 08:00, availability shows 09–12 and 13–16.

## Env V2: Multi‑provider, variable service times, action masking

Enhancements:
- Multiple providers can serve patients in parallel (concurrent servers)
- Variable service durations (lognormal around 7 minutes, configurable)
- Action masking for invalid choices (late recall only if late patient has arrived; etc.)
- Optional seeding of scheduled appointments from the 7‑min `BookingPlanner`
- Richer logs and provider utilization metrics

In [9]:
import numpy as np
import pandas as pd
import gymnasium as gym
from gymnasium import spaces
from typing import Optional, Dict, Any, List, Tuple
from collections import defaultdict

class ClinicSchedulingEnvV2(gym.Env):
    """Multi-provider scheduling with variable service times and action masking."""
    metadata = {"render.modes": ["human"]}

    def __init__(self,
                 slot_minutes: int = 7,
                 num_providers: int = 2,
                 service_mean_min: float = 7.0,
                 service_sigma_min: float = 2.0,
                 walkin_rate_per_hour: float = 8.0,
                 no_show_prob: float = 0.05,
                 late_prob: float = 0.1,
                 walkin_cutoff_minute: Optional[int] = None,
                 seed: Optional[int] = None,
                 seeded_schedule: Optional[List[Tuple[int, int]]] = None  # list of (patient_id, start_minute)
                 ):
        super().__init__()
        self.slot_minutes = slot_minutes
        self.num_providers = num_providers
        self.service_mean_min = service_mean_min
        self.service_sigma_min = service_sigma_min
        self.walkin_rate_per_hour = walkin_rate_per_hour
        self.no_show_prob = no_show_prob
        self.late_prob = late_prob
        self.walkin_cutoff_minute = walkin_cutoff_minute or MINUTES_CLOSE_PM
        self.rng = np.random.default_rng(seed)

        # Action space: 0 scheduled, 1 walk-in, 2 late recall
        self.action_space = spaces.Discrete(3)
        # Observation: [minute_slot, walkin_q, late_len, scheduled_remaining, free_providers, mask0, mask1, mask2]
        high = np.array([
            WORK_MINUTES // self.slot_minutes,
            500,
            MAX_SCHEDULED_PER_DAY,
            MAX_SCHEDULED_PER_DAY,
            self.num_providers,
            1, 1, 1
        ], dtype=np.float32)
        self.observation_space = spaces.Box(low=0.0, high=high, dtype=np.float32)

        # Providers state: remaining service time if busy else 0
        self.provider_busy_remaining: List[int] = [0 for _ in range(self.num_providers)]

        # Queues
        self.walkin_queue: List[Patient] = []
        self.late_list: List[Patient] = []
        self.scheduled_slots: Dict[int, List[Patient]] = defaultdict(list)  # slot index -> patients
        self.served_ids: List[int] = []
        self.served_log: List[Dict[str, Any]] = []
        self.generated_patients: Dict[int, Patient] = {}
        self.minute = MINUTES_OPEN_AM
        self.next_walkin_id = 1
        self.seeded_schedule = seeded_schedule
        self._build_scheduled_from_seed()

    def _build_scheduled_from_seed(self):
        self.scheduled_slots.clear()
        pid = 1
        if self.seeded_schedule:
            for pid_seed, start_min in self.seeded_schedule:
                slot = minute_to_slot(start_min, self.slot_minutes)
                arrival = start_min  # assume on-time unless randomized below
                if self.rng.random() < self.no_show_prob:
                    arrival = None
                else:
                    # lateness
                    if self.rng.random() < self.late_prob:
                        arrival = min(start_min + int(self.rng.integers(5, 30)), MINUTES_CLOSE_PM - 1)
                p = Patient(id=pid_seed, scheduled_slot=slot, arrival_time_min=arrival)
                self.scheduled_slots[slot].append(p)
                self.generated_patients[pid_seed] = p
                pid = max(pid, pid_seed + 1)
        self.next_walkin_id = pid

    def _service_duration(self) -> int:
        # lognormal-ish clamp around mean
        val = max(1.0, self.rng.lognormal(mean=np.log(max(1e-6, self.service_mean_min)), sigma=self.service_sigma_min / max(1.0, self.service_mean_min)))
        return int(max(1, round(val)))

    def _maybe_generate_walkins(self):
        if not is_open_minute(self.minute) or self.minute >= self.walkin_cutoff_minute:
            return
        lam = self.walkin_rate_per_hour / 60.0
        arrivals = self.rng.poisson(lam)
        for _ in range(arrivals):
            p = Patient(id=self.next_walkin_id, scheduled_slot=None, arrival_time_min=self.minute)
            self.walkin_queue.append(p)
            self.generated_patients[p.id] = p
            self.next_walkin_id += 1

    def _update_late_status(self):
        for slot, patients in list(self.scheduled_slots.items()):
            slot_minute = self._slot_to_minute(slot)
            for p in patients:
                if p.arrival_time_min is None:
                    continue
                if p.arrival_time_min > self.minute and self.minute >= slot_minute:
                    p.is_late = True
                    if p not in self.late_list and p.id not in self.served_ids and self.minute >= slot_minute:
                        self.late_list.append(p)

    def _slot_to_minute(self, slot_index: int) -> int:
        am_slots = (MINUTES_LUNCH_START - MINUTES_OPEN_AM) // self.slot_minutes
        if slot_index < am_slots:
            return MINUTES_OPEN_AM + slot_index * self.slot_minutes
        else:
            return MINUTES_LUNCH_END + (slot_index - am_slots) * self.slot_minutes

    def _mask(self) -> np.ndarray:
        mask = np.array([0, 0, 0], dtype=np.int8)
        # scheduled on-site available?
        slot, patient = self._next_on_site_patient()
        if patient is not None:
            mask[0] = 1
        # walk-in available?
        if len(self.walkin_queue) > 0:
            mask[1] = 1
        # late recall available (arrived late and in list)?
        if any(p.arrival_time_min is not None and p.arrival_time_min <= self.minute for p in self.late_list):
            mask[2] = 1
        # if no providers free, no actions available
        if self.free_providers() == 0:
            mask[:] = 0
        return mask

    def _remaining_scheduled(self):
        for slot in sorted(self.scheduled_slots.keys()):
            for patient in self.scheduled_slots[slot]:
                if patient.id not in self.served_ids:
                    yield slot, patient

    def _remaining_scheduled_count(self) -> int:
        return sum(1 for _ in self._remaining_scheduled())

    def _next_on_site_patient(self):
        for slot, patient in self._remaining_scheduled():
            if patient.arrival_time_min is not None and patient.arrival_time_min <= self.minute:
                return slot, patient
        return None, None

    def free_providers(self) -> int:
        return sum(1 for t in self.provider_busy_remaining if t <= 0)

    def _get_obs(self):
        obs = np.array([
            minute_to_slot(self.minute, self.slot_minutes),
            len(self.walkin_queue),
            len(self.late_list),
            self._remaining_scheduled_count(),
            self.free_providers(),
            *self._mask().tolist()
        ], dtype=np.float32)
        return obs

    def step(self, action: int):
        reward = 0.0
        info: Dict[str, Any] = {}

        # process arrivals and late updates
        self._maybe_generate_walkins()
        self._update_late_status()

        served_patients: List[Patient] = []
        mask = self._mask()
        # serve up to number of free providers, consistent with intended action preference
        to_serve = self.free_providers()
        for _ in range(to_serve):
            candidate: Optional[Patient] = None
            if action == 2 and mask[2]:
                # serve arrived late first
                for i, p in enumerate(self.late_list):
                    if p.arrival_time_min is not None and p.arrival_time_min <= self.minute:
                        candidate = self.late_list.pop(i)
                        break
            if candidate is None and action == 0 and mask[0]:
                _, on_site = self._next_on_site_patient()
                if on_site is not None:
                    candidate = on_site
            if candidate is None and action in (1, 0, 2) and mask[1] and self.walkin_queue:
                candidate = self.walkin_queue.pop(0)
            if candidate is not None:
                self.served_ids.append(candidate.id)
                served_patients.append(candidate)
                # assign to a provider
                for i in range(self.num_providers):
                    if self.provider_busy_remaining[i] <= 0:
                        self.provider_busy_remaining[i] = self._service_duration()
                        break

        # reward and logs
        reward += 1.0 * len(served_patients)
        for p in served_patients:
            wait = None
            if p.arrival_time_min is not None:
                wait = max(0, self.minute - p.arrival_time_min)
            self.served_log.append({
                "id": p.id,
                "served_minute": self.minute,
                "arrival_time": p.arrival_time_min,
                "is_walkin": p.scheduled_slot is None,
                "is_late": bool(p.is_late),
                "wait_minutes": wait,
            })
            if p.scheduled_slot is not None:
                reward += 0.05

        # idle penalty if queues but no serve due to mask or no providers free
        if len(served_patients) == 0 and (self.walkin_queue or self.late_list or self._remaining_scheduled_count() > 0):
            reward -= 0.01

        # advance time by 1 minute; decrement providers
        self.minute += 1
        for i in range(self.num_providers):
            if self.provider_busy_remaining[i] > 0:
                self.provider_busy_remaining[i] -= 1
        if MINUTES_LUNCH_START <= self.minute < MINUTES_LUNCH_END:
            self.minute = MINUTES_LUNCH_END
        done = self.minute >= MINUTES_CLOSE_PM

        # terminal penalty for unserved scheduled
        if done:
            remaining_sched = self._remaining_scheduled_count()
            reward -= 0.1 * remaining_sched

        obs = self._get_obs()
        return obs, reward, done, False, info

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
        super().reset(seed=seed)
        if seed is not None:
            self.rng = np.random.default_rng(seed)
        self.minute = MINUTES_OPEN_AM
        self.provider_busy_remaining = [0 for _ in range(self.num_providers)]
        self.walkin_queue = []
        self.late_list = []
        self.served_ids = []
        self.served_log = []
        self.generated_patients = {}
        self._build_scheduled_from_seed()
        return self._get_obs(), {}

    def render(self):
        print({
            "minute": self.minute,
            "free_providers": self.free_providers(),
            "walkin_q": len(self.walkin_queue),
            "late": len(self.late_list)
        })

In [10]:
# Calibration/config cell for Env V2
CALIB = {
    "slot_minutes": 7,
    "num_providers": 2,
    "service_mean_min": 7.0,
    "service_sigma_min": 2.0,
    "walkin_rate_per_hour": 10.0,
    "no_show_prob": 0.07,
    "late_prob": 0.12,
    "walkin_cutoff_minute": MINUTES_CLOSE_PM,
}

# Example: seed scheduled patients from BookingPlanner (first N bookings at 8:00 then 9:00)
seed_schedule = []  # list of (patient_id, start_minute)
if "BookingPlanner" in globals():
    planner = BookingPlanner()
    patient_id = 1
    for hour in [8, 9]:
        for _ in range(5):  # first 5 bookings per hour for demo
            booked_label = planner.book(hour)
            if booked_label is None:
                break
            slot_index = planner.booked_count_by_hour[hour] - 1
            minute = hour * 60 + slot_index * planner.slot_minutes
            seed_schedule.append((patient_id, minute))
            patient_id += 1
else:
    print("BookingPlanner not yet defined; using deterministic slot seeding below.")

# Fallback: deterministic seed schedule at minute marks if planner wasn't available or yielded no seeds
if not seed_schedule:
    seed_schedule = [(i + 1, 8 * 60 + i * CALIB["slot_minutes"]) for i in range(9)]  # 8:00 to ~8:56
    seed_schedule += [(10 + i, 9 * 60 + i * CALIB["slot_minutes"]) for i in range(5)]

print("Seeds (first 3):", seed_schedule[:3])

BookingPlanner not yet defined; using deterministic slot seeding below.
Seeds (first 3): [(1, 480), (2, 487), (3, 494)]


In [11]:
# Training with MaskablePPO if available, else PPO
from stable_baselines3 import PPO
try:
    from sb3_contrib import MaskablePPO
    from sb3_contrib.common.wrappers import ActionMasker
    MASKABLE = True
except Exception:
    MASKABLE = False
    ActionMasker = None

from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback
import os

log_dir_v2 = "/content/logs_v2" if os.path.exists("/content") else "/workspace/logs_v2"
os.makedirs(log_dir_v2, exist_ok=True)

# build env factory with seeded schedule

def make_env_v2():
    return ClinicSchedulingEnvV2(
        slot_minutes=CALIB["slot_minutes"],
        num_providers=CALIB["num_providers"],
        service_mean_min=CALIB["service_mean_min"],
        service_sigma_min=CALIB["service_sigma_min"],
        walkin_rate_per_hour=CALIB["walkin_rate_per_hour"],
        no_show_prob=CALIB["no_show_prob"],
        late_prob=CALIB["late_prob"],
        walkin_cutoff_minute=CALIB["walkin_cutoff_minute"],
        seeded_schedule=seed_schedule,
    )


def _mask_fn(env: ClinicSchedulingEnvV2):
    return env._mask().astype(bool)


def make_env_v2_masked():
    base = make_env_v2()
    return ActionMasker(base, _mask_fn)

env_factory = make_env_v2_masked if MASKABLE and ActionMasker is not None else make_env_v2

vec_env_v2 = make_vec_env(env_factory, n_envs=1, monitor_dir=log_dir_v2)

if MASKABLE:
    model_v2 = MaskablePPO(
        "MlpPolicy",
        vec_env_v2,
        verbose=1,
        learning_rate=3e-4,
        n_steps=2048,
        batch_size=256,
        n_epochs=10,
        gamma=0.995,
        gae_lambda=0.95,
    )
else:
    model_v2 = PPO(
        "MlpPolicy",
        vec_env_v2,
        verbose=1,
        learning_rate=3e-4,
        n_steps=2048,
        batch_size=256,
        n_epochs=10,
        gamma=0.995,
        gae_lambda=0.95,
    )

# Optional evaluation callback
# eval_env_v2 = make_vec_env(make_env_v2, n_envs=1)
# eval_cb = EvalCallback(eval_env_v2, best_model_save_path=log_dir_v2, log_path=log_dir_v2, eval_freq=10_000)

timesteps_v2 = 200_000
model_v2.learn(total_timesteps=timesteps_v2)  # , callback=eval_cb)

model_v2_path = os.path.join(log_dir_v2, "model_v2")
model_v2.save(model_v2_path)
print("Saved V2 model to", model_v2_path)

Using cpu device
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 420      |
|    ep_rew_mean     | 80.9     |
| time/              |          |
|    fps             | 1544     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 420         |
|    ep_rew_mean          | 86          |
| time/                   |             |
|    fps                  | 914         |
|    iterations           | 2           |
|    time_elapsed         | 4           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.010026465 |
|    clip_fraction        | 0.0397      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.09       |
|    explained_variance   | -0.2419256  |
|    learning

In [12]:
# Evaluation for Env V2
import pandas as pd
from statistics import mean

loaded_v2 = None
try:
    from sb3_contrib import MaskablePPO
    loaded_v2 = MaskablePPO.load(model_v2_path)
except Exception:
    from stable_baselines3 import PPO
    loaded_v2 = PPO.load(model_v2_path)


def evaluate_v2(episodes=3):
    metrics = []
    for ep in range(episodes):
        env = make_env_v2()
        obs, info = env.reset()
        done = False
        while not done:
            action, _ = loaded_v2.predict(obs, deterministic=True)
            obs, reward, terminated, truncated, _ = env.step(int(action))
            done = bool(terminated or truncated)
        # compute metrics
        served_df = pd.DataFrame(env.served_log)
        scheduled_served = int((served_df["is_walkin"] == False).sum()) if not served_df.empty else 0
        walkin_served = int((served_df["is_walkin"] == True).sum()) if not served_df.empty else 0
        avg_wait = float(served_df["wait_minutes"].dropna().mean()) if not served_df.empty else float("nan")
        util = 1.0 - (sum(1 for t in env.provider_busy_remaining if t <= 0) / env.num_providers)
        metrics.append({
            "scheduled_served": scheduled_served,
            "walkin_served": walkin_served,
            "avg_wait": avg_wait,
            "late_remaining": len(env.late_list),
            "walkins_remaining": len(env.walkin_queue),
            "providers": env.num_providers,
        })
    return pd.DataFrame(metrics)

m = evaluate_v2(episodes=3)
print(m.describe(include="all"))

       scheduled_served  walkin_served  avg_wait  late_remaining  \
count               3.0       3.000000  3.000000        3.000000   
mean               14.0      70.000000  4.117820        5.333333   
std                 0.0       4.358899  2.114557        0.577350   
min                14.0      67.000000  2.444444        5.000000   
25%                14.0      67.500000  2.929539        5.000000   
50%                14.0      68.000000  3.414634        5.000000   
75%                14.0      71.500000  4.954508        5.500000   
max                14.0      75.000000  6.494382        6.000000   

       walkins_remaining  providers  
count           3.000000        3.0  
mean            0.666667        2.0  
std             1.154701        0.0  
min             0.000000        2.0  
25%             0.000000        2.0  
50%             0.000000        2.0  
75%             1.000000        2.0  
max             2.000000        2.0  



datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



## Data-driven calibration from CSV logs
Provide CSVs with historical data to fit key parameters:
- arrivals.csv: columns [timestamp, type(scheduled|walkin), booked_minute, arrival_minute, no_show(0/1), late(0/1)]
- service.csv: columns [timestamp, provider_id, service_minutes]

This cell loads CSVs (if provided) from Colab/Drive or local path and estimates walk-in hourly rates by time-of-day, no-show and late probabilities, and service-time distribution parameters. Fallbacks are used if files are absent.

In [15]:
# Calibration from CSV logs (optional)
import pandas as pd
import numpy as np
from math import log

# Set these paths in Colab or mount Drive
ARRIVALS_CSV = '/content/drive/MyDrive/clinic_logs/arrival.csv'
SERVICE_CSV = '/content/drive/MyDrive/clinic_logs/service_time.csv'

calib = CALIB.copy()
try:
    if ARRIVALS_CSV:
        arr = pd.read_csv(ARRIVALS_CSV)
        # Time-of-day walk-in rate estimation
        arr_walk = arr[arr['type'].str.lower() == 'walkin'].copy()
        arr_walk['hour'] = (arr_walk['arrival_minute'] // 60).astype(int)
        hourly_counts = arr_walk.groupby('hour').size()
        # default to mean if missing hours
        default_rate = hourly_counts.mean() if len(hourly_counts) else calib['walkin_rate_per_hour']
        walkin_rate_by_hour = {h: hourly_counts.get(h, default_rate) for h in [8,9,10,11,13,14,15]}
        # normalize to per-hour arrival rate
        calib['walkin_rate_per_hour'] = float(np.mean(list(walkin_rate_by_hour.values())))
        # No-show and late
        sched = arr[arr['type'].str.lower() == 'scheduled']
        if len(sched) > 0:
            calib['no_show_prob'] = float((sched['no_show'] == 1).mean())
            calib['late_prob'] = float((sched['late'] == 1).mean())
    if SERVICE_CSV:
        svc = pd.read_csv(SERVICE_CSV)
        durations = svc['service_minutes'].dropna().values
        if len(durations) > 5:
            calib['service_mean_min'] = float(np.mean(durations))
            # set sigma as std/mean capped
            std = float(np.std(durations))
            calib['service_sigma_min'] = float(min(std, calib['service_mean_min']))
    print('Calibrated parameters:', calib)
except Exception as e:
    print('Calibration failed, using defaults. Reason:', e)
    calib = CALIB.copy()

Calibration failed, using defaults. Reason: [Errno 2] No such file or directory: '/content/drive/MyDrive/clinic_logs/arrival.csv'


## Rule baseline and safety wrapper
- Rule baseline: a straightforward heuristic policy (serve priority order: late recall if arrived > scheduled on-site > walk-in). Supports multi-provider.
- OR-Tools (optional): If installed, can compute a simple re-optimization for remaining scheduled patients under time windows (demonstration-level).
- Safety wrapper: masks invalid actions, enforces hard constraints (no service during lunch/after close), and provides a deterministic fallback if the RL suggests an invalid action.

In [16]:
# Heuristic baseline and optional OR-Tools demo
from typing import Optional

class HeuristicPolicy:
    def select_action(self, env: ClinicSchedulingEnvV2) -> int:
        base_env = env.unwrapped if hasattr(env, "unwrapped") else env
        if not hasattr(base_env, "_mask"):
            raise AttributeError("Underlying environment must expose _mask()")
        mask = base_env._mask()
        if mask[2]:
            return 2  # recall late
        if mask[0]:
            return 0  # serve scheduled on-site
        if mask[1]:
            return 1  # serve walk-in
        return 1  # default (no-ops will idle)

# Safety wrapper for RL decisions
class SafePolicy:
    def __init__(self, env: ClinicSchedulingEnvV2, rl_model):
        self.env = env
        self.rl_model = rl_model
        self.fallback = HeuristicPolicy()

    def predict(self, obs, deterministic: bool = True) -> tuple[int, None]:
        # Allow same signature as SB3 models and return (action, None)
        try:
            action, _ = self.rl_model.predict(obs, deterministic=deterministic)
            action = int(action)
        except Exception:
            action = self.fallback.select_action(self.env)
        # mask invalid
        mask = self.env._mask()
        if action < 0 or action > 2 or mask[action] == 0:
            action = self.fallback.select_action(self.env)
        return action, None

# OR-Tools mini-demo (optional)
try:
    from ortools.sat.python import cp_model
    HAS_OR_TOOLS = True
except Exception:
    HAS_OR_TOOLS = False


def reoptimize_schedule_demo(env: ClinicSchedulingEnvV2):
    if not HAS_OR_TOOLS:
        print("OR-Tools not installed; skipping demo.")
        return None
    # Build a minimal CP-SAT to assign remaining scheduled patients to time slots after now
    model = cp_model.CpModel()
    remaining = [
        (slot, patient)
        for slot, patients in env.scheduled_slots.items()
        for patient in patients
        if patient.id not in env.served_ids
    ]
    if not remaining:
        return None
    slots = [s for s, _ in remaining]
    # binary vars: assign each to current slot or next few slots
    vars = {}
    for idx, (s, p) in enumerate(remaining):
        for delta in range(0, 5):
            ss = s + delta
            vars[(idx, ss)] = model.NewBoolVar(f"x_{idx}_{ss}")
        # each patient assigned once
        model.Add(sum(vars[(idx, s + d)] for d in range(0, 5)) == 1)
    # capacity: per slot up to number of providers
    for ss in range(min(slots), max(slots) + 5):
        model.Add(sum(vars[(idx, ss)] for idx, _ in enumerate(remaining) if (idx, ss) in vars) <= env.num_providers)
    model.Minimize(0)
    solver = cp_model.CpSolver()
    solver.parameters.max_time_in_seconds = 2.0
    solver.Solve(model)
    assignment = {idx: None for idx in range(len(remaining))}
    for key, var in vars.items():
        if solver.Value(var) == 1:
            idx, ss = key
            assignment[idx] = ss
    return assignment

## Off-policy evaluation (OPE) and backtesting
Given historical logs with state-action pairs and outcomes, estimate policy performance without deploying it:
- Replay logged trajectories to compute value under baseline policy.
- Importance sampling/weighted importance sampling (WIS) using a logged behavior policy if known.
- For simplicity here, we backtest the heuristic and RL policies on multiple simulated days and compare metrics to logs when available.

In [17]:
# Simple backtest runner comparing Heuristic vs RL (simulated OPE)
import pandas as pd

def run_policy(env_fn, policy, episodes=10):
    rows = []
    for ep in range(episodes):
        env = env_fn()
        obs, info = env.reset()
        done = False
        steps = 0
        served = 0
        while not done:
            if isinstance(policy, HeuristicPolicy):
                action = policy.select_action(env)
            else:
                action, _ = policy.predict(obs, deterministic=True)
                action = int(action)
            obs, reward, terminated, truncated, _ = env.step(action)
            done = bool(terminated or truncated)
            steps += 1
        served_df = pd.DataFrame(env.served_log)
        served_total = len(served_df)
        avg_wait = float(served_df['wait_minutes'].dropna().mean()) if not served_df.empty else float('nan')
        rows.append({
            'served_total': served_total,
            'avg_wait': avg_wait,
            'late_remaining': len(env.late_list),
            'walkins_remaining': len(env.walkin_queue),
        })
    return pd.DataFrame(rows)

heur = HeuristicPolicy()
heur_df = run_policy(make_env_v2, heur, episodes=5)
print('Heuristic summary:')
print(heur_df.describe())

safe_rl = SafePolicy(make_env_v2(), loaded_v2)
rl_df = run_policy(make_env_v2, safe_rl, episodes=5)
print('RL summary:')
print(rl_df.describe())

Heuristic summary:
       served_total   avg_wait  late_remaining  walkins_remaining
count      5.000000   5.000000             5.0           5.000000
mean      85.200000   7.452141             0.0           0.400000
std        5.540758   3.857161             0.0           0.894427
min       78.000000   3.956522             0.0           0.000000
25%       82.000000   5.376471             0.0           0.000000
50%       85.000000   5.743902             0.0           0.000000
75%       89.000000   8.487179             0.0           0.000000
max       92.000000  13.696629             0.0           2.000000



datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



RL summary:
       served_total   avg_wait  late_remaining  walkins_remaining
count      5.000000   5.000000         5.00000           5.000000
mean      85.600000   8.632792         6.80000           1.800000
std        6.618157   4.843861         1.30384           1.788854
min       79.000000   4.777778         6.00000           0.000000
25%       81.000000   6.298851         6.00000           0.000000
50%       85.000000   6.417722         6.00000           2.000000
75%       87.000000   8.752941         7.00000           3.000000
max       96.000000  16.916667         9.00000           4.000000



datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



## Domain randomization
During training, randomize environment parameters each episode to improve robustness (sim-to-real). This wrapper perturbs arrival rates, no-show/late probabilities, and service-time parameters within bounds.

In [18]:
# Domain randomization wrapper
class DomainRandomizedFactory:
    def __init__(self, base_calib: dict, ranges: dict):
        self.base = base_calib
        self.ranges = ranges

    def sample(self) -> dict:
        cfg = self.base.copy()
        rng = np.random.default_rng()
        for k, (low, high) in self.ranges.items():
            if isinstance(low, float) or isinstance(high, float):
                cfg[k] = float(rng.uniform(low, high))
            else:
                cfg[k] = int(rng.integers(low, high + 1))
        return cfg

    def make_env(self):
        cfg = self.sample()
        return ClinicSchedulingEnvV2(
            slot_minutes=cfg.get('slot_minutes', CALIB['slot_minutes']),
            num_providers=cfg.get('num_providers', CALIB['num_providers']),
            service_mean_min=cfg.get('service_mean_min', CALIB['service_mean_min']),
            service_sigma_min=cfg.get('service_sigma_min', CALIB['service_sigma_min']),
            walkin_rate_per_hour=cfg.get('walkin_rate_per_hour', CALIB['walkin_rate_per_hour']),
            no_show_prob=cfg.get('no_show_prob', CALIB['no_show_prob']),
            late_prob=cfg.get('late_prob', CALIB['late_prob']),
            walkin_cutoff_minute=cfg.get('walkin_cutoff_minute', CALIB['walkin_cutoff_minute']),
            seeded_schedule=seed_schedule,
        )

# Example ranges for randomization
RAND_RANGES = {
    'walkin_rate_per_hour': (6.0, 14.0),
    'no_show_prob': (0.03, 0.12),
    'late_prob': (0.05, 0.2),
    'service_mean_min': (6.0, 10.0),
    'service_sigma_min': (1.5, 3.0),
}

rand_factory = DomainRandomizedFactory(calib, RAND_RANGES)

# Train with randomization (short run)
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

vec_env_rand = make_vec_env(rand_factory.make_env, n_envs=1)
model_rand = PPO('MlpPolicy', vec_env_rand, verbose=1)
model_rand.learn(total_timesteps=100_000)
print('Domain-randomized model trained.')

Using cpu device
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 420      |
|    ep_rew_mean     | 57.7     |
| time/              |          |
|    fps             | 1610     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 420          |
|    ep_rew_mean          | 61.4         |
| time/                   |              |
|    fps                  | 1140         |
|    iterations           | 2            |
|    time_elapsed         | 3            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.008999734  |
|    clip_fraction        | 0.049        |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.09        |
|    explained_variance   | -0.056109667 

## Shadow-mode integration scaffold (FastAPI)
This example shows how to serve availability and decisions in shadow mode: the system returns rule-based decisions to operators while logging RL suggestions for comparison. In Colab, this is illustrative; deploy as a microservice in production.

In [19]:
# FastAPI shadow-mode scaffold (illustrative)
try:
    from fastapi import FastAPI
    from pydantic import BaseModel
    import uvicorn

    app = FastAPI()

    class DecisionRequest(BaseModel):
        obs: list

    class DecisionResponse(BaseModel):
        action_rule: int
        action_rl: int
        mask: list

    # Build single env for masking and a safe policy wrapper
    env_for_api = make_env_v2()
    safe_rl_policy = SafePolicy(env_for_api, loaded_v2)
    heuristic_policy = HeuristicPolicy()

    @app.get("/availability")
    def availability():
        # Return hour labels from the planner demo
        planner = BookingPlanner()
        return {"availability": planner.availability_label()}

    @app.post("/decide", response_model=DecisionResponse)
    def decide(req: DecisionRequest):
        obs = np.array(req.obs, dtype=np.float32)
        # Rule action
        action_rule = heuristic_policy.select_action(env_for_api)
        # RL suggestion (masked)
        action_rl = safe_rl_policy.predict(obs)
        return DecisionResponse(action_rule=action_rule, action_rl=action_rl, mask=env_for_api._mask().tolist())

    print("To run locally (not in Colab runtime):\nuvicorn main:app --reload")
except Exception as e:
    print("FastAPI not available or running in limited environment:", e)

To run locally (not in Colab runtime):
uvicorn main:app --reload



datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



## MaskablePPO with Action Masking + Reward Shaping
Train with action masking and a shaping wrapper to improve learning signal and safety. Requires `sb3-contrib` for MaskablePPO and ActionMasker.

In [20]:
# ActionMasker + Reward shaping wrappers
try:
    from sb3_contrib.ppo_mask import MaskablePPO
    from sb3_contrib.common.wrappers import ActionMasker
    HAS_MASKABLE = True
except Exception:
    HAS_MASKABLE = False

class RewardShapingWrapper(gym.Wrapper):
    def __init__(self, env: ClinicSchedulingEnvV2):
        super().__init__(env)
        self.env: ClinicSchedulingEnvV2
    def step(self, action):
        obs, reward, done, trunc, info = self.env.step(action)
        # Encourage serving scheduled first when on-site
        mask = self.env._mask()
        if mask[0] and action == 0:
            reward += 0.02
        # Penalize long waits more strongly
        if self.env.served_log:
            w = self.env.served_log[-1].get('wait_minutes')
            if w is not None and w > 30:
                reward -= 0.02
        return obs, reward, done, trunc, info


def mask_fn(env: ClinicSchedulingEnvV2):
    return env._mask().astype(bool)


def make_masked_env():
    base = make_env_v2()
    shaped = RewardShapingWrapper(base)
    if HAS_MASKABLE:
        return ActionMasker(shaped, mask_fn)
    return shaped

if HAS_MASKABLE:
    vec_masked = make_vec_env(make_masked_env, n_envs=1)
    model_masked = MaskablePPO('MlpPolicy', vec_masked, verbose=1)
    model_masked.learn(total_timesteps=200_000)
    print('MaskablePPO training complete.')
else:
    print('sb3-contrib not available; skipping MaskablePPO training.')

sb3-contrib not available; skipping MaskablePPO training.


## Optuna hyperparameter tuning (optional)
Run a short study to search PPO/MaskablePPO hyperparameters. May be compute-intensive in Colab.

In [21]:
# Optuna tuning demo
try:
    import optuna
    HAS_OPTUNA = True
except Exception:
    HAS_OPTUNA = False

if HAS_OPTUNA:
    def objective(trial):
        lr = trial.suggest_float('lr', 1e-5, 5e-4, log=True)
        n_steps = trial.suggest_int('n_steps', 512, 4096, log=True)
        batch_size = trial.suggest_categorical('batch_size', [128, 256, 512])
        gamma = trial.suggest_float('gamma', 0.98, 0.999)

        env = make_masked_env() if HAS_MASKABLE else make_env_v2()
        from stable_baselines3 import PPO
        model = (MaskablePPO('MlpPolicy', env, verbose=0, learning_rate=lr, n_steps=n_steps, batch_size=batch_size, gamma=gamma)
                 if HAS_MASKABLE else
                 PPO('MlpPolicy', env, verbose=0, learning_rate=lr, n_steps=n_steps, batch_size=batch_size, gamma=gamma))
        model.learn(total_timesteps=30_000)

        # Evaluate quickly
        def eval_once():
            e = make_env_v2()
            obs, _ = e.reset()
            done = False
            served = 0
            while not done:
                a, _ = model.predict(obs, deterministic=True)
                obs, r, term, trunc, _ = e.step(int(a))
                done = bool(term or trunc)
            return len(e.served_log)
        return eval_once()

    study = optuna.create_study(direction='maximize')
    study.optimize(objective, n_trials=10)
    print('Best params:', study.best_params)
else:
    print('Optuna not installed; skipping tuning.')

Optuna not installed; skipping tuning.


## Statistical evaluation with confidence intervals
Compute means and 95% CIs for key metrics (served_total, avg_wait) across multiple episodes to quantify reliability.

In [22]:
# Statistical evaluation (95% CI)
import numpy as np
import pandas as pd
from math import sqrt


def ci95(series: pd.Series):
    x = series.dropna().values
    if len(x) < 2:
        return (float('nan'), float('nan'))
    mean = float(np.mean(x))
    se = float(np.std(x, ddof=1) / sqrt(len(x)))
    return (mean - 1.96 * se, mean + 1.96 * se)

# Example: compare heuristic vs RL with CIs
heur_df = run_policy(make_env_v2, HeuristicPolicy(), episodes=20)
print('Heuristic served_total mean, 95% CI:', np.mean(heur_df['served_total']), ci95(heur_df['served_total']))
print('Heuristic avg_wait mean, 95% CI:', np.mean(heur_df['avg_wait']), ci95(heur_df['avg_wait']))

safe_rl = SafePolicy(make_env_v2(), loaded_v2)
rl_df = run_policy(make_env_v2, safe_rl, episodes=20)
print('RL served_total mean, 95% CI:', np.mean(rl_df['served_total']), ci95(rl_df['served_total']))
print('RL avg_wait mean, 95% CI:', np.mean(rl_df['avg_wait']), ci95(rl_df['avg_wait']))

Heuristic served_total mean, 95% CI: 84.1 (81.36153189192436, 86.83846810807563)
Heuristic avg_wait mean, 95% CI: 4.561222580543133 (3.793222460567078, 5.329222700519187)
RL served_total mean, 95% CI: 81.0 (76.7294625014741, 85.2705374985259)
RL avg_wait mean, 95% CI: 5.557506824777314 (3.8904633501300054, 7.224550299424623)


## Import Clinic Logs
Use this section to connect Google Drive (in Colab) or point to locally available CSV files that contain historical scheduling data. The calibration cells below expect arrival/service files in the formats described here.

In [23]:
# Configure data sources for calibration
import pathlib

# --- Step 1: (Optional) Mount Google Drive when running in Colab ---
try:
    from google.colab import drive  # type: ignore
    DRIVE_MOUNT = pathlib.Path('/content/drive')
    if not DRIVE_MOUNT.exists():
        drive.mount('/content/drive')
    print('Drive mounted at /content/drive')
except Exception:
    print('Google Drive not detected; skipping mount. Set paths manually below.')

# --- Step 2: Set file paths for historical data ---
# Replace these with the location of your exported logs.
# Expectation:
#   ARRIVALS_CSV: each row has booked_minute (int), timestamp (datetime), arrival_minute,
#                 no_show (0/1), late (0/1), and any categorical features.
#   SERVICE_CSV: each row has service_start_minute, service_duration_min, provider_id, etc.
ARRIVALS_CSV = ARRIVALS_CSV if 'ARRIVALS_CSV' in globals() else ''
SERVICE_CSV = SERVICE_CSV if 'SERVICE_CSV' in globals() else ''

# Example (uncomment and edit):
ARRIVALS_CSV = '/content/drive/MyDrive/clinic_logs/arrival.csv'
SERVICE_CSV = '/content/drive/MyDrive/clinic_logs/service_time.csv'

print('ARRIVALS_CSV =', ARRIVALS_CSV)
print('SERVICE_CSV  =', SERVICE_CSV)

Mounted at /content/drive
Drive mounted at /content/drive
ARRIVALS_CSV = /content/drive/MyDrive/clinic_logs/arrival.csv
SERVICE_CSV  = /content/drive/MyDrive/clinic_logs/service_time.csv


In [24]:
# Quick preview of data (optional)
import pandas as pd

def preview_csv(path, n=5):
    if not path:
        print('Path is empty; skip preview.')
        return
    try:
        df = pd.read_csv(path)
        print(f'{path} -> {len(df)} rows')
        display(df.head(n))
    except FileNotFoundError:
        print(f'File not found: {path}')
    except Exception as exc:
        print(f'Could not read {path}: {exc}')

preview_csv(ARRIVALS_CSV)
preview_csv(SERVICE_CSV)

/content/drive/MyDrive/clinic_logs/arrival.csv -> 619 rows



datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



Unnamed: 0,appointment_id,patient_id,provider_id,date,scheduled_time,arrival_time,late_by_mins,no_show_flag,status,scheduler_id
0,A5000,P2174,D02,2025-11-06,2025-11-06 13:00:00+08:00,2025-11-06 13:00:00+08:00,0.0,0,completed,sched_01
1,A5001,P2267,D02,2025-11-06,2025-11-06 09:00:00+08:00,2025-11-06 09:08:00+08:00,8.0,0,completed,sched_02
2,A5002,P2463,D03,2025-11-06,2025-11-06 14:00:00+08:00,2025-11-06 14:00:00+08:00,0.0,0,completed,sched_01
3,A5003,P2134,D03,2025-11-06,2025-11-06 08:00:00+08:00,,,0,cancelled,sched_01
4,A5004,P2682,D01,2025-11-06,2025-11-06 11:00:00+08:00,2025-11-06 11:00:00+08:00,0.0,0,completed,sched_01


/content/drive/MyDrive/clinic_logs/service_time.csv -> 619 rows



datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



Unnamed: 0,appointment_id,patient_id,provider_id,date,start_time,end_time,service_time_mins,status,scheduler_id
0,A5000,P2174,D02,2025-11-06,2025-11-06 13:02:00+08:00,2025-11-06 13:09:00+08:00,7.0,completed,sched_01
1,A5001,P2267,D02,2025-11-06,2025-11-06 09:04:00+08:00,2025-11-06 09:11:00+08:00,7.0,completed,sched_02
2,A5002,P2463,D03,2025-11-06,2025-11-06 14:04:00+08:00,2025-11-06 14:11:00+08:00,7.0,completed,sched_01
3,A5003,P2134,D03,2025-11-06,,,,cancelled,sched_01
4,A5004,P2682,D01,2025-11-06,2025-11-06 11:04:00+08:00,2025-11-06 11:11:00+08:00,7.0,completed,sched_01


## Full calibration pipeline and tuned hyperparameters
This section fits no-show and late-arrival models from CSV logs, calibrates environment parameters and domain-randomization ranges, and runs a short Optuna study to output tuned hyperparameters. Provide paths in `ARRIVALS_CSV` and `SERVICE_CSV`, then run the cells.

In [25]:
# Fit no-show and late models (logistic) + set calibrated params
import pandas as pd
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline

ARRIVALS_CSV = ARRIVALS_CSV or None
SERVICE_CSV = SERVICE_CSV or None

noshow_model = None
late_model = None

try:
    if ARRIVALS_CSV:
        df = pd.read_csv(ARRIVALS_CSV)
        # Basic feature set: hour, dow, visit_type (if present)
        df['hour'] = (df['booked_minute'] // 60).astype(int)
        df['dow'] = pd.to_datetime(df['timestamp']).dt.dayofweek
        feature_cols = ['hour', 'dow'] + ([ 'visit_type' ] if 'visit_type' in df.columns else [])
        X = df[feature_cols]
        # No-show model (scheduled only)
        sched = df[df['type'].str.lower() == 'scheduled'].copy()
        y_ns = sched['no_show'].astype(int)
        X_ns = sched[feature_cols]
        pre = ColumnTransformer([
            ('ohe', OneHotEncoder(handle_unknown='ignore'), feature_cols)
        ])
        noshow_model = Pipeline([
            ('pre', pre),
            ('clf', LogisticRegression(max_iter=1000))
        ])
        noshow_model.fit(X_ns, y_ns)
        p_ns = float(y_ns.mean())
        calib['no_show_prob'] = p_ns

        # Late model (scheduled only, arrived)
        sched_arrived = sched[sched['no_show'] == 0].copy()
        y_late = sched_arrived['late'].astype(int)
        X_lt = sched_arrived[feature_cols]
        late_model = Pipeline([
            ('pre', pre),
            ('clf', LogisticRegression(max_iter=1000))
        ])
        late_model.fit(X_lt, y_late)
        p_lt = float(y_late.mean())
        calib['late_prob'] = p_lt

        print('Fitted no-show baseline prob:', round(p_ns, 3), '| late baseline prob:', round(p_lt, 3))
    else:
        print('ARRIVALS_CSV not set; skipping model fit.')
except Exception as e:
    print('Model fitting failed; using default calibration. Reason:', e)

# Set domain randomization ranges around calibrated values
RAND_RANGES = {
    'walkin_rate_per_hour': (max(2.0, calib['walkin_rate_per_hour'] * 0.7), calib['walkin_rate_per_hour'] * 1.4),
    'no_show_prob': (max(0.0, calib['no_show_prob'] * 0.7), min(0.5, calib['no_show_prob'] * 1.4)),
    'late_prob': (max(0.0, calib['late_prob'] * 0.7), min(0.6, calib['late_prob'] * 1.4)),
    'service_mean_min': (max(4.0, calib['service_mean_min'] * 0.8), calib['service_mean_min'] * 1.2),
    'service_sigma_min': (max(0.5, calib['service_sigma_min'] * 0.8), calib['service_sigma_min'] * 1.3),
}
print('Updated RAND_RANGES:', RAND_RANGES)

Model fitting failed; using default calibration. Reason: 'booked_minute'
Updated RAND_RANGES: {'walkin_rate_per_hour': (7.0, 14.0), 'no_show_prob': (0.049, 0.098), 'late_prob': (0.08399999999999999, 0.16799999999999998), 'service_mean_min': (5.6000000000000005, 8.4), 'service_sigma_min': (1.6, 2.6)}



datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



In [26]:
# Short Optuna tuning on calibrated env (adjust n_trials for compute)
try:
    import optuna
    HAS_OPTUNA = True
except Exception:
    HAS_OPTUNA = False

best_params = None

if HAS_OPTUNA:
    def make_env_for_tuning():
        # use calibrated params
        return ClinicSchedulingEnvV2(
            slot_minutes=calib['slot_minutes'],
            num_providers=calib['num_providers'],
            service_mean_min=calib['service_mean_min'],
            service_sigma_min=calib['service_sigma_min'],
            walkin_rate_per_hour=calib['walkin_rate_per_hour'],
            no_show_prob=calib['no_show_prob'],
            late_prob=calib['late_prob'],
            walkin_cutoff_minute=calib['walkin_cutoff_minute'],
            seeded_schedule=seed_schedule,
        )

    def objective(trial):
        lr = trial.suggest_float('lr', 1e-5, 5e-4, log=True)
        n_steps = trial.suggest_int('n_steps', 512, 4096, log=True)
        batch_size = trial.suggest_categorical('batch_size', [128, 256, 512])
        gamma = trial.suggest_float('gamma', 0.98, 0.999)
        from stable_baselines3 import PPO
        env = make_env_for_tuning()
        model = PPO('MlpPolicy', env, verbose=0, learning_rate=lr, n_steps=n_steps, batch_size=batch_size, gamma=gamma)
        model.learn(total_timesteps=50_000)
        # quick metric: scheduled served
        e = make_env_for_tuning()
        obs, _ = e.reset()
        done = False
        while not done:
            a, _ = model.predict(obs, deterministic=True)
            obs, r, term, trunc, _ = e.step(int(a))
            done = bool(term or trunc)
        return len([1 for row in e.served_log if not row['is_walkin']])

    study = optuna.create_study(direction='maximize')
    study.optimize(objective, n_trials=10)
    best_params = study.best_params
    print('Tuned hyperparameters:', best_params)
else:
    print('Optuna not installed; skipping hyperparameter tuning.')

Optuna not installed; skipping hyperparameter tuning.


## Calibrated training setup (Drive mount + env + training)
This cell mounts Google Drive, sets file paths, redefines the environment factory to use calibrated parameters (`calib`), and runs calibrated training with MaskablePPO if available.

In [27]:
# Mount Drive and set base dir
try:
    from google.colab import drive
    drive.mount('/content/drive')
    BASE_DIR = '/content/drive/MyDrive/clinic_scheduling_rl'
except Exception:
    import os
    BASE_DIR = '/workspace/drive'

import os
os.makedirs(BASE_DIR, exist_ok=True)

# Optional: set CSV paths here if not set earlier
ARRIVALS_CSV = ARRIVALS_CSV or None
SERVICE_CSV  = SERVICE_CSV  or None

# Redefine env factory to use calibrated params

def make_env_v2():
    return ClinicSchedulingEnvV2(
        slot_minutes=calib["slot_minutes"],
        num_providers=calib["num_providers"],
        service_mean_min=calib["service_mean_min"],
        service_sigma_min=calib["service_sigma_min"],
        walkin_rate_per_hour=calib["walkin_rate_per_hour"],
        no_show_prob=calib["no_show_prob"],
        late_prob=calib["late_prob"],
        walkin_cutoff_minute=calib["walkin_cutoff_minute"],
        seeded_schedule=seed_schedule,
    )

# Calibrated training
from stable_baselines3.common.env_util import make_vec_env
vec_env_v2 = make_vec_env(make_env_v2, n_envs=1)

try:
    from sb3_contrib import MaskablePPO
    hp = best_params or {'lr':3e-4,'n_steps':2048,'batch_size':256,'gamma':0.995}
    model_v2 = MaskablePPO("MlpPolicy", vec_env_v2, verbose=1,
                           learning_rate=hp['lr'], n_steps=hp['n_steps'],
                           batch_size=hp['batch_size'], gamma=hp['gamma'])
except Exception:
    from stable_baselines3 import PPO
    hp = best_params or {'lr':3e-4,'n_steps':2048,'batch_size':256,'gamma':0.995}
    model_v2 = PPO("MlpPolicy", vec_env_v2, verbose=1,
                   learning_rate=hp['lr'], n_steps=hp['n_steps'],
                   batch_size=hp['batch_size'], gamma=hp['gamma'])

model_v2.learn(total_timesteps=500_000)

# Save to Drive
model_v2_path = os.path.join(BASE_DIR, 'model_v2.zip')
model_v2.save(model_v2_path)
print('Saved calibrated model to', model_v2_path)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
------------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 420         |
|    ep_rew_mean          | 82.8        |
| time/                   |             |
|    fps                  | 1119        |
|    iterations           | 8           |
|    time_elapsed         | 14          |
|    total_timesteps      | 16384       |
| train/                  |             |
|    approx_kl            | 0.010830843 |
|    clip_fraction        | 0.00117     |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.08       |
|    explained_variance   | 0.6776651   |
|    learning_rate        | 0.0003      |
|    loss                 | 1.85        |
|    n_updates            | 70          |
|    policy_gradient_loss | -0.00137    |
|    value_loss           | 4.59        |
-----------------------------------------
----------

## Richer observation wrapper, Behavior Cloning (BC) pretrain, RCPO, and robust training
Adds:
- ObservationAugmentWrapper: time-to-lunch/close, provider load stats, near-term backlog, predicted no-show/late probs
- BC pretrain on heuristic rollouts
- RCPO-style constrained training (SLA≤30 min)
- VecNormalize + multiple envs training
- 100-episode confidence interval evaluation

In [28]:
# Richer observation augmentation
class ObservationAugmentWrapper(gym.ObservationWrapper):
    def __init__(self, env: ClinicSchedulingEnvV2, noshow_model=None, late_model=None):
        super().__init__(env)
        self.env: ClinicSchedulingEnvV2
        self.noshow_model = noshow_model
        self.late_model = late_model
        orig = env.observation_space.shape[0]
        # extra features: time_to_lunch, time_to_close, provider_busy_frac, backlog_next_30min, p_no_show_next, p_late_next
        self.extra_dim = 6
        high = np.concatenate([
            env.observation_space.high,
            np.array([300, 600, 1.0, 50, 1.0, 1.0], dtype=np.float32)
        ])
        self.observation_space = spaces.Box(low=0.0, high=high, dtype=np.float32)

    def observation(self, obs):
        minute = self.env.minute
        ttl = max(0, MINUTES_LUNCH_START - minute) if minute < MINUTES_LUNCH_START else max(0, MINUTES_CLOSE_PM - minute)
        ttc = max(0, MINUTES_CLOSE_PM - minute)
        busy = sum(1 for t in self.env.provider_busy_remaining if t > 0)
        busy_frac = busy / max(1, self.env.num_providers)
        # approximate backlog in next 30 min (scheduled + walk-ins)
        near_sched = sum(
            1
            for slot, patient in self.env._remaining_scheduled()
            if self.env._slot_to_minute(slot) < minute + 30
        )
        near_backlog = near_sched + len(self.env.walkin_queue)
        # simple next scheduled probs
        p_ns = 0.0; p_lt = 0.0
        try:
            next_slot, _ = next(self.env._remaining_scheduled(), (None, None))
            if next_slot is not None:
                booked_minute = self.env._slot_to_minute(next_slot)
                df_feat = pd.DataFrame([{
                    'hour': booked_minute // 60,
                    'dow': 0
                }])
                if self.noshow_model is not None:
                    p_ns = float(self.noshow_model.predict_proba(df_feat)[0,1])
                if self.late_model is not None:
                    p_lt = float(self.late_model.predict_proba(df_feat)[0,1])
        except Exception:
            pass
        extra = np.array([ttl, ttc, busy_frac, near_backlog, p_ns, p_lt], dtype=np.float32)
        return np.concatenate([obs, extra], axis=0)


## Offline Replay (Historical Days)
Provide one JSON/CSV path per clinic day with the sequence of scheduled arrivals, walk-ins, and staff actions.
Configure the list below, then run the replay cell to compare heuristic vs. RL policies on actual demand.

In [37]:
# Offline replay driver (CSV or JSON day files)
import json
from pathlib import Path

HISTORICAL_DAY_FILES = ["/content/drive/MyDrive/clinic_logs_filled_regen/day_2025-10-20.csv",
                        "/content/drive/MyDrive/clinic_logs_filled_regen/day_2025-10-21.csv",
                        "/content/drive/MyDrive/clinic_logs_filled_regen/day_2025-10-22.csv",
                        "/content/drive/MyDrive/clinic_logs_filled_regen/day_2025-10-23.csv",
                        "/content/drive/MyDrive/clinic_logs_filled_regen/day_2025-10-24.csv",
                        "/content/drive/MyDrive/clinic_logs_filled_regen/day_2025-10-25.csv",
                        "/content/drive/MyDrive/clinic_logs_filled_regen/day_2025-10-27.csv",
                        "/content/drive/MyDrive/clinic_logs_filled_regen/day_2025-10-28.csv",
                        "/content/drive/MyDrive/clinic_logs_filled_regen/day_2025-10-29.csv",
                        "/content/drive/MyDrive/clinic_logs_filled_regen/day_2025-10-30.csv",
                        "/content/drive/MyDrive/clinic_logs_filled_regen/day_2025-10-31.csv",
                        "/content/drive/MyDrive/clinic_logs_filled_regen/day_2025-11-01.csv",
                        "/content/drive/MyDrive/clinic_logs_filled_regen/day_2025-11-03.csv"]


class HistoricalClinicEnv(ClinicSchedulingEnvV2):
    """Environment that replays observed schedules/walk-ins instead of sampling."""
    def __init__(self, day_data, **kwargs):
        super().__init__(**kwargs)
        self._scheduled_plan = day_data.get('scheduled', [])
        self._walkin_queue_plan = day_data.get('walkins', [])
        self._action_log = day_data.get('actions', [])
        self._action_idx = 0
        self._build_scheduled_from_observed()

    def _build_scheduled_from_observed(self):
        self.scheduled_slots.clear()
        pid_max = 0
        for entry in self._scheduled_plan:
            slot = int(entry['slot'])
            pid = entry['patient_id']
            arrival = entry.get('arrival_minute')
            if arrival == '' or arrival is None:
                arrival = None
            else:
                arrival = int(arrival)
            patient = Patient(id=pid, scheduled_slot=slot, arrival_time_min=arrival)
            patient.is_late = bool(int(entry.get('late', 0))) if entry.get('late') not in (None, '') else False
            self.scheduled_slots[slot].append(patient)
            self.generated_patients[pid] = patient
            pid_max = max(pid_max, pid)
        self.next_walkin_id = pid_max + 1

    def _maybe_generate_walkins(self):
        while self._walkin_queue_plan and int(self._walkin_queue_plan[0]['minute']) <= self.minute:
            entry = self._walkin_queue_plan.pop(0)
            pid = entry.get('patient_id')
            if pid in (None, ''):
                pid = self.next_walkin_id
                self.next_walkin_id += 1
            minute = int(entry['minute'])
            patient = Patient(id=pid, scheduled_slot=None, arrival_time_min=minute)
            self.walkin_queue.append(patient)
            self.generated_patients[pid] = patient

    def step(self, action):
        if self._action_log:
            action = int(self._action_log[min(self._action_idx, len(self._action_log)-1)])
            self._action_idx += 1
        return super().step(action)


def _load_day_file(day_path):
    day_path = Path(day_path)
    if day_path.suffix.lower() == '.json':
        data = json.loads(day_path.read_text())
    else:
        import pandas as pd
        df = pd.read_csv(day_path)
        if 'type' not in df.columns:
            raise ValueError('Historical CSV must include a "type" column (scheduled/walkin/action).')
        data = {
            'scheduled': df[df['type'].str.lower() == 'scheduled'].to_dict('records'),
            'walkins': df[df['type'].str.lower() == 'walkin'].to_dict('records'),
            'actions': df[df['type'].str.lower() == 'action']['action'].dropna().tolist()
        }
    return data


def load_historical_env(day_path, calib_overrides=None):
    data = _load_day_file(day_path)
    overrides = calib.copy()
    if calib_overrides:
        overrides.update(calib_overrides)
    env = HistoricalClinicEnv(
        data,
        slot_minutes=overrides['slot_minutes'],
        num_providers=overrides['num_providers'],
        service_mean_min=overrides['service_mean_min'],
        service_sigma_min=overrides['service_sigma_min'],
        walkin_rate_per_hour=overrides['walkin_rate_per_hour'],
        no_show_prob=overrides['no_show_prob'],
        late_prob=overrides['late_prob'],
        walkin_cutoff_minute=overrides['walkin_cutoff_minute'],
    )
    return env


def evaluate_historical_days(day_files, policy_factories=None):
    if not day_files:
        print('HISTORICAL_DAY_FILES is empty; add paths to replay.')
        return None
    if policy_factories is None:
        policy_factories = {
            'Heuristic': lambda env: HeuristicPolicy(),
            'Safe RL (fresh)': lambda env: SafePolicy(env, model_robust)
        }
    all_metrics = {}
    for policy_name, factory in policy_factories.items():
        rows = []
        for path in day_files:
            env = load_historical_env(path)
            policy = factory(env)
            obs, _ = env.reset()
            done = False
            while not done:
                if hasattr(policy, 'predict'):
                    action, _ = policy.predict(obs, deterministic=True)
                else:
                    action = policy.select_action(env)
                obs, reward, term, trunc, _ = env.step(int(action))
                done = bool(term or trunc)
            df = pd.DataFrame(env.served_log)
            rows.append(episode_metrics(env, df))
        all_metrics[policy_name] = pd.DataFrame(rows)
        print(f"=== Offline replay: {policy_name} ===")
        print(all_metrics[policy_name].describe().loc[['mean','std','min','max']])
        print()
    return all_metrics

# Example run (set HISTORICAL_DAY_FILES first):
# historical_results = evaluate_historical_days(HISTORICAL_DAY_FILES)


In [38]:
# Run offline replay (execute after populating HISTORICAL_DAY_FILES)
if HISTORICAL_DAY_FILES:
    historical_results = evaluate_historical_days(HISTORICAL_DAY_FILES)
    print('Replay completed for', len(HISTORICAL_DAY_FILES), 'day(s).')
else:
    print('HISTORICAL_DAY_FILES is empty; set file paths before running the replay.')

KeyError: 'action'


datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).



In [None]:
# Behavior Cloning (BC) pretrain on heuristic rollouts
from torch import nn
import torch

class BCPolicy(nn.Module):
    def __init__(self, obs_dim, act_dim=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 128), nn.ReLU(),
            nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, act_dim)
        )
    def forward(self, x):
        return self.net(x)


def collect_heuristic_dataset(n_episodes=50):
    ds_obs, ds_act = [], []
    for _ in range(n_episodes):
        env = make_env_v2()
        env = ObservationAugmentWrapper(env, noshow_model, late_model)
        policy = HeuristicPolicy()
        obs, _ = env.reset()
        done = False
        while not done:
            action = policy.select_action(env)
            ds_obs.append(obs.copy())
            ds_act.append(action)
            obs, r, term, trunc, _ = env.step(action)
            done = bool(term or trunc)
    X = torch.tensor(np.array(ds_obs), dtype=torch.float32)
    y = torch.tensor(np.array(ds_act), dtype=torch.long)
    return X, y


def bc_pretrain(epochs=5, batch=256):
    env_tmp = ObservationAugmentWrapper(make_env_v2(), noshow_model, late_model)
    obs_dim = env_tmp.observation_space.shape[0]
    model = BCPolicy(obs_dim)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()
    X, y = collect_heuristic_dataset(n_episodes=50)
    for ep in range(epochs):
        idx = torch.randperm(X.size(0))
        for i in range(0, X.size(0), batch):
            b = idx[i:i+batch]
            logits = model(X[b])
            loss = loss_fn(logits, y[b])
            opt.zero_grad(); loss.backward(); opt.step()
        print(f"BC epoch {ep+1}/{epochs} loss {loss.item():.4f}")
    return model

# Run BC
bc_model = bc_pretrain(epochs=5)
print("BC pretrain complete")

## Deployment Readiness Validation Harness
Use this section to run offline acceptance tests before shadowing or deploying: compare heuristic vs. RL, verify SLA compliance, and ensure hard constraints (no lunch service, no leftover queues) are respected.

In [None]:
# Evaluation harness for policy readiness
import numpy as np
import pandas as pd
from math import sqrt


def make_base_eval_env(randomized: bool = False):
    if randomized and 'rand_factory' in globals():
        return rand_factory.make_env()
    return make_env_v2()


def unwrap_to_base(env):
    current = env
    while hasattr(current, 'env'):
        current = current.env
    if hasattr(current, 'unwrapped'):
        return current.unwrapped
    return current


def episode_metrics(base_env, df: pd.DataFrame):
    metrics = {}
    metrics['served_total'] = len(df)
    metrics['avg_wait'] = float(df['wait_minutes'].dropna().mean()) if not df.empty else float('nan')
    metrics['p95_wait'] = float(df['wait_minutes'].dropna().quantile(0.95)) if not df.empty else float('nan')
    metrics['sla_compliance'] = float((df['wait_minutes'].fillna(999) <= 30).mean()) if not df.empty else float('nan')
    metrics['scheduled_unserved'] = int(sum(1 for _ in base_env._remaining_scheduled()))
    metrics['walkins_remaining'] = int(len(base_env.walkin_queue))
    metrics['late_remaining'] = int(sum(1 for p in base_env.late_list if p.arrival_time_min is not None))

    alerts = []
    if not df.empty and df['served_minute'].between(MINUTES_LUNCH_START, MINUTES_LUNCH_END - 1).any():
        alerts.append('served_during_lunch')
    if base_env.minute > MINUTES_CLOSE_PM:
        alerts.append(f'closed_overrun={base_env.minute - MINUTES_CLOSE_PM}')
    if metrics['scheduled_unserved']:
        alerts.append(f'unserved_scheduled={metrics["scheduled_unserved"]}')
    if metrics['walkins_remaining']:
        alerts.append(f'walkins_remaining={metrics["walkins_remaining"]}')
    if metrics['late_remaining']:
        alerts.append(f'late_remaining={metrics["late_remaining"]}')
    metrics['alerts'] = alerts
    return metrics


def ci95(series: pd.Series):
    x = series.dropna().values
    if len(x) < 2:
        return (float('nan'), float('nan'))
    mean = float(np.mean(x))
    se = float(np.std(x, ddof=1) / np.sqrt(len(x)))
    return (mean - 1.96 * se, mean + 1.96 * se)


def evaluate_policy_suite(policy_factories: dict, episodes: int = 30, randomized: bool = False, seed: int | None = None):
    results = {}
    for name, factory in policy_factories.items():
        metrics = []
        for ep in range(episodes):
            env = make_base_eval_env(randomized=randomized)
            policy = factory(env)
            obs, _ = env.reset(seed=seed + ep if seed is not None else None)
            done = False
            while not done:
                if hasattr(policy, 'predict'):
                    action, _ = policy.predict(obs, deterministic=True)
                else:
                    action = policy.select_action(env)
                obs, reward, term, trunc, _ = env.step(int(action))
                done = bool(term or trunc)

            base_env = unwrap_to_base(env)
            df = pd.DataFrame(base_env.served_log)
            metrics.append(episode_metrics(base_env, df))

        df_metrics = pd.DataFrame(metrics)
        results[name] = df_metrics

        print(f"=== {name} ({'randomized' if randomized else 'calibrated'} env) ===")
        summary_cols = ['served_total', 'avg_wait', 'p95_wait', 'sla_compliance']
        print(df_metrics[summary_cols].describe().loc[['mean', 'std', 'min', 'max']])
        for col in summary_cols:
            print(f"CI {col}: {ci95(df_metrics[col])}")
        alert_count = df_metrics['alerts'].apply(bool).sum()
        if alert_count:
            print(f"Alerts triggered in {alert_count} / {episodes} episodes")
            print(df_metrics.loc[df_metrics['alerts'].apply(bool), 'alerts'].head())
        else:
            print('No safety alerts triggered.')
        print()
    return results


policy_factories = {'Heuristic': lambda env: HeuristicPolicy()}
if 'loaded_v2' in globals():
    policy_factories['Safe RL (loaded_v2)'] = lambda env: SafePolicy(env, loaded_v2)

_ = evaluate_policy_suite(policy_factories, episodes=30, randomized=False)
if 'rand_factory' in globals():
    _ = evaluate_policy_suite(policy_factories, episodes=30, randomized=True)


In [None]:
# RCPO-style constrained RL (SLA <= 30 min)
class RCPOWrapper(gym.Wrapper):
    def __init__(self, env: ClinicSchedulingEnvV2, sla_minutes=30, lam=0.01, lam_lr=1e-4):
        super().__init__(env)
        self.sla = sla_minutes
        self.lam = lam
        self.lam_lr = lam_lr
        self.violations = 0

    def step(self, action):
        base_env = self.env.unwrapped if hasattr(self.env, "unwrapped") else self.env
        served_log = getattr(base_env, "served_log", None)
        prev_len = len(served_log) if served_log is not None else 0

        obs, reward, done, trunc, info = self.env.step(action)

        served_log = getattr(base_env, "served_log", None)
        if served_log:
            new_entries = served_log[prev_len:]
            for entry in new_entries:
                w = entry.get("wait_minutes")
                if w is not None and w > self.sla:
                    self.violations += 1
                    reward -= self.lam

        if done:
            total = max(1, len(served_log) if served_log is not None else 1)
            target = 0.15  # target 15% or less over-SLA
            rate = self.violations / total
            self.lam += self.lam_lr * (rate - target)
            self.lam = max(0.0, min(1.0, self.lam))
            self.violations = 0

        return obs, reward, done, trunc, info


def make_augmented_env():
    base = make_env_v2()
    aug = ObservationAugmentWrapper(base, noshow_model, late_model)
    rcpo = RCPOWrapper(aug)
    return rcpo

In [None]:
# Robust training: VecNormalize + multiple envs
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3 import PPO

n_envs = 4
vec_env = DummyVecEnv([make_augmented_env for _ in range(n_envs)])
vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True)

model_robust = PPO('MlpPolicy', vec_env, verbose=1, n_steps=2048, batch_size=512, gamma=0.995, learning_rate=3e-4)
model_robust.learn(total_timesteps=1_000_000)
print('Robust training complete')

In [None]:
# 100-episode CI evaluation (aligned with augmented + normalized training env)
import numpy as np
import pandas as pd
from math import sqrt
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize

def ci95(arr):
    x = np.array(arr, dtype=float)
    x = x[~np.isnan(x)]
    if len(x) < 2:
        return (float('nan'), float('nan'))
    m = float(np.mean(x))
    se = float(np.std(x, ddof=1) / sqrt(len(x)))
    return (m - 1.96 * se, m + 1.96 * se)


def make_eval_env(model):
    base = DummyVecEnv([make_augmented_env])
    eval_env = VecNormalize(base, training=False, norm_obs=True, norm_reward=False)
    training_env = model.get_env()
    if training_env is not None and hasattr(training_env, 'obs_rms'):
        eval_env.obs_rms = training_env.obs_rms.copy()
        eval_env.clip_obs = getattr(training_env, 'clip_obs', eval_env.clip_obs)
    return eval_env


def unwrap_env(env):
    current = env
    while hasattr(current, 'env'):
        current = current.env
    if hasattr(current, 'unwrapped'):
        return current.unwrapped
    return current


def evaluate_model(model, episodes=100):
    env = make_eval_env(model)
    served_totals, avg_waits, sla_compliance, p95_waits, violations = [], [], [], [], []

    for ep in range(episodes):
        obs = env.reset()
        done = False
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, done_vec, info = env.step(action)
            done = bool(done_vec[0])

        base_env = unwrap_env(env.envs[0])
        df = pd.DataFrame(base_env.served_log)
        served_totals.append(len(df))
        avg_waits.append(float(df['wait_minutes'].dropna().mean()) if not df.empty else float('nan'))
        p95_waits.append(float(df['wait_minutes'].dropna().quantile(0.95)) if not df.empty else float('nan'))
        sla_compliance.append(float((df['wait_minutes'].fillna(999) <= 30).mean()) if not df.empty else float('nan'))

        alerts = []
        if not df.empty and (df['served_minute'].between(MINUTES_LUNCH_START, MINUTES_LUNCH_END - 1).any()):
            alerts.append('served_during_lunch')
        remaining_sched = sum(1 for _ in base_env._remaining_scheduled())
        if remaining_sched:
            alerts.append(f'unserved_scheduled={remaining_sched}')
        if len(base_env.walkin_queue):
            alerts.append(f'walkins_remaining={len(base_env.walkin_queue)}')
        violations.append(alerts)

    def summarize(name, data):
        print(f"{name}: mean={np.nanmean(data):.2f}, CI={ci95(data)}")

    print('=== Robust PPO evaluation (augmented + normalized env) ===')
    summarize('Served total', served_totals)
    summarize('Average wait', avg_waits)
    summarize('95th percentile wait', p95_waits)
    summarize('SLA<=30min compliance', sla_compliance)
    flagged = [v for v in violations if v]
    if flagged:
        print(f"Alerts triggered in {len(flagged)} / {episodes} episodes -> {flagged[:3]}")
    else:
        print('No safety alerts triggered across evaluated episodes.')


evaluate_model(model_robust, episodes=100)

## End-to-End Training Runner
This cell installs any missing dependencies, trains the robust PPO policy with observation augmentation, behavior-cloned warm-start, and RCPO constraints, then runs the evaluation harnesses defined above.

- Call `train_end_to_end(total_timesteps=300_000, eval_randomized=True)` to kick off training + evaluations.
- Adjust `total_timesteps`, `n_envs`, or `log_dir` as needed; artifacts and VecNormalize stats write to the chosen log directory.
- The function automatically runs the calibrated/domain-randomized evaluation suites and the 100-episode CI check when those helpers are defined.

In [None]:
# End-to-end training orchestrator (install deps, train, evaluate)
import os
import sys
import json
import time
import pathlib
import importlib
import subprocess


def _ensure_rl_dependencies():
    required = [
        ("stable_baselines3", "stable-baselines3==2.3.2"),
        ("sb3_contrib", "sb3-contrib==2.3.2"),
        ("shimmy", "shimmy==1.3.0"),
        ("gymnasium", "gymnasium==0.29.1"),
    ]
    missing = [pkg for pkg, pip_name in required if importlib.util.find_spec(pkg) is None]
    if not missing:
        return
    print("Installing missing RL dependencies:", missing)
    cmd = [sys.executable, "-m", "pip", "install"] + [pip_name for pkg, pip_name in required if pkg in missing]
    subprocess.run(cmd, check=True)


def train_end_to_end(total_timesteps: int = 200_000,
                     n_envs: int = 4,
                     log_dir: str = None,
                     eval_episodes: int = 30,
                     eval_randomized: bool = False,
                     save_name: str = "model_robust"):
    """Full pipeline: dependency check, training on augmented+RCPO env, evaluation."""
    _ensure_rl_dependencies()
    from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
    try:
        from sb3_contrib import MaskablePPO as Algo
        algo_name = "MaskablePPO"
        algo_kwargs = {}
    except Exception:
        from stable_baselines3 import PPO as Algo
        algo_name = "PPO"
        algo_kwargs = {}

    if log_dir is None:
        base_dir = pathlib.Path("/content") if pathlib.Path("/content").exists() else pathlib.Path.cwd()
        log_dir = base_dir / "logs_end_to_end"
    log_dir = pathlib.Path(log_dir)
    log_dir.mkdir(parents=True, exist_ok=True)

    print(f"Using algorithm: {algo_name} | total_timesteps={total_timesteps:,} | n_envs={n_envs}")

    vec_env = DummyVecEnv([make_augmented_env for _ in range(n_envs)])
    vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True, clip_obs=10.0)

    model = Algo(
        "MlpPolicy",
        vec_env,
        verbose=1,
        learning_rate=3e-4,
        n_steps=2048,
        batch_size=512,
        gamma=0.995,
        gae_lambda=0.95,
        **algo_kwargs,
    )

    model.learn(total_timesteps=total_timesteps)
    print("Training complete.")

    globals()['model_robust'] = model

    model_path = log_dir / f"{save_name}.zip"
    model.save(model_path)
    print(f"Saved policy to {model_path}")

    try:
        vec_env.save(log_dir / "vecnormalize.pkl")
        print(f"Saved VecNormalize statistics to {log_dir / 'vecnormalize.pkl'}")
    except Exception as exc:
        print("Warning: could not save VecNormalize stats:", exc)

    if 'evaluate_policy_suite' in globals():
        factories = {'Heuristic': lambda env: HeuristicPolicy()}
        if 'loaded_v2' in globals():
            factories['Safe RL (loaded_v2)'] = lambda env: SafePolicy(env, loaded_v2)
        factories['Safe RL (fresh)'] = lambda env: SafePolicy(env, model)
        print("\nRunning calibrated evaluation suite ...")
        evaluate_policy_suite(factories, episodes=eval_episodes, randomized=False)
        if eval_randomized and 'rand_factory' in globals():
            print("\nRunning domain-randomized evaluation suite ...")
            evaluate_policy_suite(factories, episodes=eval_episodes, randomized=True)

    if 'evaluate_model' in globals():
        print("\nRunning 100-episode CI evaluation ...")
        evaluate_model(model, episodes=max(20, eval_episodes))
    else:
        print("Note: evaluate_model() not defined yet; skip 100-episode CI run.")

    return model


# Example usage (commented to avoid accidental long runs):
# trained_model = train_end_to_end(total_timesteps=300_000, n_envs=4, eval_episodes=20, eval_randomized=True)


In [None]:
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple

SLOT_MINUTES_BOOKING = 7
OPEN_HOURS = [8, 9, 10, 11, 13, 14, 15]


def _capacity_for_hour(hour: int, slot_minutes: int = SLOT_MINUTES_BOOKING) -> int:
    # Hours just before a closed boundary (12:00 lunch, 16:00 close) cannot bleed over
    if hour == 11:
        return (60) // slot_minutes  # end by 12:00
    if hour == 15:
        return (60) // slot_minutes  # end by 16:00
    # Other hours may bleed into next hour a bit
    from math import ceil
    return int(ceil(60.0 / slot_minutes))


def _minute_to_str(minute_of_day: int) -> str:
    h = minute_of_day // 60
    m = minute_of_day % 60
    suffix = "am" if h < 12 else "pm"
    h12 = h if 1 <= h <= 12 else (h - 12 if h > 12 else 12)
    return f"{h12}:{m:02d}{suffix}"


def _format_hour_ranges(hours: List[int]) -> str:
    if not hours:
        return "(none)"
    hours = sorted(hours)
    ranges: List[Tuple[int, int]] = []  # [start, end_exclusive]
    start = hours[0]
    prev = hours[0]
    for h in hours[1:]:
        if h == prev + 1 or (prev == 11 and h == 13):
            # treat lunch gap as break; so 11->13 is not consecutive
            if prev == 11 and h == 13:
                ranges.append((start, prev + 1))
                start = h
            # else keep extending
        else:
            ranges.append((start, prev + 1))
            start = h
        prev = h
    ranges.append((start, prev + 1))

    def to_12h(h):
        suffix = "am" if h < 12 else "pm"
        h12 = h if 1 <= h <= 12 else (h - 12 if h > 12 else 12)
        return f"{h12}{suffix}"

    return ", ".join([f"{to_12h(s)}–{to_12h(e)}" for s, e in ranges])


@dataclass
class BookingPlanner:
    slot_minutes: int = SLOT_MINUTES_BOOKING
    open_hours: List[int] = field(default_factory=lambda: OPEN_HOURS.copy())
    capacity_by_hour: Dict[int, int] = field(init=False)
    booked_count_by_hour: Dict[int, int] = field(init=False)

    def __post_init__(self):
        self.capacity_by_hour = {h: _capacity_for_hour(h, self.slot_minutes) for h in self.open_hours}
        self.booked_count_by_hour = {h: 0 for h in self.open_hours}

    def available_hours(self) -> List[int]:
        return [h for h in self.open_hours if self.booked_count_by_hour[h] < self.capacity_by_hour[h]]

    def book(self, hour: int) -> Optional[str]:
        if hour not in self.open_hours:
            return None
        cap = self.capacity_by_hour[hour]
        used = self.booked_count_by_hour[hour]
        if used >= cap:
            return None
        # Assign next 7-min slot within the hour
        start_minute = hour * 60 + used * self.slot_minutes
        self.booked_count_by_hour[hour] += 1
        return _minute_to_str(start_minute)

    def availability_label(self) -> str:
        return _format_hour_ranges(self.available_hours())


# Demo: fill 8:00 hour and show remaining availability
planner = BookingPlanner()
print("Initial availability:", planner.availability_label())
assigned = []
for i in range(9):
    assigned.append(planner.book(8))
print("Assigned times at 8am:", assigned)
print("Availability after filling 8am:", planner.availability_label())

In [None]:
# Optional: Interactive booking widget (Colab)
try:
    import ipywidgets as widgets
    from IPython.display import display, clear_output

    planner_widget = BookingPlanner()

    def hour_options():
        return [(f"{h}:00 ({planner_widget.capacity_by_hour[h] - planner_widget.booked_count_by_hour[h]} left)", h)
                for h in planner_widget.available_hours()]

    hour_dd = widgets.Dropdown(options=hour_options(), description='Hour:')
    out = widgets.Output()

    def on_book(_):
        with out:
            clear_output()
            if not hour_dd.options:
                print("No hours available.")
                return
            hour = hour_dd.value
            assigned_time = planner_widget.book(hour)
            if assigned_time is None:
                print("Selected hour is full. Choose another.")
            else:
                print(f"Booked at {assigned_time}")
                print("Remaining availability:", planner_widget.availability_label())
            # refresh dropdown
            hour_dd.options = hour_options()

    btn = widgets.Button(description='Book')
    btn.on_click(on_book)

    display(widgets.VBox([hour_dd, btn, out]))
except Exception as e:
    print("Widgets unavailable:", e)