# Iteration 5: includes 4 types of patients, 4 types of LOS distributions and then the % of where they go afterwards: Rehab, ESD, Other. Aslo includes Test: Stroke Patient Journey Validation

---

## **. Patient Admission Sources**  
Patients enter the hospital through **two primary pathways**:  
- **New Admissions**: Direct hospital entry for Stroke, TIA, Complex Neurological, and Other Medical Cases.  


| Patient Type                     | Admissions (n) | Percentage (%) |
|----------------------------------|---------------|--------------|
| Stroke                           | 1,320         | 54%          |
| Transient Ischemic Attack (TIA)  | 158           | 6%           |
| Complex Neurological             | 456           | 19%          |
| Other Medical Cases              | 510           | 21%          |

**Patient admissions are distributed as follows**

| Category                | Mean(Days) |
|-------------------------|----------------------------|
| Stroke                  | 1.2                        |
| TIA (Transient Ischemic Attack) | 9.3                |
| Complex Neurological    | 3.6                        |
| Other                   | 3.2                        |



## **Patient Transfer Probabilities**  

### **From Acute Stroke Unit**  
| Destination  | Stroke | TIA | Complex Neurological | Other |
|-------------|--------|-----|----------------------|-------|
| **Rehab**   | 24%    | 1%  | 11%                  | 5%    |
| **ESD**     | 13%    | 1%  | 5%                   | 10%   |
| **Other***  | 63%    | 98% | 84%                  | 85%   |

**Other includes home, care home, or mortality.**


## 1. Imports 

In [4]:
import numpy as np
import itertools
import simpy
import os
import sys

In [5]:
# Add parent directory to path so we can import distribution.py
sys.path.append(os.path.abspath(".."))  # noqa: E402

from distribution import Exponential, Lognormal  # noqa: E402

## 2. Constants

In [6]:
# default mean inter-arrival times(exp)
IAT_STROKE = 1.2
IAT_TIA = 9.3
IAT_COMPLEX_NEURO = 3.6
IAT_OTHER = 3.2

# Default Length of Stay (LOS) parameters
# (mean, stdev for Lognormal distribution
LOS_STROKE = (7.4, 8.6)
LOS_TIA = (1.8, 2.3)
LOS_COMPLEX_NEURO = (4.0, 5.0)
LOS_OTHER = (3.8, 5.2)
LOS_STROKE_NESD = (7.4, 8.6)
LOS_STROKE_ESD = (4.6, 4.8)
LOS_STROKE_MORTALITY = (7.0, 8.7)

# % where patients go after the Acute Stroke Unit (ASU)
TRANSFER_PROBABILITIES = {
    "stroke": {"rehab": 0.24, "esd": 0.13, "other": 0.63},
    "tia": {"rehab": 0.01, "esd": 0.01, "other": 0.98},
    "complex_neuro": {"rehab": 0.11, "esd": 0.05, "other": 0.84},
    "other": {"rehab": 0.05, "esd": 0.10, "other": 0.85},
}


# sampling settings, 4 for arrivals, 4 for LOS and
# for 4 transfer probabilities
N_STREAMS = 12
DEFAULT_RND_SET = 0

# Boolean switch to simulation results as the model runs
TRACE = False

# run variables (units = days)
RUN_LENGTH = 365 * 3

## 2. Helper classes and functions

In [7]:
def trace(msg):
    """
    Turing printing of events on and off.

    Params:
    -------
    msg: str
        string to print to screen.
    """
    if TRACE:
        print(msg)

## 3. Experiment class

In [8]:
class Experiment:
    """
    Encapsulates the concept of an experiment
    for the Acute Stroke Unit simulation.
    Manages parameters, PRNG streams, and results.
    """

    def __init__(
        self,
        random_number_set=DEFAULT_RND_SET,
        n_streams=N_STREAMS,
        iat_stroke=IAT_STROKE,
        iat_tia=IAT_TIA,
        iat_complex_neuro=IAT_COMPLEX_NEURO,
        iat_other=IAT_OTHER,
        asu_beds=10,
        los_stroke=LOS_STROKE,
        los_tia=LOS_TIA,
        los_complex_neuro=LOS_COMPLEX_NEURO,
        los_other=LOS_OTHER,
        transfer_probabilities=TRANSFER_PROBABILITIES,
    ):
        """
        Initialize default parameters.
        """
        # Sampling settings
        self.random_number_set = random_number_set
        self.n_streams = n_streams

        # Model parameters
        self.iat_stroke = iat_stroke
        self.iat_tia = iat_tia
        self.iat_complex_neuro = iat_complex_neuro
        self.iat_other = iat_other
        self.asu_beds = asu_beds

        # LOS Parameters
        self.los_stroke = los_stroke
        self.los_tia = los_tia
        self.los_complex_neuro = los_complex_neuro
        self.los_other = los_other

        # Transfer probabilities
        self.transfer_probabilities = transfer_probabilities

        # Initialize results storage
        self.init_results_variables()

        # Initialize sampling distributions (RNGs)
        self.init_sampling()

    def set_random_no_set(self, random_number_set):
        """
        Controls the random sampling by re-seeding.
        """
        self.random_number_set = random_number_set
        self.init_sampling()

    def init_sampling(self):
        """
        Creates the distributions used by the model and initializes
        the random seeds of each.
        """
        # Create a new seed sequence
        seed_sequence = np.random.SeedSequence(self.random_number_set)
        # Produce n non-overlapping streams
        self.seeds = seed_sequence.spawn(self.n_streams)

        # Prepare a list of RNGs
        rng_list = [np.random.default_rng(s.entropy) for s in self.seeds]

        # Inter-arrival time distributions
        self.arrival_stroke = Exponential(self.iat_stroke, self.seeds[0])
        self.arrival_tia = Exponential(self.iat_tia, self.seeds[1])
        self.arrival_complex_neuro = Exponential(
            self.iat_complex_neuro, self.seeds[2]
        )
        self.arrival_other = Exponential(self.iat_other, self.seeds[3])

        # LOS distributions using stored parameters
        self.los_distributions = {
            "stroke": Lognormal(
                self.los_stroke[0], self.los_stroke[1], self.seeds[4]
            ),
            "tia": Lognormal(self.los_tia[0], self.los_tia[1], self.seeds[5]),
            "complex_neuro": Lognormal(
                self.los_complex_neuro[0],
                self.los_complex_neuro[1],
                self.seeds[6],
            ),
            "other": Lognormal(
                self.los_other[0], self.los_other[1], self.seeds[7]
            ),
        }

        # RNGs specifically for transfer choices (1 per patient type)
        self.transfer_rngs = {
            "stroke": rng_list[8],
            "tia": rng_list[9],
            "complex_neuro": rng_list[10],
            "other": rng_list[11],
        }

    def init_results_variables(self):
        """
        Initializes all the experiment variables used in results collection.
        """
        self.results = {
            "n_stroke": 0,
            "n_tia": 0,
            "n_complex_neuro": 0,
            "n_other": 0,
            "n_patients": 0,
            "n_stroke_discharged": 0,
            "n_tia_discharged": 0,
            "n_complex_neuro_discharged": 0,
            "n_other_discharged": 0,
            "n_discharged": 0,
            "stroke_transfer": {"rehab": 0, "esd": 0, "other": 0},
            "tia_transfer": {"rehab": 0, "esd": 0, "other": 0},
            "complex_neuro_transfer": {"rehab": 0, "esd": 0, "other": 0},
            "other_transfer": {"rehab": 0, "esd": 0, "other": 0},
            "total_transfers": {"rehab": 0, "esd": 0, "other": 0},
        }

## 3. Patient Class

In [9]:
class Patient:
    """
    Represents a patient in the system.
    """

    def __init__(self, patient_id, env, args, acute_stroke_unit, patient_type):
        self.patient_id = patient_id
        self.env = env
        self.args = args
        self.acute_stroke_unit = acute_stroke_unit
        self.patient_type = patient_type
        self.waiting_time = 0.0  # track how long the patient waited for a bed

    def treatment(self):
        """
        Simulates the patient’s treatment process.
        Upon discharge, calls `transfer()` to determine next destination.
        """
        arrival_time = self.env.now
        los_distribution = self.args.los_distributions[self.patient_type]

        # Arrival message
        trace(
            f"Patient {self.patient_id} ({self.patient_type.upper()})\
            arrives at {arrival_time:.2f}."
        )

        # Request bed from the ASU
        with self.acute_stroke_unit.beds.request() as request:
            yield request
            self.waiting_time = self.env.now - arrival_time
            los = los_distribution.sample()

            # Bed assigned message
            trace(
                f"Patient {self.patient_id} ({self.patient_type.upper()})\
                gets a bed at {self.env.now:.2f}."
                f" (Waited {self.waiting_time:.2f} days)"
            )

            # Simulate length of stay
            yield self.env.timeout(los)

            # Leaving message
            trace(
                f"Patient {self.patient_id} ({self.patient_type.upper()})\
                leaves at {self.env.now:.2f}."
                f" (LOS {los:.2f} days)"
            )

            self.args.results["n_discharged"] += 1
            self.args.results[f"n_{self.patient_type}_discharged"] += 1

            # Transfer the patient after discharge
            self.transfer()

    def transfer(self):
        """
        Determines the patient's next destination
        based on the transfer probabilities.
        Logs and updates the experiment's results.
        """
        # Access the RNG and probabilities for this patient type
        rng = self.args.transfer_rngs[self.patient_type]
        p_dict = self.args.transfer_probabilities[self.patient_type]
        destinations = list(p_dict.keys())
        probs = list(p_dict.values())

        # Random draw for the transfer destination
        destination = rng.choice(destinations, p=probs)
        trace(
            f"Patient {self.patient_id} ({self.patient_type.upper()})\
            is transferred to {destination.upper()}."
        )

        # Update results
        self.args.results["total_transfers"][destination] += 1
        self.args.results[f"{self.patient_type}_transfer"][destination] += 1

### 4. Acute Stroke Unit Class


In [10]:
class AcuteStrokeUnit:
    """
    Models the Acute Stroke Unit (ASU) in the hospital.
    """

    def __init__(self, env, args):
        self.env = env
        self.args = args
        self.beds = simpy.Resource(env, capacity=args.asu_beds)

    def patient_arrivals(self, patient_type, arrival_distribution):
        """
        A generator that creates patients of 'patient_type' according to
        the given 'arrival_distribution' (Exponential).
        """
        for patient_count in itertools.count(start=1):
            inter_arrival_time = arrival_distribution.sample()
            yield self.env.timeout(inter_arrival_time)

            # Track patient count
            self.args.results[f"n_{patient_type}"] += 1
            self.args.results["n_patients"] += 1

            trace(f"{self.env.now:.2f}: {patient_type.upper()} arrival.")

            # Create new patient and start its treatment process
            new_patient = Patient(
                patient_count, self.env, self.args, self, patient_type
            )
            self.env.process(new_patient.treatment())

## 4. A function per arrival source

The first approach we will use is creating an arrival generator per source.  There will be some code redundancy, but it will a clear design for others to understand.

## 5. Single run function

In [11]:
def single_run(experiment, rep=0, run_length=RUN_LENGTH):
    """
    Perform a single run of the model and return the results.

    Parameters
    ----------
    experiment : Experiment
        The experiment/parameters to use with model
    rep : int
        The replication number (used to set random seeds).
    run_length : float
        The run length of the model in days (default = 3650 = 10 years).
    """
    # 1. Reset results for each run
    experiment.init_results_variables()

    # 2. Set the random number set for this run
    experiment.set_random_no_set(rep)

    # 3. Create a fresh environment and an AcuteStrokeUnit
    env = simpy.Environment()
    asu = AcuteStrokeUnit(env, experiment)

    # 4. Create patient arrival processes for different types of patients
    env.process(asu.patient_arrivals("stroke", experiment.arrival_stroke))
    env.process(asu.patient_arrivals("tia", experiment.arrival_tia))
    env.process(
        asu.patient_arrivals("complex_neuro", experiment.arrival_complex_neuro)
    )
    env.process(asu.patient_arrivals("other", experiment.arrival_other))

    # 5. Run the simulation
    env.run(until=run_length)

    # 6. Trace summary of total patients
    total_patients = sum(
        experiment.results[key]
        for key in experiment.results
        if key.startswith("n_")
    )
    trace(f"Final summary for rep={rep}: {total_patients} total patients.")

    # Return the results dictionary
    return experiment.results

In [12]:
TRACE = False
experiment = Experiment()
results = single_run(experiment)
results

{'n_stroke': 934,
 'n_tia': 119,
 'n_complex_neuro': 300,
 'n_other': 368,
 'n_patients': 1721,
 'n_stroke_discharged': 923,
 'n_tia_discharged': 119,
 'n_complex_neuro_discharged': 299,
 'n_other_discharged': 367,
 'n_discharged': 1708,
 'stroke_transfer': {'rehab': 204, 'esd': 107, 'other': 612},
 'tia_transfer': {'rehab': 1, 'esd': 3, 'other': 115},
 'complex_neuro_transfer': {'rehab': 33, 'esd': 13, 'other': 253},
 'other_transfer': {'rehab': 17, 'esd': 37, 'other': 313},
 'total_transfers': {'rehab': 255, 'esd': 160, 'other': 1293}}

## Test: Stroke Patient Journey Validation

### Objective  
To validate the consistency and correctness of the patient flow for individuals diagnosed with stroke within the simulation of the Acute Stroke Unit (ASU). This test ensures that stroke patients follow a logical care pathway from admission to transfer.

### Context  
In this SimPy-based simulation, stroke patients are admitted to the ASU when a bed becomes available. They remain in the unit for a duration based on their length of stay (LoS). After completing their stay, they are transferred to another pathway (e.g., Rehab, ESD, or Other).

Because the simulation ends at a fixed time, some patients may still be in the system when it finishes. Therefore, the number of admitted stroke patients will often be **greater** than the number of transferred patients.

### Test Steps  
1. **Admission Count**: Count the number of stroke patients admitted to the ASU (`n_patients`).
2. **Discharge Count**: Count the number of stroke patients who completed their LoS and left the ASU (`n_discharged`).
3. **Transfer Count**: Sum the number of stroke patients transferred to Rehab, ESD, or Other (`sum(transfer_transfers.values())`).
4. **Validation**:  
   - The number of discharged stroke patients should equal the number of transferred stroke patients:  
     `n_discharged == (sum(transfer_transfers.values()))`.
   - The number of discharged patients should be **less than or equal to** the number admitted:  
     `n_discharged <= n_patients`

### Expected Outcome  
- All stroke patients who complete their stay in the ASU are transferred.
- Some stroke patients may still be in the ASU at the end of the simulation.
- A mismatch between discharged and transferred patients indicates a logic error.

In [13]:
# Calculate percentage of admitted patients per category
stroke_percentage = results["n_stroke"] / results["n_patients"]
tia_percentage = results["n_tia"] / results["n_patients"]
complex_percentage = results["n_complex_neuro"] / results["n_patients"]
other_percentage = results["n_other"] / results["n_patients"]

# Calculate percentage of discharged patients per category
stroke_percentage_discharge = (
    results["n_stroke_discharged"] / results["n_discharged"]
)
tia_percentage_discharge = (
    results["n_tia_discharged"] / results["n_discharged"]
)
complex_percentage_discharge = (
    results["n_complex_neuro_discharged"] / results["n_discharged"]
)
other_percentage_discharge = (
    results["n_other_discharged"] / results["n_discharged"]
)

# Helper function to print transfer info

def print_transfer_info(transfer_data, group_name, total_group_patients):
    rehab = transfer_data["rehab"]
    esd = transfer_data["esd"]
    other = transfer_data["other"]
    total = rehab + esd + other

    def percent(x):
        if total_group_patients > 0:
            return (x / total_group_patients) * 100
        else:
            return 0

    print(f"{group_name} Transfers:")
    print(f"  Rehab: {rehab} ({percent(rehab):.2f}%)")
    print(f"  ESD: {esd} ({percent(esd):.2f}%)")
    print(f"  Other: {other} ({percent(other):.2f}%)")
    print(f"  Total: {total} ({percent(total):.2f}%)\n")

    return total



In [14]:
# Display admissions
print("ADMISSIONS")
print("---------------------------------------------")
print(f"Number of patients admitted: {results['n_patients']}")
print(f"Stroke: {results['n_stroke']} ({stroke_percentage*100:.2f}%)")
print(f"TIA: {results['n_tia']} ({tia_percentage*100:.2f}%)")
print(
    f"Complex Neuro: {results['n_complex_neuro']} \
        ({complex_percentage*100:.2f}%)"
)
print(f"Other: {results['n_other']} ({other_percentage*100:.2f}%)")

ADMISSIONS
---------------------------------------------
Number of patients admitted: 1721
Stroke: 934 (54.27%)
TIA: 119 (6.91%)
Complex Neuro: 300         (17.43%)
Other: 368 (21.38%)


In [15]:
# Display discharges
print("\nDISCHARGES")
print("---------------------------------------------")
print(f"Number of patients discharged: {results['n_discharged']}")
print(
    f"% of admitted patients discharged:\
    {results['n_discharged'] / results['n_patients'] * 100:.2f}%"
)
print(
    f"Stroke: {results['n_stroke_discharged']}\
    ({stroke_percentage_discharge*100:.2f}%)"
)
print(
    f"TIA: {results['n_tia_discharged']} ({tia_percentage_discharge*100:.2f}%)"
)
print(
    f"Complex Neuro: {results['n_complex_neuro_discharged']}\
    ({complex_percentage_discharge*100:.2f}%)"
)
print(
    f"Other: {results['n_other_discharged']} \
    ({other_percentage_discharge*100:.2f}%)"
)


DISCHARGES
---------------------------------------------
Number of patients discharged: 1708
% of admitted patients discharged:    99.24%
Stroke: 923    (54.04%)
TIA: 119 (6.97%)
Complex Neuro: 299    (17.51%)
Other: 367     (21.49%)


In [16]:
# Display transfers
print("\nTRANSFERS")
print("---------------------------------------------")
stroke_transferred = print_transfer_info(
    results["stroke_transfer"], "Stroke", results["n_stroke"]
)
tia_transferred = print_transfer_info(
    results["tia_transfer"], "TIA", results["n_tia"]
)
complex_transferred = print_transfer_info(
    results["complex_neuro_transfer"],
    "Complex Neuro",
    results["n_complex_neuro"],
)
other_transferred = print_transfer_info(
    results["other_transfer"], "Other", results["n_other"]
)

total_transferred = (
    stroke_transferred
    + tia_transferred
    + complex_transferred
    + other_transferred
)
print(f"Total transferred patients across all groups: {total_transferred}")


TRANSFERS
---------------------------------------------
Stroke Transfers:
  Rehab: 204 (21.84%)
  ESD: 107 (11.46%)
  Other: 612 (65.52%)
  Total: 923 (98.82%)

TIA Transfers:
  Rehab: 1 (0.84%)
  ESD: 3 (2.52%)
  Other: 115 (96.64%)
  Total: 119 (100.00%)

Complex Neuro Transfers:
  Rehab: 33 (11.00%)
  ESD: 13 (4.33%)
  Other: 253 (84.33%)
  Total: 299 (99.67%)

Other Transfers:
  Rehab: 17 (4.62%)
  ESD: 37 (10.05%)
  Other: 313 (85.05%)
  Total: 367 (99.73%)

Total transferred patients across all groups: 1708


In [17]:
# Final consistency checks
admitted = results["n_patients"]
discharged = results["n_discharged"]
completed = discharged = total_transferred

print("\nSUMMARY CHECK")
print("---------------------------------------------")
print(f"Admitted patients: {admitted}")
print(f"Discharged (i.e. completed): {discharged}")
print(f"Transferred patients: {total_transferred}")

if total_transferred != discharged:
    print(
        f"WARNING: Discharged patients ≠ \
    Transferred patients! Difference: {abs(discharged - total_transferred)}"
    )
else:
    print("SUCCESS: All discharged patients were transferred correctly.")

if discharged > admitted:
    print("WARNING: More patients discharged than admitted!")
elif discharged < admitted:
    print("NOTE: Some patients may still be in ASU at simulation end.")
else:
    print(
        "INFO: All admitted patients have been discharged (and transferred)."
    )


SUMMARY CHECK
---------------------------------------------
Admitted patients: 1721
Discharged (i.e. completed): 1708
Transferred patients: 1708
SUCCESS: All discharged patients were transferred correctly.
NOTE: Some patients may still be in ASU at simulation end.


## TEST - Should only show TIA patients

In [18]:
M = 1_000_000
experiment = Experiment(iat_stroke=M, iat_complex_neuro=M, iat_other=M)
results = single_run(experiment)
results

{'n_stroke': 0,
 'n_tia': 119,
 'n_complex_neuro': 0,
 'n_other': 0,
 'n_patients': 119,
 'n_stroke_discharged': 0,
 'n_tia_discharged': 119,
 'n_complex_neuro_discharged': 0,
 'n_other_discharged': 0,
 'n_discharged': 119,
 'stroke_transfer': {'rehab': 0, 'esd': 0, 'other': 0},
 'tia_transfer': {'rehab': 1, 'esd': 3, 'other': 115},
 'complex_neuro_transfer': {'rehab': 0, 'esd': 0, 'other': 0},
 'other_transfer': {'rehab': 0, 'esd': 0, 'other': 0},
 'total_transfers': {'rehab': 1, 'esd': 3, 'other': 115}}