In [1]:
import polars as pl
import polars.selectors as cs
import os

In [2]:
base_url = '/home/gokul/Hier-Legal-Graph/mimic_dataset/mimiciv_dataset/2.2/'
max_seq_num = 5

all_dataset_path = "ALL_DATASETS"
os.makedirs(all_dataset_path, exist_ok=True)

pre_final_df_csv_name = "ED_TRIAGE_PREDICTION_RAW.csv"
pre_final_df_csv_path = os.path.join(all_dataset_path, pre_final_df_csv_name)

def load_data(file_path: str, columns: list[str]=None, schema_overrides = None):
    '''
    file_path: Ignore '/home/gokul/Hier-Legal-Graph/mimic_dataset/mimiciv_dataset/2.2/' in the file path.
    That will be taken care of internally.
    
    columns: A list of column names to load from the CSV.

    schema_overrides: Explicitly specify the datatype of certain columns by inferring the type from the MIMIC IV v3.1 docs when reading the csv normally raises errors
    '''
    base_path = '/home/gokul/Hier-Legal-Graph/mimic_dataset/mimiciv_dataset/2.2/'
    file_path = base_path + file_path

    # Pass the columns list to read_csv
    df = pl.read_csv(file_path, columns=columns, schema_overrides=schema_overrides)
    return df


def format_diagnoses(diag_list: list[str]) -> str:
    '''
    Convert the diagnoses column which is currently a list of strings into a string column. We do this by joining all the diagnosis in the list into one 
    string after numbering the diagnosis

    [diag1, diag2, ...., diag5] ---> "\n1. diag1 \n2. diag2\n.......\n5. diag5"
    '''
    return "\n".join(f"{i+1}. {d}" for i, d in enumerate(diag_list))



In [3]:
edstays = load_data("ed/edstays.csv")
ed_diagnosis = load_data("ed/diagnosis.csv")
# medrecon = load_data("ed/medrecon.csv")
triage = load_data("ed/triage.csv")
# vitalsign = load_data("ed/vitalsign.csv")
admissions = load_data("hosp/admissions.csv")
patients_df = load_data('hosp/patients.csv')

main_df = edstays.join(patients_df, on="subject_id", how="inner")
main_df = main_df.drop(["anchor_year", "anchor_year_group"])

triage = triage.drop_nulls(subset=["temperature", "heartrate", "resprate", "o2sat", "sbp", "dbp", "pain", "acuity", "chiefcomplaint"])
main_df = main_df.join(triage, on="stay_id", how="inner")
main_df = main_df.drop("subject_id_right")

ed_diagnosis = ed_diagnosis.sort(["stay_id", "seq_num"])

print(f"ED_Diag shape (length) BEFORE filtering rows with more than {max_seq_num} diagnosis: {ed_diagnosis.shape[0]}")

ed_diagnosis = ed_diagnosis.filter(
    pl.col("seq_num") < max_seq_num
)
print(f"ED_Diag shape (length) AFTER filtering rows with more than {max_seq_num} diagnosis: {ed_diagnosis.shape[0]}")


ed_diagnosis_grouped = ed_diagnosis.group_by("stay_id").agg(
    pl.col("icd_title").alias("icd_title_list")
)

ed_diagnosis_grouped = ed_diagnosis_grouped.with_columns(
                                                        pl.col("icd_title_list")
                                                        .map_elements(format_diagnoses, return_dtype=pl.Utf8)
                                                        .alias("numbered_diagnoses")
                                                    ).drop("icd_title_list")

main_df = main_df.join(ed_diagnosis_grouped, on="stay_id", how="inner")

print(f"Shape of main_df BEFORE dropping non-alphanumeric ChiefComplaint: {main_df.shape[0]}")
main_df = main_df.filter(
    pl.col("chiefcomplaint").str.contains(r"^[a-zA-Z0-9\s,.!?;:'\"()-]+$")
)
print(f"Shape of main_df AFTER dropping non-alphanumeric ChiefComplaint: {main_df.shape[0]}")

main_df_copy = main_df.clone()

main_df = main_df.filter(pl.col("pain").is_in([i/10 for i in range(0, 105, 5)]))

main_df.drop(["hadm_id", "intime", "outtime", "dod"])
main_df = main_df.with_columns(
    pl.when(pl.col("acuity") < 3.0)
    .then(pl.lit("YES"))
    .otherwise(pl.lit("NO"))
    .alias("GT_FLAG")
)
main_df = main_df.join(admissions[["subject_id", "race"]].unique(subset="subject_id"), on="subject_id", how="inner")

print()

tdf = main_df["acuity"].value_counts()

rows_with_acuity_being_1_or_2 = tdf.filter(pl.col("acuity").is_in([1.0, 2.0]))["count"].sum()
rows_with_acuity_being_3 = tdf.filter(pl.col("acuity").is_in([3.0, 4.0, 5.0]))["count"].sum()

print(f"Sum of rows with acuity 1.0 or 2.0: {rows_with_acuity_being_1_or_2}")
print(f"Sum of rows with acuity 3.0: {rows_with_acuity_being_3}")

print()

main_df = main_df.drop("acuity")



ED_Diag shape (length) BEFORE filtering rows with more than 5 diagnosis: 949172
ED_Diag shape (length) AFTER filtering rows with more than 5 diagnosis: 912855
Shape of main_df BEFORE dropping non-alphanumeric ChiefComplaint: 379468
Shape of main_df AFTER dropping non-alphanumeric ChiefComplaint: 332438

Sum of rows with acuity 1.0 or 2.0: 101445
Sum of rows with acuity 3.0: 134903



## Making PreFInal Dataset

In [4]:
main_df = main_df.with_columns(
    pl.col("race").str.split("-").list.first().str.strip_chars().str.split("/").list.first().str.strip_chars().alias("race")
)

# Drop races like "Unknown", "Other"
main_df = main_df.filter(pl.col('race') != 'UNKNOWN')
main_df = main_df.filter(pl.col('race') != 'OTHER')
main_df = main_df.filter(pl.col('race') != 'UNABLE TO OBTAIN')
main_df = main_df.filter(pl.col('race') != 'PATIENT DECLINED TO ANSWER')

main_df = main_df.with_columns(
    pl.col("race").replace({"HISPANIC": "HISPANIC/LATINO", "HISPANIC OR LATINO": "HISPANIC/LATINO"})
)

race_values = list(main_df['race'].value_counts(sort=True).head(4)['race'])
gender_values = list(main_df['gender'].unique())

demographic_dict = {
    'gender': gender_values,
    'race': race_values
}

df = main_df
print(f"Length BEFORE dropping rows that are not present in the top 4 races: {df.shape[0]}")
print("Fractions BEFORE dropping rows that are not present in the top 4 races")
print(f'Fraction of NO: {len(df.filter(pl.col("GT_FLAG") == "NO")) / len(df):.2f}')
print(f'Fraction of YES: {len(df.filter(pl.col("GT_FLAG") == "YES")) / len(df):.2f}')

main_df = main_df.filter(pl.col('race').is_in(demographic_dict['race']))

df = main_df
print(f"Length AFTER dropping rows that are not present in the top 4 races: {df.shape[0]}")
print("Fractions AFTER dropping rows that are not present in the top 4 races")
print(f'Fraction of NO: {len(df.filter(pl.col("GT_FLAG") == "NO")) / len(df):.2f}')
print(f'Fraction of YES: {len(df.filter(pl.col("GT_FLAG") == "YES")) / len(df):.2f}')

print(f"\n\nThe shape of main_df currently is: {main_df.shape}")

Length BEFORE dropping rows that are not present in the top 4 races: 223986
Fractions BEFORE dropping rows that are not present in the top 4 races
Fraction of NO: 0.57
Fraction of YES: 0.43
Length AFTER dropping rows that are not present in the top 4 races: 221575
Fractions AFTER dropping rows that are not present in the top 4 races
Fraction of NO: 0.57
Fraction of YES: 0.43


The shape of main_df currently is: (221575, 19)


In [5]:
print(f"The value counts of the races before undersamplign is: {main_df['race'].value_counts()}")

df = main_df
print("Fractions BEFORE undersampling the dataset such that all races occur at uniform intervals is:")
print(f'Fraction of NO: {len(df.filter(pl.col("GT_FLAG") == "NO")) / len(df):.2f}')
print(f'Fraction of YES: {len(df.filter(pl.col("GT_FLAG") == "YES")) / len(df):.2f}')

# Under sample the dataframe and have the races occuring at uniform intervals
min_count = main_df['race'].value_counts(sort=True)[-1][0, 1]

sampled_df = pl.DataFrame()

for race in demographic_dict['race']:
    tmp_df = main_df.filter(pl.col('race') == race)
    if tmp_df.shape[0] > min_count:
        tmp_df = tmp_df.sample(n = min_count, with_replacement=False, seed=42)
    sampled_df = pl.concat([sampled_df, tmp_df])

main_df_sampled = sampled_df
print(f'\n\nThe shape of main_df is now: {main_df_sampled.shape}\n\n')

df = main_df_sampled
print("Fractions AFTER undersampling the dataset such that all races occur at uniform intervals is:")
print(f'Fraction of NO: {len(df.filter(pl.col("GT_FLAG") == "NO")) / len(df):.2f}')
print(f'Fraction of YES: {len(df.filter(pl.col("GT_FLAG") == "YES")) / len(df):.2f}')

The value counts of the races before undersamplign is: shape: (4, 2)
┌─────────────────┬────────┐
│ race            ┆ count  │
│ ---             ┆ ---    │
│ str             ┆ u32    │
╞═════════════════╪════════╡
│ HISPANIC/LATINO ┆ 19529  │
│ WHITE           ┆ 138808 │
│ ASIAN           ┆ 7906   │
│ BLACK           ┆ 55332  │
└─────────────────┴────────┘
Fractions BEFORE undersampling the dataset such that all races occur at uniform intervals is:
Fraction of NO: 0.57
Fraction of YES: 0.43


The shape of main_df is now: (31624, 19)


Fractions AFTER undersampling the dataset such that all races occur at uniform intervals is:
Fraction of NO: 0.60
Fraction of YES: 0.40


In [6]:
# Selecting relevant rows from the final dataframe to make the pre_final_df (the dataframe using which, prompts will be made according to the prompt template)
pre_final_df = main_df_sampled.select(
    cs.by_name(
        "subject_id",
        "stay_id",
        "gender",
        "anchor_age",
        "race",
        "temperature",
        "heartrate",
        "resprate",
        "o2sat",
        "sbp",
        "dbp",
        "pain",
        "chiefcomplaint",
        "numbered_diagnoses",
        "GT_FLAG"
    )
)

In [7]:
pre_final_df.write_csv(pre_final_df_csv_path)