In [6]:
from typing import Any
import pandas as pd
import numpy as np
import json
import ast

class Patient:
    def __init__(self, patient_index: int, patient_data: pd.DataFrame):
        self.patient_index = patient_index
        self.patient_data = patient_data
        self.score = 0.0

class Engine:
    def __init__(self, rules: dict[str, Any], patient_csv_fp: str):
        self.rules = rules
        self.patient_csv_fp = patient_csv_fp

    def sort_patients(self) -> list[Patient]:
        """
        The engine takes in a set of rules and a CSV file containing patient data.

        It converts the CSV into a pandas DataFrame.

        It then iterates through the patients in the DataFrame and applies the rules to each patient to obtain a score.

        Finally, it sorts the patients by score and returns the patients in a list.
        """
        # Load the patient data from the CSV file
        patient_data = pd.read_csv(self.patient_csv_fp)

        # Create a list to store the patients
        patients = []

        # Iterate through the patients
        for index, row in patient_data.iterrows():
            patient = Patient(index, row)
            patients.append(patient)

        # Apply the appropriate rules to each patient
        for patient in patients:
            score = 0.0

            # Apply the inclusion criteria
            for inclusion_criterium in self.rules["inclusion_criterium"]:
                rule = inclusion_criterium["rule"]
                weight = inclusion_criterium["weight"]

                if rule["type"] == "age":
                    score += self.age(patient, min=rule["min"], max=rule["max"]) * weight
                elif rule["type"] == "gender":
                    score += self.gender(patient, gender=rule["gender"]) * weight
                elif rule["type"] == "medications":
                    score += self.medications(patient, medications=rule["medications"]) * weight
                elif rule["type"] == "preexisting_conditions":
                    score += self.preexisting_conditions(patient, study_icd9_codes=rule["icd9_codes"]) * weight

            patient.score = score
            # Apply the exclusion criteria (set score to 0 if any exclusion criteria are met)
            for exclusion_criterium in self.rules["exclusion_criterium"]:
                rule = exclusion_criterium["rule"]

                if rule["type"] == "age":
                    if self.age(patient, min=rule["min"], max=rule["max"]):
                        patient.score = 0.0
                        print(f'Patient {patient.patient_index} excluded due to age ({patient.patient_data["age"]}, needed {rule["min"]} to {rule["max"]})')
                        break
                elif rule["type"] == "gender":
                    if self.gender(patient, gender=rule["gender"]):
                        patient.score = 0.0
                        print(f'Patient {patient.patient_index} excluded due to gender ({patient.patient_data["gender"]}, needed to be {['M','F'][rule["gender"]]})')
                        break
                elif rule["type"] == "medications":
                    if self.medications(patient, medications=rule["medications"]) > 0.0:
                        patient.score = 0.0
                        print(f'Patient {patient.patient_index} excluded due to medications ({patient.patient_data["prescriptions"]}, but couldn\'t have {rule["medications"]})')
                        break
                elif rule["type"] == "preexisting_conditions":
                    if self.preexisting_conditions(patient, study_icd9_codes=rule["icd9_codes"]) > 0.0:
                        patient.score = 0.0
                        print(f'Patient {patient.patient_index} excluded due to preexisting conditions ({patient.patient_data["icd9_codes"]}, but couldn\'t have {rule["icd9_codes"]})')
                        break

        # Sort the patients by score
        patients.sort(key=lambda x: x.score, reverse=True)
        return patients

    def age(self, patient: Patient, *,
            min: int | None = None, max: int | None = None) -> float:
        """
        Used to specify an age criteria for a patient.

        Args:
            patient (Patient): The patient to evaluate.
            min (int): The minimum age (optional).
            max (int): The maximum age (optional).

        Returns:
            float: 1.0 if the patient's age is within the specified range, 0.0 otherwise.
        """
        age = patient.patient_data['age']
        if min is not None and age < min:
            return 0.0
        if max is not None and age > max:
            return 0.0
        return 1.0

    def gender(self, patient: Patient, gender: int) -> float:
        """
        Used to specify a gender criteria for a patient.

        Args:
            patient_data (Patient): The patient to evaluate
            gender (int): 0 for male, 1 for female, 2 for either

        Returns:
            float: 1.0 if the patient is of the specified gender, 0.0 otherwise.
        """
        if gender == 0:
            if patient.patient_data['gender'] == 'M':
                return 1.0

        elif gender == 1:
            if patient.patient_data['gender'] == 'F':
                return 1.0

        elif gender == 2:
            return 1.0

        else:
            return 0.0

    def medications(self, patient: Patient, medications: list[str] = None) -> float:
        """
        Used to find whether a patient is prescribed certain medications.

        Args:
            patient_data (Patient): The patient to evaluate
            medications (list[str]): List of medications

        Returns:
            float: between 1.0 and 0.0 for the percent the patient's medications that match the listed medications
        """
        # Columns to check in patient data
        columns = ['prescriptions', 'prescriptions_poe', 'prescriptions_generic']
        matches = 0

        for med in medications:
            # convert medication name to all lowercase
            med = med.lower().strip()

            for col in columns:
                for patient_med in patient.patient_data[col]:
                    if  med in patient_med.lower().strip():
                        matches += 1

        total = len(medications)

        return matches/total

    def preexisting_conditions(self, patient: Patient, study_icd9_codes: list[str]) -> float:
        """
        Used to find whether a patient has had prior diagnoses or conditions.

        Args:
            patient_data (Patient): The patient to evaluate
            study_icd9_codes (list[str]): List of ICD-9 codes

        Returns:
            float: between 1.0 and 0.0 for the percent the patient's prior conditions that match the listed ICD-9 codes
        """

        patient_icd9_codes = patient.patient_data['icd9_codes']

        # Standardize ICD-9 codes for clean string matching
        patient_icd9_codes = patient_icd9_codes.strip().lower()
        study_icd9_codes = [x.strip().lower() for x in study_icd9_codes]

        # Convert ICD-9 diagnostic codes to sets
        patient_icd9_codes = set(ast.literal_eval(patient_icd9_codes))
        study_icd9_codes = set(study_icd9_codes)

        # Count how many ICD-9 codes the patient matches to the study
        set_diff = patient_icd9_codes.intersection(study_icd9_codes)

        # Return a binary score of 1 or 0 if patient matches all or no criteria
        # Return a raw score for ranking
        prior_condition_score = len(set_diff) / len(study_icd9_codes)

        return prior_condition_score


# Define the rules
rules = """
{
  "response": "rules",
  "inclusion_criterium": [
    {
      "rule": {
        "type": "age",
        "min": 18,
        "max": 70
      },
      "weight": 1.0
    }
  ],
  "exclusion_criterium": [
    {
      "rule": {
        "type": "preexisting_conditions",
        "icd9_codes": ["140-239", "428", "490-496", "585", "571.5", "204-208", "042"]
      }
    },
    {
      "rule": {
        "type": "medications",
        "medications": ["immunosuppressant", "chemotherapy", "corticosteroid"]
      }
    },
    {
      "rule": {
        "type": "preexisting_conditions",
        "icd9_codes": ["415.1", "410", "070", "011"]
      }
    },
    {
      "rule": {
        "type": "other",
        "description": "Several complex medical conditions that cannot be fully captured with current ruleset, including: severe anemia, uncontrolled bleeding, large area burns, severe hypotension, severe thrombocytopenia, and various other clinical conditions mentioned in the exclusion criteria"
      }
    }
  ]
}
"""

rules = json.loads(rules)

# Create an instance of the engine
engine = Engine(rules, "bigPatientData.csv")

# Sort the patients
patients = engine.sort_patients()

# Print the patients
for patient in patients:
    print(patient.patient_index, patient.score)
    #if patient.score > 0:
    #    print(patient.patient_data)

# Print percentage of patients that meet the inclusion criteria
num_patients = len(patients)
num_included_patients = sum(patient.score > 0 for patient in patients)
percentage_included = num_included_patients / num_patients * 100
print(f"Percentage of patients that meet the inclusion criteria: {percentage_included}%")
print(f'Percentage of patients that were excluded: {100 - percentage_included}%')

15 1.0
row_id                                                                  9502
subject_id                                                             10043
gender                                                                     M
dob                                                      2109-04-07 00:00:00
dod                                                      2191-02-07 00:00:00
dod_hosp                                                                 NaN
dod_ssn                                                  2191-02-07 00:00:00
expire_flag                                                                1
icd9_codes                 ['51881', '486', '49121', '00845', '2875', '42...
diagnoses (long_titles)    ['Acute respiratory failure', 'Pneumonia, orga...
age                                                                59.473473
Name: 15, dtype: object
55 1.0
row_id                                                                 30873
subject_id                            