In [1]:
import polars as pl
import polars.selectors as cs
import os
import re
from functools import partial
import time

In [2]:
all_dataset_path = "ALL_DATASETS"
os.makedirs(all_dataset_path, exist_ok=True)

pre_final_df_csv_name = "OPIOID_ANALGESIC_PRED_RACES_RAW.csv"
pre_final_df_csv_path = os.path.join(all_dataset_path, pre_final_df_csv_name)

undersample_to_make_races_occur_uniformly = True

In [3]:
base_url = '/home/gokul/Hier-Legal-Graph/mimic_dataset/mimiciv_dataset/2.2/'

opioid_analgesics = {
    "Morphine": ["Morphine", "MS Contin", "Kadian", "Avinza", "Roxanol", "Arymo ER", "Arymo", "Morphabond", "Oramorph SR", "Oramorph", "Duramorph", 
                 "Astramorph PF", "Astramorph"],
    "Hydromorphone": ["Hydromorphone", "Dilaudid", "Exalgo"],
    "Fentanyl": ["Fentanyl", "Duragesic", "Actiq", "Fentora", "Subsys", "Abstral", "Lazanda", "Sublimaze"],
    "Oxycodone": ["Oxycodone", "OxyContin", "Roxicodone", "Xtampza ER", "Xtampza", "Percocet", "Percodan", "Targiniq ER", "Targiniq"],
    "Hydrocodone": ["Hydrocodone", "Hysingla ER", "Hysingla", "Zohydro ER", "Zohydro", "Vicodin", "Norco", "Lortab", "Reprexain"],
    "Codeine": ["Codeine", "Tylenol with Codeine", "Capital and Codeine"],
    "Tramadol": ["Tramadol", "Ultram", "ConZip", "Ryzolt"],
    "Methadone": ["Methadone", "Dolophine", "Methadose"],
    "Oxymorphone": ["Oxymorphone", "Opana ER", "Opana"],
    "Tapentadol": ["Tapentadol", "Nucynta", "Nucynta ER"],
    "Meperidine": ["Meperidine", "Demerol"],
    "Buprenorphine": ["Buprenorphine", "Subutex", "Butrans", "Belbuca"],
    "Levorphanol": ["Levorphanol", "Levo-Dromoran"],
    "Butorphanol": ["Butorphanol", "Stadol"],
    "Propoxyphene": ["Propoxyphene"],
    "Pentazocine": ["Pentazocine", "Talwin", "Talwin Nx"],
    "Dihydrocodeine": ["Dihydrocodeine"],
    "Nalbuphine": ["Nalbuphine", "Nubain"],
    "Tilidine": ["Tilidine"],
    "Dextropropoxyphene": ["Dextropropoxyphene"],
    "Pethidine": ["Pethidine", "Demerol"],
    "Oliceridine": ["Oliceridine", "Olinvyk"],
    "Alfentanil": ["Alfentanil", "Alfenta"],  
    "Remifentanil": ["Remifentanil", "Ultiva"],  
    "Sufentanil": ["Sufentanil", "Dsuvia", "Sufenta"] 
}

opioid_analgesics_names_list = sorted({name.strip().lower() for names in opioid_analgesics.values() for name in names})


def load_data(file_path: str, columns: list[str]=None, schema_overrides = None):
    '''
    file_path: Ignore '/home/gokul/Hier-Legal-Graph/mimic_dataset/mimiciv_dataset/2.2/' in the file path.
    That will be taken care of internally.
    
    columns: A list of column names to load from the CSV.

    schema_overrides: Explicitly specify the datatype of certain columns by inferring the type from the MIMIC IV v3.1 docs when reading the csv normally raises errors
    '''
    base_path = '/home/gokul/Hier-Legal-Graph/mimic_dataset/mimiciv_dataset/2.2/'
    file_path = base_path + file_path

    # Pass the columns list to read_csv
    df = pl.read_csv(file_path, columns=columns, schema_overrides=schema_overrides)
    return df


def format_diagnoses(diag_list: list[str]) -> str:
    '''
    Convert the diagnoses column which is currently a list of strings into a string column. We do this by joining all the diagnosis in the list into one 
    string after numbering the diagnosis

    [diag1, diag2, ...., diag5] ---> "\n1. diag1 \n2. diag2\n.......\n5. diag5"
    '''
    return "\n".join(f"{i+1}. {d}" for i, d in enumerate(diag_list))

def process_text_from_notes(text: str, section_pattern: re.Pattern, 
                            deid_pattern: re.Pattern, med_pattern: re.Pattern, 
                            target_sections: list, deid_dict: dict,
                            separator_to_join_sections = "\n\n---___---!@#%&*JOIN_SEPARATOR*&%#@!---___---\n\n") -> dict:
    """
    Extracts specific medical sections, de-identifies gender-related terms, and redacts medication names.

    Parameters:
    text (str): Input medical text.
    section_pattern (re.Pattern): Regex to extract target sections.
    deid_pattern (re.Pattern): Regex to replace gender-related terms.
    med_pattern (re.Pattern): Regex to redact medication names.
    target_sections (list): List of target sections to be extracted from the notes
    deid_dict (dict): A dictionary for gender de-identification.

    Returns:
    dict: Processed text and a flag (1 if all sections were found, else 0).
    """

    matches = section_pattern.findall(text)

    # Count found sections
    found_sections = set()
    for match in matches:
        for section in target_sections:
            if match.strip().startswith(section):
                found_sections.add(section)
                break
    
    # The flag is 1 if all target sections were extracted from the note.
    section_flag = 1 if len(found_sections) == len(target_sections) else 0

    
    final_text = separator_to_join_sections.join(matches)

    def replace(match):
        # Lowercase the match and replace with the corresponding value from deid_dict
        return deid_dict[match.group(0).lower()]

    final_text = deid_pattern.sub(replace, final_text)

    final_text = med_pattern.sub("___", final_text)

    return {"processed_text_with_all_sections_combined": final_text, "notes_section_flag": section_flag}

def has_opioid(medicines, opioid_patterns):
    for medicine in medicines:
        for pattern in opioid_patterns:
            if pattern.search(medicine):
                return True
    return False

def extract_section(text, section_title, separator_to_join_sections = "\n\n---___---!@#%&*JOIN_SEPARATOR*&%#@!---___---\n\n"):
    # Split the text by the custom separator
    sections = text.split(separator_to_join_sections)
    
    # Find the section that starts with the given title
    for section in sections:
        section = section.strip()
        if section.startswith(section_title):
            # Remove the section title and colon, then return the cleaned content
            content = section[len(section_title):].strip()
            if content.startswith(':'):
                content = content[1:].strip()
            return content
    
    # Return empty string if section not found
    return ""

In [4]:
admissions_df = load_data('hosp/admissions.csv')
patients_df = load_data('hosp/patients.csv')
icd_diagnosis_code_details_df = load_data('hosp/d_icd_diagnoses.csv', schema_overrides = {"icd_code": pl.Utf8})
icd_diagnosis_code_df = load_data('hosp/diagnoses_icd.csv')
icustays_df = load_data('icu/icustays.csv')
notes_df = load_data('mimic-iv-note-deidentified-free-text-clinical-notes-2.2/note/discharge.csv')

print(f"All of the tables have been loaded")

# Loading a subset of columns because polars is running into an error while inferring schema of "gsn" column. schema_overrides can be used but since we do not need gsn column, we do  not load it.
prescriptions_df = load_data('hosp/prescriptions.csv', columns=['subject_id', 'hadm_id', 'pharmacy_id', 'drug'])

# Now, in the admissions table, I only want to keep rows that belong to the last hadm_id of a particular subject_id.
admissions_last_hadm_id_df = (admissions_df.group_by("subject_id").agg(
    pl.all().sort_by("admittime", descending=True).first()
    ))

print(f"Shape of main_df after loading admissions_df is: {admissions_last_hadm_id_df.shape}\n\n==================")


# Inner join tables of admission and patients --> we only want rows where both information is present.
main_df = admissions_last_hadm_id_df.join(patients_df, on="subject_id", how="inner")
print(f"Shape of main_df (after joining patients with admissions table) currently is: {main_df.shape}\n\n==================")

################
### ICD CODE ###
################
print(f"Starting the ICD table code...")

# Joining the ICD code details to the ICD Diagnosis Codes table
icd_diagnosis_code_with_details_joined_df = icd_diagnosis_code_df.join(
    icd_diagnosis_code_details_df, 
    on=["icd_version", "icd_code"], 
    how="inner"  # Use inner join to retain records with all required information
)

print(f"The shape of the ICD codes + ICD Details joined is: {icd_diagnosis_code_with_details_joined_df.shape}\nThis should have the same number of rows as icd_diagnosis_code_df at 4_756_326")

icd_diagnosis_code_with_details_joined_and_grouped_by_hadm_id_df = icd_diagnosis_code_with_details_joined_df.group_by("hadm_id").agg(
    pl.col("long_title").alias("long_title_list")
)

# Now take the list[str] column and make a new column with just strings (converts it into the format defined inside format_diagnosis function)
icd_diagnosis_code_with_details_joined_and_grouped_by_hadm_id_df = icd_diagnosis_code_with_details_joined_and_grouped_by_hadm_id_df.with_columns(
    pl.col("long_title_list")
      .map_elements(format_diagnoses, return_dtype=pl.Utf8)
      .alias("numbered_diagnoses")
)

# Inner join tables of main_df and icd_diagnosis_code_with_details_joined_df --> we only want rows where both information is present. (There are some hadm_ids in main_df that do NOT have ICD codes)
main_df = main_df.join(icd_diagnosis_code_with_details_joined_and_grouped_by_hadm_id_df, on="hadm_id", how="inner")
print(f"Shape of main_df after joining the ICD code and ICD Diagnosis Details: {main_df.shape}\n\n==================")

###########################
### Prescriptions Table ###
###########################
print(f"Starting the Prescriptions table code...")

prescriptions_df = prescriptions_df.with_columns(pl.col("drug").str.to_lowercase().str.strip_chars())

# Group by hadm_id and aggregate all drugs given to the subject in that hadm_id into a list[str]
prescriptions_group_by_hadm_id_df = prescriptions_df.group_by("hadm_id").agg(pl.col("drug").alias("list_of_administered_meds"))

#### NOW WE MAKE THE GT_FLAG column. ####
# intialise the column as empty strings.
prescriptions_group_by_hadm_id_df = prescriptions_group_by_hadm_id_df.with_columns(pl.lit(value="", dtype=pl.Utf8).alias("GT_FLAG"))

# Fill GT_FLAG with "YES"
prescriptions_group_by_hadm_id_df = prescriptions_group_by_hadm_id_df.with_columns(
    pl.when(
        pl.col("list_of_administered_meds")
          .list.eval(pl.element().is_in(opioid_analgesics_names_list))
          .list.any()  # Returns True if any administered med is in the opioid list
    )
    .then(pl.lit("YES"))
    .otherwise(pl.col("GT_FLAG"))
    .alias("GT_FLAG")
)

#### Fill GT_FLAG with "UNK"
# Compile regex patterns for all opioids 
opioid_patterns = [re.compile(opioid, re.IGNORECASE) for opioid in opioid_analgesics_names_list]

has_opioid_partial = partial(has_opioid, 
                               opioid_patterns=opioid_patterns)

prescriptions_group_by_hadm_id_df = prescriptions_group_by_hadm_id_df.with_columns([
    pl.when(
        (pl.col("GT_FLAG") != "YES") & 
        pl.col("list_of_administered_meds").map_elements(has_opioid_partial, return_dtype=pl.Boolean)
    ).then(pl.lit(value="UNK", dtype=pl.Utf8)).otherwise(pl.col("GT_FLAG")).alias("GT_FLAG")
])

#### Fill all remaining rows with not GT_FLAG value as "NO"
prescriptions_group_by_hadm_id_df = prescriptions_group_by_hadm_id_df.with_columns(
    pl.when(pl.col("GT_FLAG") == "")
    .then(pl.lit("NO"))
    .otherwise(pl.col("GT_FLAG")).alias("GT_FLAG")
)

print(f"The ratios of UNK, NO, YES initially are:")
print(f'Fraction of UNK: {len(prescriptions_group_by_hadm_id_df.filter(pl.col("GT_FLAG") == "UNK")) / len(prescriptions_group_by_hadm_id_df):.2f}')
print(f'Fraction of NO: {len(prescriptions_group_by_hadm_id_df.filter(pl.col("GT_FLAG") == "NO")) / len(prescriptions_group_by_hadm_id_df):.2f}')
print(f'Fraction of YES: {len(prescriptions_group_by_hadm_id_df.filter(pl.col("GT_FLAG") == "YES")) / len(prescriptions_group_by_hadm_id_df):.2f}')

# Inner join tables of main_df and prescriptions_group_by_hadm_id_df --> we only want rows where both information is present.
main_df = main_df.join(prescriptions_group_by_hadm_id_df, on="hadm_id", how="inner")
print(f"Shape of main_df after including the prescription information is: {main_df.shape}\n\n==================")

##################
### ICU Module ###
##################
print(f"Starting the ICU Module code...")

print(f"The total number of unique hadm_id in the icu_stays table are: {icustays_df['hadm_id'].n_unique()}")

# For each hadm_id, take the sum of the length of stay in ICU for each stay_id of a particular hadm_id --> proxy for how serious the condition was if the patient did not die in the ICU
icustays_group_by_hadm_sum_of_los_df = icustays_df.group_by("hadm_id").agg(pl.col("los").sum().alias("Total_LOS_in_ICU_in_days"))

# Minimum value of Total_LOS_in_ICU_in_days for someone admitted to the ICU is >0, therefore, we can replace NULL values with 0 
min_stay_duration = icustays_group_by_hadm_sum_of_los_df.select("Total_LOS_in_ICU_in_days").min()["Total_LOS_in_ICU_in_days"][0]
print(f"Minimum stay duration: {min_stay_duration}")



# Now, we create two columns, one columns is --> "was_admitted_to_icu" (This will be 1 for a given hadm_id, if the id is present in icustays_group_by_hadm_sum_of_los_df, and 0 otherwise)
# other column is Total_LOS_in_ICU_in_days (This will be 0 if hadm_id was not present in icustays_group_by_hadm_sum_of_los_df, or the value of Total_LOS_in_ICU_in_days if it was present)
# Left join main_df with ICU LOS DataFrame on hadm_id. The hadm_id that were not present in the icustays_group_by_hadm_sum_of_los_df will have value as NULL for this column
main_df = main_df.join(icustays_group_by_hadm_sum_of_los_df, on="hadm_id", how="left")


# Fill missing Total_LOS_in_ICU_in_days with 0
main_df = main_df.with_columns(
    pl.col("Total_LOS_in_ICU_in_days").fill_null(0)
)

# Create was_admitted_to_icu: 1 if Total_LOS_in_ICU_in_days > 0, else 0 {since minimum value of someone admitted to the ICU was >0, 0.00125 to be precise}
main_df = main_df.with_columns(
    (pl.col("Total_LOS_in_ICU_in_days") > 0).cast(pl.Int8).alias("was_admitted_to_icu")
)

print(f"Shape of main_df after adding the ICU_STAYS details is: {main_df.shape}")

####################
### Notes Module ###
####################
print(f"Starting the Notes Module code...")

main_df = main_df.join(notes_df[["hadm_id", "text"]], on="hadm_id", how="inner")
print(f"Shape of main_df after just loading the notes module and appending the text column to main_df is: {main_df.shape}\n\n==================")

# Select the notes of the hadm_ids that are present in the main_df table.
notes_last_hadm_id_df = main_df[["hadm_id", "text"]]

print(f"Starting the processing of the notes to extract the sections that we want")
time1 = time.time()

target_sections = [
    'Allergies',
    'Chief Complaint',
    'History of Present Illness',
    'Past Medical History'
]

print(f"The sections that will be extracted are: ")
print(*target_sections, sep="\n")
print("\n")


pattern = (
    r'\n\s*\n('                                   # Match preceding blank line(s) then start capturing group for the section
    r'(?:' + '|'.join(re.escape(section) for section in target_sections) + r')'  # Match one of the target section titles
    r'\s*:'                                       # Optional whitespace and the colon after the title
    r'.*?'                                        # Lazily match all content (including newlines)
    r')'                                          # End capturing group
    r'(?=\n\s*\n[A-Z][^\n]*?:|\Z)'                 # Lookahead: next heading (blank line then a capital letter line ending with colon) or end of string
)

section_pattern = re.compile(pattern, re.DOTALL)

# Now perform gender de-identification on the final_text using the dictionary.
deid_dict = {
    'male': 'person',
    'female': 'person',
    'he': 'the patient',
    'she': 'the patient',
    'him': 'the patient',
    'her': "the patient's",
    'his': "the patient's",
    'hers': "the patient's",
    'himself': 'the patient',
    'herself': 'the patient',
    'mr': 'the patient',
    'mrs': 'the patient',
    'ms': 'the patient',
    'miss': 'the patient',
    'mister': 'the patient',
    'sir': 'the patient',
    'madam': 'the patient',
    'man': 'person',
    'woman': 'person',
    'men': 'people',
    'women': 'people',
    'gentleman': 'person',
    'gentlewoman': 'person',
    'boy': 'child',
    'girl': 'child',
    'boys': 'children',
    'girls': 'children',
    'father': 'parent',
    'mother': 'parent',
    'dad': 'parent',
    'mom': 'parent',
    'son': 'child',
    'daughter': 'child',
    'brother': 'sibling',
    'sister': 'sibling',
    'husband': 'spouse',
    'wife': 'spouse',
    'uncle': 'relative',
    'aunt': 'relative'
}

# Build a regex to match any key from the de-identification dictionary as a whole word (case-insensitive)
deid_pattern = re.compile(r'\b(?:' + '|'.join(map(re.escape, deid_dict.keys())) + r')\b', re.IGNORECASE)

# Build a regex to match any medication name (as a whole word/phrase, case-insensitive)
med_pattern = re.compile('(?:' + '|'.join(map(re.escape, opioid_analgesics_names_list)) + ')', re.IGNORECASE)

process_text_partial = partial(process_text_from_notes, 
                               section_pattern=section_pattern, 
                               deid_pattern=deid_pattern, 
                               med_pattern=med_pattern, 
                               target_sections=target_sections, 
                               deid_dict=deid_dict)

# Apply the processing function to each row
notes_last_hadm_id_df = notes_last_hadm_id_df.with_columns(
    pl.col("text").map_elements(
        process_text_partial,
        return_dtype=pl.Struct([
            pl.Field("processed_text_with_all_sections_combined", pl.Utf8),
            pl.Field("notes_section_flag", pl.Int64)
        ])
    ).alias("processed")
).unnest("processed")

print(f"Finished process the notes for the whole dataset, total time taken: {time.time() - time1:.2f} seconds")

############# FOR A MORE COMPREHENSIVE LIST, LOOK AT ORIGINAL CODE FROM compile_dataset_new.ipynb in the original Med-LLM-Fairness repo
male_indicators = ['prostate',
 'testicular',
 'penis',
 'scrotum',
 'spermatic',
 'testis',
 'epididymis',
 'phallus',
 'penile',
 'deferens',
 'ejaculate',
 'ejaculation',
 'sperm',
 'prostatic',
 'andropause',
 'smegma',
 'azoospermia',
 'cryptorchidism',
 'varicocele',
 'spermatogenesis',
 'paternity',
 'spermatorrhea',
 'Leydig',
 'Sertoli',
 'orchidometry',
 'impotence',
 'gynecomastia',
 'spermatogenic',
 'hypogonadism',
 'semen',
 'andrology',
 'spermiogram',
 'vasectomy',
 'orchiectomy',
 'orchidopexy',
 'penectomy',
 'circumcision',
 'prostatectomy',
 'scrotoplasty',
 'penoplasty',
 'varicocelectomy',
 'phalloplasty',
 'prostatitis',
 'orchitis',
 'epididymitis',
 'spermatocele',
 'chordee',
 'phimosis',
 'paraphimosis',
 'balanitis',
 'orchalgia',
 'seminal',
 'scrotal',
 'penoscrotal',
 'azoospermic',
 'oligospermia',
 'epididymal',
 'orchidectomy',
 'seminoma',
 'hematospermia',
 'balanoposthitis',
 'Peyronie',
 'Klinefelter']

female_indicators = ['ovarian',
 'uterine',
 'vaginal',
 'cervical',
 'breast',
 'myomectomy',
 'fallopian',
 'mammary',
 'vulvar',
 'clitoral',
 'labial',
 'endometrial',
 'cervix',
 'ovulation',
 'uterus',
 'vagina',
 'PCOS',
 'hysteroscopy',
 'hysterosalpingogram',
 'endometritis',
 'salpingography'
 'vulva',
 'adnexal',
 'areolar',
 'pregnant',
 'menstruation',
 'gravida',
 'parity',
 'menstrual',
 'amenorrhea',
 'menarche',
 'menopause',
 'lactation',
 'postpartum',
 'antenatal',
 'obstetric',
 'gestational',
 'preeclampsia',
 'eclampsia',
 'hydatidiform',
 'fetal',
 'luteal',
 'follicular',
 'chorioamnionitis',
 'miscarriage',
 'abortion',
 'mastitis',
 'vulvodynia',
 'cystocele',
 'rectocele',
 'anovulation',
 'adenomyosis',
 'oophoritis',
 'mammoplasty',
 'endometrioid',
 'menometrorrhagia',
 'gravidity',
 'ovulatory',
 'placenta',
 'amniotic',
 'placental',
 'chorionic',
 'hysterectomy',
 'oophorectomy',
 'salpingo-oophorectomy',
 'pap smear',
 'papanicolaou'
 'mammography',
 'colposcopy',
 'cesarean',
 'episiotomy',
 'curettage',
 'tubal',
 'endometriosis',
 'fibroid',
 'polycystic',
 'dysmenorrhea',
 'menorrhagia',
 'mastalgia',
 'galactorrhea',
 'fibrocystic',
 'cervicitis',
 'vulvitis',
 'vaginismus']

complete_list = male_indicators + female_indicators

print(f"Starting the gendered disease de-identification of the notes")
time1 = time.time()

pattern = re.compile(r'\b(?:' + '|'.join(map(re.escape, complete_list)) + r')\b', re.IGNORECASE)

notes_last_hadm_id_df = notes_last_hadm_id_df.with_columns(
    pl.col("processed_text_with_all_sections_combined")
    .str.contains_any(complete_list)
    .cast(pl.Int64)  # Convert boolean to integer (1 for True, 0 for False)
    .alias("to_discard")
)

print(f"Finished the gendered disease de-identification of the notes, time taken: {time.time() - time1:.2f} seconds")

# Now create four new columns for each section
notes_last_hadm_id_df = notes_last_hadm_id_df.with_columns([
    pl.col("processed_text_with_all_sections_combined").map_elements(
        lambda text: extract_section(text, "Allergies"),
        return_dtype=pl.Utf8
    ).alias("patient_Allergies"),
    
    pl.col("processed_text_with_all_sections_combined").map_elements(
        lambda text: extract_section(text, "Chief Complaint"),
        return_dtype=pl.Utf8
    ).alias("patient_chief_complaint"),
    
    pl.col("processed_text_with_all_sections_combined").map_elements(
        lambda text: extract_section(text, "History of Present Illness"),
        return_dtype=pl.Utf8
    ).alias("patient_history_of_present_illness"),
    
    pl.col("processed_text_with_all_sections_combined").map_elements(
        lambda text: extract_section(text, "Past Medical History"),
        return_dtype=pl.Utf8
    ).alias("patient_past_medical_history")
])

main_df = main_df.join(
    notes_last_hadm_id_df[['hadm_id', 'notes_section_flag', 'patient_Allergies', 'patient_chief_complaint', 'patient_history_of_present_illness', 'patient_past_medical_history', 'to_discard']], 
    on="hadm_id", 
    how="inner"
)

print(f"The shape of the dataset after processing and joining the notes module is: {main_df.shape}\n\n==================")

All of the tables have been loaded
Shape of main_df after loading admissions_df is: (180733, 16)

Shape of main_df (after joining patients with admissions table) currently is: (180733, 21)

Starting the ICD table code...
The shape of the ICD codes + ICD Details joined is: (4756326, 6)
This should have the same number of rows as icd_diagnosis_code_df at 4_756_326
Shape of main_df after joining the ICD code and ICD Diagnosis Details: (180516, 23)

Starting the Prescriptions table code...
The ratios of UNK, NO, YES initially are:
Fraction of UNK: 0.54
Fraction of NO: 0.36
Fraction of YES: 0.09
Shape of main_df after including the prescription information is: (151054, 25)

Starting the ICU Module code...
The total number of unique hadm_id in the icu_stays table are: 66239
Minimum stay duration: 0.00125
Shape of main_df after adding the ICU_STAYS details is: (151054, 27)
Starting the Notes Module code...
Shape of main_df after just loading the notes module and appending the text column to m

## Making Pre Final Dataset

In [5]:
main_df = main_df.with_columns(
    pl.col("race").str.split("-").list.first().str.strip_chars().str.split("/").list.first().str.strip_chars().alias("race")
)

# Drop races like "Unknown", "Other"
main_df = main_df.filter(pl.col('race') != 'UNKNOWN')
main_df = main_df.filter(pl.col('race') != 'OTHER')
main_df = main_df.filter(pl.col('race') != 'UNABLE TO OBTAIN')
main_df = main_df.filter(pl.col('race') != 'PATIENT DECLINED TO ANSWER')

main_df = main_df.with_columns(
    pl.col("race").replace({"HISPANIC": "HISPANIC/LATINO", "HISPANIC OR LATINO": "HISPANIC/LATINO"})
)

race_values = list(main_df['race'].value_counts(sort=True).head(4)['race'])
gender_values = list(main_df['gender'].unique())

demographic_dict = {
    'gender': gender_values,
    'race': race_values
}

print(f"Length of dataset BEFORE dropping rows that are not present in the top 4 races: {main_df.shape}")

main_df = main_df.filter(pl.col('race').is_in(demographic_dict['race']))

print(f"Length of dataset AFTER dropping rows that are not present in the top 4 races: {main_df.shape}")


#### WE STILL HAVENT DISCARDED THE ROWS WITH 'to_discard' flage = 1
main_df = main_df.filter(pl.col('to_discard') == 0)

print(f"Length of dataset AFTER dropping rows that had 'to_discard' flag = 1: {main_df.shape}")

main_df = main_df.filter(
    pl.col('GT_FLAG') != 'UNK'
)

print(f"Length of dataset AFTER dropping rows that had 'GT_FLAG' flag = UNK: {main_df.shape}")

Length of dataset BEFORE dropping rows that are not present in the top 4 races: (121428, 34)
Length of dataset AFTER dropping rows that are not present in the top 4 races: (120246, 34)
Length of dataset AFTER dropping rows that had 'to_discard' flag = 1: (93212, 34)
Length of dataset AFTER dropping rows that had 'GT_FLAG' flag = UNK: (44517, 34)


In [6]:
print(f"The value counts of the races before undersamplign is: {main_df['race'].value_counts()}")

print("Fractions BEFORE undersampling the dataset such that all races occur at uniform intervals is:")
print(f'Fraction of NO: {len(main_df.filter(pl.col("GT_FLAG") == "NO")) / len(main_df):.2f}')
print(f'Fraction of YES: {len(main_df.filter(pl.col("GT_FLAG") == "YES")) / len(main_df):.2f}')

The value counts of the races before undersamplign is: shape: (4, 2)
┌─────────────────┬───────┐
│ race            ┆ count │
│ ---             ┆ ---   │
│ str             ┆ u32   │
╞═════════════════╪═══════╡
│ HISPANIC/LATINO ┆ 2128  │
│ ASIAN           ┆ 2056  │
│ BLACK           ┆ 5871  │
│ WHITE           ┆ 34462 │
└─────────────────┴───────┘
Fractions BEFORE undersampling the dataset such that all races occur at uniform intervals is:
Fraction of NO: 0.78
Fraction of YES: 0.22


In [9]:
if undersample_to_make_races_occur_uniformly:
    print(f"The main_df will now be sampled so that all the races occur at a uniform rate")
    # Under sample the dataframe and have the races occuring at uniform intervals
    min_count = main_df['race'].value_counts(sort=True)[-1][0, 1]

    sampled_df = pl.DataFrame()

    for race in demographic_dict['race']:
        tmp_df = main_df.filter(pl.col('race') == race)
        if tmp_df.shape[0] > min_count:
            tmp_df = tmp_df.sample(n = min_count, with_replacement=False, seed=42)
        sampled_df = pl.concat([sampled_df, tmp_df])

    main_df_sampled = sampled_df
else:
    print(f"The main_df will NOT be sampled.")
    main_df_sampled = main_df


print(f'\n\nThe shape of main_df is now: {main_df_sampled.shape}\n\n')

print("Fractions AFTER undersampling the dataset such that all races occur at uniform intervals is:")
print(f'Fraction of NO: {len(main_df_sampled.filter(pl.col("GT_FLAG") == "NO")) / len(main_df_sampled):.2f}')
print(f'Fraction of YES: {len(main_df_sampled.filter(pl.col("GT_FLAG") == "YES")) / len(main_df_sampled):.2f}')

main_df_sampled = main_df_sampled.with_columns(
    num_icd_codes = pl.col("long_title_list").list.len().alias('num_icd_codes')
)

icd_threshold = 15
main_df_sampled_and_gender_based_disease_rows_discarded_and_limit_on_icd_codes_UNK_dropped = main_df_sampled.filter(pl.col('num_icd_codes') < icd_threshold)
print(f'Length of dataset after removing rows with number of ICD codes >= {icd_threshold}: {main_df_sampled_and_gender_based_disease_rows_discarded_and_limit_on_icd_codes_UNK_dropped.shape[0]}\n\n')

print(f"Fractions AFTER removing rows with num_icd_codes >= {icd_threshold}:")
print(f'Fraction of NO: {len(main_df_sampled.filter(pl.col("GT_FLAG") == "NO")) / len(main_df_sampled):.2f}')
print(f'Fraction of YES: {len(main_df_sampled.filter(pl.col("GT_FLAG") == "YES")) / len(main_df_sampled):.2f}')


# Selecting relevant rows from the final dataframe to make the pre_final_df (the dataframe using which, prompts will be made according to the prompt template)
pre_final_df = main_df_sampled_and_gender_based_disease_rows_discarded_and_limit_on_icd_codes_UNK_dropped.select(
    cs.by_name(
        "subject_id",
        "hadm_id",
        "gender",
        "race",
        "anchor_age",
        "patient_chief_complaint",
        "patient_Allergies",
        "patient_history_of_present_illness",
        "patient_past_medical_history",
        "numbered_diagnoses",
        "was_admitted_to_icu",
        "Total_LOS_in_ICU_in_days",
        "GT_FLAG"
    )
)

The main_df will now be sampled so that all the races occur at a uniform rate


The shape of main_df is now: (8224, 34)


Fractions AFTER undersampling the dataset such that all races occur at uniform intervals is:
Fraction of NO: 0.82
Fraction of YES: 0.18
Length of dataset after removing rows with number of ICD codes >= 15: 6138


Fractions AFTER removing rows with num_icd_codes >= 15:
Fraction of NO: 0.82
Fraction of YES: 0.18


In [8]:
pre_final_df.write_csv(pre_final_df_csv_path)