In [35]:
# !pip install gymnasium==0.29.1 stable-baselines3==2.3.2 sb3-contrib==2.3.2 pygame==2.6.0


In [36]:
import numpy as np
import gymnasium as gym
from gymnasium import spaces

from typing import Optional, Tuple, Dict, Any


In [None]:
# Import rendering module for pygame visualization
from rendering import ClinicRenderer

  from pkg_resources import resource_stream, resource_exists


In [38]:
# Action IDs
REQUEST_NCD_TEST = 0
REQUEST_INFECTION_TEST = 1
DIAGNOSE_CHRONIC = 2
DIAGNOSE_INFECTION = 3
ALLOCATE_MED = 4
REFER_PATIENT = 5
WAIT = 6

ACTION_MEANINGS = {
    REQUEST_NCD_TEST: "Request NCD test",
    REQUEST_INFECTION_TEST: "Request infection test",
    DIAGNOSE_CHRONIC: "Diagnose chronic condition",
    DIAGNOSE_INFECTION: "Diagnose infection",
    ALLOCATE_MED: "Allocate medication",
    REFER_PATIENT: "Refer patient",
    WAIT: "Wait / no action"
}

# Underlying "true" patient condition (not fully visible to agent)
CONDITION_HEALTHY_OR_MILD = 0
CONDITION_CHRONIC = 1
CONDITION_INFECTION = 2
CONDITION_BOTH_SERIOUS = 3


In [39]:
class RwandaHealthEnv(gym.Env):
    """
    Custom environment:
    - Agent acts like a decision support system in a Rwandan clinic.
    - It balances diagnosis accuracy & resource usage (tests, meds).
    """

    metadata = {"render_modes": ["human", "ansi"], "render_fps": 4}

    def __init__(
        self,
        max_steps: int = 50,
        initial_test_kits: int = 10,
        initial_meds: int = 20,
        max_queue: int = 30,
        render_mode: Optional[str] = None,
    ):
        super().__init__()

        self.max_steps = max_steps
        self.initial_test_kits = initial_test_kits
        self.initial_meds = initial_meds
        self.max_queue = max_queue
        self.render_mode = render_mode

        # Discrete action space
        self.action_space = spaces.Discrete(7)

        # Observation space: 12 continuous features, all normalized [0, 1]
        # [0] age_norm
        # [1] symptom_severity (0-1)
        # [2] chronic_risk_feature (0-1)  -- proxy
        # [3] infection_risk_feature (0-1)
        # [4] comorbidity_flag (0 or 1)
        # [5] available_test_kits_norm
        # [6] available_meds_norm
        # [7] queue_length_norm
        # [8] time_step_norm
        # [9] last_action_norm
        # [10] ncd_tested (0 or 1) - whether NCD test was done
        # [11] infection_tested (0 or 1) - whether infection test was done
        low = np.zeros(12, dtype=np.float32)
        high = np.ones(12, dtype=np.float32)
        self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)

        # Internal state
        self.test_kits: int = 0
        self.meds: int = 0
        self.queue_length: int = 0
        self.step_count: int = 0
        self.last_action: int = WAIT  # default

        # Patient workflow tracking
        self.ncd_tested: bool = False
        self.infection_tested: bool = False
        self.diagnosed: bool = False
        self.treated: bool = False

        # True patient condition (not part of obs)
        self.true_condition: int = CONDITION_HEALTHY_OR_MILD

        # For rendering
        self.last_reward: float = 0.0

        # To allow seeding
        self.np_random = np.random.RandomState()


In [None]:
# Helper methods for rewards and rendering
def _reward_for_diagnosis(self, action: int) -> float:
    """
    Reward based on correctness of diagnosis.
    Requires proper testing before diagnosis for full reward.
    """
    if action == DIAGNOSE_CHRONIC:
        # Check if proper testing was done
        if not self.ncd_tested:
            self.diagnosed = True  # Still mark as diagnosed
            if self.true_condition in [CONDITION_CHRONIC, CONDITION_BOTH_SERIOUS]:
                return 3.0  # Correct but without testing (lucky guess)
            else:
                return -10.0  # Wrong diagnosis, no test
        else:
            self.diagnosed = True
            if self.true_condition in [CONDITION_CHRONIC, CONDITION_BOTH_SERIOUS]:
                return 10.0  # Correct with testing
            else:
                return -6.0  # Wrong despite testing

    elif action == DIAGNOSE_INFECTION:
        # Check if proper testing was done
        if not self.infection_tested:
            self.diagnosed = True
            if self.true_condition in [CONDITION_INFECTION, CONDITION_BOTH_SERIOUS]:
                return 3.0  # Correct but without testing (lucky guess)
            else:
                return -10.0  # Wrong diagnosis, no test
        else:
            self.diagnosed = True
            if self.true_condition in [CONDITION_INFECTION, CONDITION_BOTH_SERIOUS]:
                return 10.0  # Correct with testing
            else:
                return -6.0  # Wrong despite testing
    return 0.0

def _reward_for_treatment(self) -> float:
    """
    Reward for allocating medication depending on condition.
    Higher reward if diagnosis was correct.
    """
    if self.true_condition in [CONDITION_CHRONIC, CONDITION_INFECTION]:
        return 6.0  # Good treatment
    elif self.true_condition == CONDITION_BOTH_SERIOUS:
        return 8.0  # Critical treatment
    else:
        return -4.0  # Unnecessary treatment

def _update_queue(self):
    """
    Simulate queue dynamics: some patients are served, some arrive.
    """
    # some patients served (1 per step) and random arrivals
    served = 1
    arrivals = self.np_random.randint(0, 3)  # 0-2 new patients

    self.queue_length = max(0, self.queue_length - served + arrivals)

def render(self):
    """
    Render the environment using pygame visualization (via rendering module)
    """
    if self.render_mode == "ansi":
        # Text-only rendering
        return (
            f"Step: {self.step_count}, "
            f"Condition: {self.true_condition}, "
            f"Test kits: {self.test_kits}, Meds: {self.meds}, "
            f"Queue: {self.queue_length}, "
            f"Last action: {ACTION_MEANINGS[self.last_action]}, "
            f"Last reward: {self.last_reward:.2f}"
        )

    elif self.render_mode == "human":
        # Initialize renderer if not already done
        if not hasattr(self, 'renderer'):
            self.renderer = ClinicRenderer()

        # Prepare state dictionary for renderer
        state = {
            'patient_age': self.patient_age,
            'symptom_severity': self.symptom_severity,
            'chronic_risk': self.chronic_risk,
            'infection_risk': self.infection_risk,
            'comorbidity_flag': self.comorbidity_flag,
            'ncd_tested': self.ncd_tested,
            'infection_tested': self.infection_tested,
            'diagnosed': self.diagnosed,
            'treated': self.treated,
            'test_kits': self.test_kits,
            'initial_test_kits': self.initial_test_kits,
            'meds': self.meds,
            'initial_meds': self.initial_meds,
            'queue_length': self.queue_length,
            'max_queue': self.max_queue,
            'last_action': self.last_action,
            'last_reward': self.last_reward,
            'step_count': self.step_count,
            'max_steps': self.max_steps,
        }

        # Use the renderer from rendering.py
        self.renderer.render(state)

    else:
        # Console fallback
        print(
            f"[Step {self.step_count}] "
            f"Cond={self.true_condition} | "
            f"Tests={self.test_kits}, Meds={self.meds}, Queue={self.queue_length} | "
            f"Action={ACTION_MEANINGS[self.last_action]} | "
            f"Reward={self.last_reward:.2f}"
        )

def close(self):
    """Close the pygame renderer"""
    if hasattr(self, 'renderer'):
        self.renderer.close()

# Attach all helper methods
RwandaHealthEnv._reward_for_diagnosis = _reward_for_diagnosis
RwandaHealthEnv._reward_for_treatment = _reward_for_treatment
RwandaHealthEnv._update_queue = _update_queue
RwandaHealthEnv.render = render
RwandaHealthEnv.close = close


In [41]:
# Main step method
def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
    assert self.action_space.contains(action), f"Invalid action: {action}"

    self.step_count += 1
    self.last_action = int(action)

    reward = 0.0
    info: Dict[str, Any] = {}

    # Small penalty for queue (scaled down from -0.2 to -1.0)
    queue_penalty = -1.0 * (self.queue_length / self.max_queue)
    reward += queue_penalty

    # ---- Action effects & reward shaping ----
    if action == REQUEST_NCD_TEST:
        if self.ncd_tested:
            reward -= 3.0  # Already tested, wasteful
        elif self.test_kits > 0:
            self.test_kits -= 1
            self.ncd_tested = True
            reward -= 0.5  # Small resource cost
            # Reward if test is informative
            if self.true_condition in [CONDITION_CHRONIC, CONDITION_BOTH_SERIOUS]:
                reward += 2.0  # Good test
                self.chronic_risk = np.clip(self.chronic_risk + 0.2, 0.0, 1.0)
            else:
                reward -= 0.5  # Unnecessary but not severely penalized
        else:
            reward -= 5.0  # No kits available

    elif action == REQUEST_INFECTION_TEST:
        if self.infection_tested:
            reward -= 3.0  # Already tested, wasteful
        elif self.test_kits > 0:
            self.test_kits -= 1
            self.infection_tested = True
            reward -= 0.5  # Small resource cost
            if self.true_condition in [CONDITION_INFECTION, CONDITION_BOTH_SERIOUS]:
                reward += 2.0  # Good test
                self.infection_risk = np.clip(self.infection_risk + 0.2, 0.0, 1.0)
            else:
                reward -= 0.5  # Unnecessary but not severely penalized
        else:
            reward -= 5.0  # No kits available

    elif action in [DIAGNOSE_CHRONIC, DIAGNOSE_INFECTION]:
        if self.diagnosed:
            reward -= 5.0  # Already diagnosed
        else:
            reward += self._reward_for_diagnosis(action)

    elif action == ALLOCATE_MED:
        if self.treated:
            reward -= 3.0  # Already treated
        elif not self.diagnosed:
            reward -= 5.0  # Cannot treat without diagnosis
        elif self.meds > 0:
            self.meds -= 1
            self.treated = True
            reward += self._reward_for_treatment()
        else:
            reward -= 5.0  # No meds available

    elif action == REFER_PATIENT:
        if self.diagnosed:
            # Referral after diagnosis
            if self.true_condition == CONDITION_BOTH_SERIOUS:
                reward += 8.0  # Excellent decision
            elif self.true_condition in [CONDITION_CHRONIC, CONDITION_INFECTION]:
                reward += 3.0  # Good decision
            else:
                reward -= 2.0  # Unnecessary referral
            # Patient complete, generate new one
            self._generate_new_patient()
        else:
            reward -= 5.0  # Cannot refer without diagnosis

    elif action == WAIT:
        # Waiting might be reasonable if lacking resources
        if self.test_kits == 0 or self.meds == 0:
            reward -= 0.2  # Small penalty
        else:
            reward -= 1.0  # Larger penalty if resources available

    # Check if patient workflow is complete (diagnosed + treated)
    if self.diagnosed and self.treated and action != REFER_PATIENT:
        # Bonus for completing a patient
        reward += 5.0
        # Generate new patient
        self._generate_new_patient()

    # queue change: each step some patients leave / arrive
    self._update_queue()

    # Terminal conditions (more lenient)
    terminated = False
    truncated = False

    if self.step_count >= self.max_steps:
        truncated = True

    # Episode ends if BOTH resources are depleted (changed from AND)
    if self.test_kits == 0 and self.meds == 0:
        terminated = True
        reward -= 10.0  # Severe penalty for running out

    if self.queue_length >= self.max_queue:
        terminated = True
        reward -= 10.0  # Severe penalty for overload

    self.last_reward = reward

    obs = self._get_obs()

    info.update(
        {
            "true_condition": self.true_condition,
            "test_kits": self.test_kits,
            "meds": self.meds,
            "queue_length": self.queue_length,
            "ncd_tested": self.ncd_tested,
            "infection_tested": self.infection_tested,
            "diagnosed": self.diagnosed,
            "treated": self.treated,
        }
    )

    return obs, reward, terminated, truncated, info

RwandaHealthEnv.step = step


In [42]:
# Observation method
def _get_obs(self) -> np.ndarray:
    age_norm = (self.patient_age - 15) / (80 - 15)  # 0-1
    test_kits_norm = np.clip(self.test_kits / max(self.initial_test_kits, 1), 0, 1)
    meds_norm = np.clip(self.meds / max(self.initial_meds, 1), 0, 1)
    queue_norm = np.clip(self.queue_length / self.max_queue, 0, 1)
    time_norm = self.step_count / max(self.max_steps, 1)
    last_action_norm = self.last_action / 6.0  # since 0-6

    obs = np.array(
        [
            age_norm,
            self.symptom_severity,
            self.chronic_risk,
            self.infection_risk,
            self.comorbidity_flag,
            test_kits_norm,
            meds_norm,
            queue_norm,
            time_norm,
            last_action_norm,
            float(self.ncd_tested),
            float(self.infection_tested),
        ],
        dtype=np.float32,
    )
    return obs

RwandaHealthEnv._get_obs = _get_obs


In [43]:
# Patient generation method
def _generate_new_patient(self):
    """
    Sample a new patient's 'true' condition and visible features.
    """
    # Sample condition with some bias toward chronic diseases (mission-related)
    self.true_condition = self.np_random.choice(
        [CONDITION_HEALTHY_OR_MILD, CONDITION_CHRONIC, CONDITION_INFECTION, CONDITION_BOTH_SERIOUS],
        p=[0.4, 0.25, 0.25, 0.10]
    )

    # Patient features (these will influence the observation)
    self.patient_age = self.np_random.randint(15, 80)  # 15-80
    self.symptom_severity = self.np_random.uniform(0.0, 1.0)

    # Risk proxies - not perfect, but correlated with condition
    if self.true_condition in [CONDITION_CHRONIC, CONDITION_BOTH_SERIOUS]:
        self.chronic_risk = self.np_random.uniform(0.6, 1.0)
    else:
        self.chronic_risk = self.np_random.uniform(0.0, 0.7)

    if self.true_condition in [CONDITION_INFECTION, CONDITION_BOTH_SERIOUS]:
        self.infection_risk = self.np_random.uniform(0.6, 1.0)
    else:
        self.infection_risk = self.np_random.uniform(0.0, 0.7)

    # Comorbidity flag more likely for serious/both
    if self.true_condition == CONDITION_BOTH_SERIOUS:
        self.comorbidity_flag = 1.0
    else:
        self.comorbidity_flag = float(self.np_random.rand() < 0.2)

    # Reset patient workflow state
    self.ncd_tested = False
    self.infection_tested = False
    self.diagnosed = False
    self.treated = False

RwandaHealthEnv._generate_new_patient = _generate_new_patient


In [44]:
# Reset and Seed methods
def seed(self, seed: Optional[int] = None):
    self.np_random = np.random.RandomState(seed)
    return [seed]

def reset(
    self,
    *,
    seed: Optional[int] = None,
    options: Optional[Dict[str, Any]] = None
) -> Tuple[np.ndarray, Dict[str, Any]]:
    super(RwandaHealthEnv, self).reset(seed=seed)
    if seed is not None:
        self.seed(seed)

    self.test_kits = self.initial_test_kits
    self.meds = self.initial_meds
    self.queue_length = self.np_random.randint(0, 5)  # small initial queue
    self.step_count = 0
    self.last_action = WAIT
    self.last_reward = 0.0

    # Generate first patient
    self._generate_new_patient()

    obs = self._get_obs()
    info = {}
    return obs, info

# Attach methods to class
RwandaHealthEnv.seed = seed
RwandaHealthEnv.reset = reset


In [45]:
env = RwandaHealthEnv(render_mode="human")

obs, info = env.reset(seed=42)
print("Initial observation shape:", obs.shape)
print("Initial info:", info)

for i in range(5):
    random_action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(random_action)
    env.render()
    if terminated or truncated:
        print("Episode ended early at step", i + 1)
        break

env.close()


Initial observation shape: (12,)
Initial info: {}


In [46]:
def run_random_policy(env, num_episodes: int = 3, max_steps: int = 50):
    for ep in range(num_episodes):
        print(f"\n=== Episode {ep + 1} ===")
        obs, info = env.reset(seed=ep)
        total_reward = 0.0

        for t in range(max_steps):
            action = env.action_space.sample()
            obs, reward, terminated, truncated, info = env.step(action)
            total_reward += reward
            env.render()
            if terminated or truncated:
                print(f"Episode finished after {t + 1} steps, total reward = {total_reward:.2f}")
                break

        else:
            # if no break
            print(f"Episode reached max_steps, total reward = {total_reward:.2f}")

env = RwandaHealthEnv(render_mode="human")
run_random_policy(env, num_episodes=2, max_steps=30)
env.close()



=== Episode 1 ===
Episode reached max_steps, total reward = -39.33

=== Episode 2 ===
Episode reached max_steps, total reward = -39.33

=== Episode 2 ===
Episode reached max_steps, total reward = -51.23
Episode reached max_steps, total reward = -51.23


In [47]:
# Pygame Visualization Demo - Random Policy
import time

env = RwandaHealthEnv(render_mode="human")
obs, info = env.reset(seed=42)

print("Starting pygame visualization...")
print("Close the pygame window to stop the simulation.")

try:
    for i in range(20):  # Run for 20 steps
        random_action = env.action_space.sample()
        obs, reward, terminated, truncated, info = env.step(random_action)
        env.render()
        time.sleep(0.5)  # Slow down for better viewing

        if terminated or truncated:
            print(f"\nEpisode ended at step {i + 1}")
            time.sleep(2)  # Pause to see final state
            break

    print("Visualization complete!")
    time.sleep(2)

except KeyboardInterrupt:
    print("\nStopped by user")

finally:
    env.close()
    print("Environment closed.")

Starting pygame visualization...
Close the pygame window to stop the simulation.
Visualization complete!
Visualization complete!
Environment closed.
Environment closed.


## Pygame Visualization Demo

This cell demonstrates the pygame visualization with random actions. The visualization shows:
- **Patient Information**: Age, symptoms, risk factors, and workflow status
- **Clinic Resources**: Test kits and medication levels with progress bars
- **Patient Queue**: Current queue length with color-coded warnings
- **Agent Actions**: Last action taken and reward received
- **Episode Progress**: Step counter and completion percentage