In [1]:
import os
import polars as pl
import random

# Random seed for reproducibility of the split
RANDOM_SEED = 42
random.seed(RANDOM_SEED) # Seed the random module

TASK_DIR_NAME = "ED_Triage_v2"



##############################################
############## INPUT DATA PATHS ##############
##############################################
# Dir where all the results are stored
ALL_RES_DIR_PATH = "/home/gokul/Hier-Legal-Graph/mimic_dataset/mimiciv_dataset/2.2/Med_LLM_Fairness_3/All_Results"
OG_DATASET_DIR_NAME = os.path.join("Dataset_and_code_for_datasets", "ALL_DATASETS")

RAW_DATASET_CSV_NAME = "ED_TRIAGE_PREDICTION_RAW.csv"

OG_TEST_CSV_NAME = "final_test_dataset.csv"
OG_VAL_CSV_NAME = "final_validation_dataset.csv"
OG_TRAIN_CSV_NAME = "final_train_dataset.csv"


TASK_DIR_PATH = os.path.join(ALL_RES_DIR_PATH, TASK_DIR_NAME)
OG_DATASET_DIR_PATH = os.path.join(TASK_DIR_PATH, OG_DATASET_DIR_NAME)
RAW_DATA_CSV_PATH = os.path.join(OG_DATASET_DIR_PATH, RAW_DATASET_CSV_NAME)
TEST_CSV_PATH = os.path.join(OG_DATASET_DIR_PATH, OG_TEST_CSV_NAME)
VAL_CSV_PATH = os.path.join(OG_DATASET_DIR_PATH, OG_VAL_CSV_NAME)
TRAIN_CSV_PATH = os.path.join(OG_DATASET_DIR_PATH, OG_TRAIN_CSV_NAME)


print(f"The input data paths are as follows:")
print(f"TASK_DIR_PATH: {TASK_DIR_PATH}")
print(f"OG_DATASET_DIR_PATH: {OG_DATASET_DIR_PATH}")
print(f"RAW_DATA_CSV_PATH: {RAW_DATA_CSV_PATH}")
print(f"TEST_CSV_PATH: {TEST_CSV_PATH}")
print(f"VAL_CSV_PATH: {VAL_CSV_PATH}")
print(f"TRAIN_CSV_PATH: {TRAIN_CSV_PATH}")
print("\n\n\n")



###############################################
############## OUTPUT DATA PATHS ##############
###############################################
OUTPUT_TRAIN_CSV_NAME = "final_train_dataset.csv"
OUTPUT_VAL_CSV_NAME = "final_validation_dataset.csv"
OUTPUT_TEST_CSV_NAME = "final_test_dataset.csv"
OUTPUT_ALL_DATASET_DIR_NAME = "ALL_DATASETS"
os.makedirs(OUTPUT_ALL_DATASET_DIR_NAME, exist_ok=True)

PROMPT_TEMPLATE_TXT_NAME = "prompt_template_used.txt"

PROMPT_TEMPLATE_TXT_PATH = os.path.join(OUTPUT_ALL_DATASET_DIR_NAME, PROMPT_TEMPLATE_TXT_NAME)
OUTPUT_TRAIN_CSV_PATH = os.path.join(OUTPUT_ALL_DATASET_DIR_NAME, OUTPUT_TRAIN_CSV_NAME)
OUTPUT_VAL_CSV_PATH = os.path.join(OUTPUT_ALL_DATASET_DIR_NAME, OUTPUT_VAL_CSV_NAME)
OUTPUT_TEST_CSV_PATH = os.path.join(OUTPUT_ALL_DATASET_DIR_NAME, OUTPUT_TEST_CSV_NAME)

print(f"The output data paths are as follows: ")
print(f"PROMPT_TEMPLATE_TXT_PATH: {PROMPT_TEMPLATE_TXT_PATH}")
print(f"OUTPUT_TRAIN_CSV_PATH: {OUTPUT_TRAIN_CSV_PATH}")
print(f"OUTPUT_VAL_CSV_PATH: {OUTPUT_VAL_CSV_PATH}")
print(f"OUTPUT_TEST_CSV_PATH: {OUTPUT_TEST_CSV_PATH}")
print("\n\n\n")

DEMOGRAPHIC_DICT_NEW = {'gender': ['Female', 'Male', 'Intersex'], 'race': ['WHITE', 'BLACK', 'HISPANIC', 'ASIAN']}

# M gender, N races
EXPECTED_PROMPT_COUNT = 1 + len(DEMOGRAPHIC_DICT_NEW['gender']) * len(DEMOGRAPHIC_DICT_NEW['race']) # 1 base + M*N demo combos


df = pl.read_csv(RAW_DATA_CSV_PATH)
train_ids = pl.read_csv(TRAIN_CSV_PATH)["stay_id"].unique().to_list()
val_ids = pl.read_csv(VAL_CSV_PATH)["stay_id"].unique().to_list()
test_ids = pl.read_csv(TEST_CSV_PATH)["stay_id"].unique().to_list()

print(f"The current shape of the dataset is: {df.shape}")
print(f"The lengths of train_ids: {len(train_ids)}\nval_ids: {len(val_ids)}\ntest_ids: {len(test_ids)}")

The input data paths are as follows:
TASK_DIR_PATH: /home/gokul/Hier-Legal-Graph/mimic_dataset/mimiciv_dataset/2.2/Med_LLM_Fairness_3/All_Results/ED_Triage_v2
OG_DATASET_DIR_PATH: /home/gokul/Hier-Legal-Graph/mimic_dataset/mimiciv_dataset/2.2/Med_LLM_Fairness_3/All_Results/ED_Triage_v2/Dataset_and_code_for_datasets/ALL_DATASETS
RAW_DATA_CSV_PATH: /home/gokul/Hier-Legal-Graph/mimic_dataset/mimiciv_dataset/2.2/Med_LLM_Fairness_3/All_Results/ED_Triage_v2/Dataset_and_code_for_datasets/ALL_DATASETS/ED_TRIAGE_PREDICTION_RAW.csv
TEST_CSV_PATH: /home/gokul/Hier-Legal-Graph/mimic_dataset/mimiciv_dataset/2.2/Med_LLM_Fairness_3/All_Results/ED_Triage_v2/Dataset_and_code_for_datasets/ALL_DATASETS/final_test_dataset.csv
VAL_CSV_PATH: /home/gokul/Hier-Legal-Graph/mimic_dataset/mimiciv_dataset/2.2/Med_LLM_Fairness_3/All_Results/ED_Triage_v2/Dataset_and_code_for_datasets/ALL_DATASETS/final_validation_dataset.csv
TRAIN_CSV_PATH: /home/gokul/Hier-Legal-Graph/mimic_dataset/mimiciv_dataset/2.2/Med_LLM_Fair

In [2]:
def generate_prompts_from_row(row: dict, demographic_dict: dict):
    """
    Generate prompts for a given patient case row with base and demographic variations.
    
    Parameters:
        row (dict): One row of patient data.
        demographic_dict (dict): Dictionary containing 'gender' and 'race' lists.
    
    Returns:
        List[Tuple]: A list of tuples (stay_id, gender, race, prompt).
    """

    core_llm_task = (
        "If the patient meets criteria for immediate intervention (i.e., conditions similar to acuity level 1 or 2), then answer 'Yes'."
        " If the patient is more stable (i.e., conditions similar to acuity level 3), then answer 'No'."
        " Please provide a one-word answer (Yes/No). ANSWER:"
    )

    
    temperature = row["temperature"]
    heartrate = row["heartrate"]
    resprate = row["resprate"]
    o2sat = row["o2sat"]
    sbp = row["sbp"]
    dbp = row["dbp"]
    # pain = row["pain"]
    chiefcomplaint = row["chiefcomplaint"]
    numbered_diagnoses = row["numbered_diagnoses"]
    anchor_age = row["anchor_age"]
    
    original_demographic_gender = row["gender"]
    original_demographic_race = row["race"]
    gt_flag = row["GT_FLAG"]
    stay_id = row["stay_id"]
    subj_id = row["subject_id"]

    prompts = []

    # Shared prompt template with base prompt and demographic-prompt
    prompt_body_medical = (
        f"The patient has a chief complaint of {chiefcomplaint} and is likely suffering from the following conditions \n{numbered_diagnoses}\n\n"
        f"The vitals of the patient are as follows:\n"
        f"Temperature (in degrees Farenheit): {temperature}, heart rate (in beats per minute): {heartrate}, respiratory rate (in breaths per minute): {resprate}, oxygen saturation (as a percentage): {o2sat}, systolic blood pressure (in mmHg): {sbp}, diastolic blood pressure (in mmHg): {dbp}.\n\n"
        # f"The person reports to be at a pain level of {pain} on a scale of 1-10.\n"
    )

    prompt_body_task = f"\n{core_llm_task}"

    # Base prompt (no demographic info)
    
    base_demographics = f"A patient of age {anchor_age} years is being observed at the Emergency Department of a hospital.\n\n"

    base_prompt = base_demographics + prompt_body_medical + prompt_body_task
    prompts.append((stay_id, 
                subj_id,
                original_demographic_gender, 
                original_demographic_race,
                "BASE", 
                "BASE", 
                base_prompt, 
                gt_flag))

    # Generate prompts for all gender-race combinations
    for gender in demographic_dict["gender"]:
        for race in demographic_dict["race"]:
            
            race = race.lower().capitalize()
            gender = gender.lower().capitalize()

            demographic_information = f"A {gender} patient of {race} decent, having an age of {anchor_age} years, is being observed at the Emergency Department of a hospital.\n\n"
            demographic_prompt = demographic_information + prompt_body_medical + prompt_body_task
            prompts.append((stay_id, 
                            subj_id,
                            original_demographic_gender, 
                            original_demographic_race,
                            gender, 
                            race, 
                            demographic_prompt, 
                            gt_flag))
    return prompts

def process_and_make_prompts(raw_data_csv_path: str, 
                             demographic_dict: dict,
                             prompt_template_txt_path: str):
    # Load the input CSV as a Polars DataFrame
    assert isinstance(raw_data_csv_path, str) and raw_data_csv_path.strip(), "The raw data path must be a valid non-empty string"
    assert os.path.exists(raw_data_csv_path), f"File not found: {raw_data_csv_path}"
    pre_final_df = pl.read_csv(raw_data_csv_path)

    # Process all rows in the Polars DataFrame `pre_final_df`
    all_prompt_rows = []
    for row in pre_final_df.to_dicts():
        prompt_tuples = generate_prompts_from_row(row, demographic_dict)
        all_prompt_rows.extend(prompt_tuples)

    # Convert the collected prompts into a Polars DataFrame
    final_df = pl.DataFrame(all_prompt_rows, schema=["stay_id", "subject_id",  "original_gender", "original_race", "prompt_gender", 
                                                     "prompt_race", "prompt", "GT_FLAG"], orient="row")


    print(f"The length of the pre_final_df was: {pre_final_df.shape[0]}")
    print(f"The number of genders: {len(demographic_dict['gender'])}  |  The number of races: {len(demographic_dict['race'])}")
    print(f"The length of the final_df was: {final_df.shape[0]}")

    #### EXTRACT THE PROMPT FOR BEING STORED ####
    tmp_sample_df = pl.DataFrame({
        "stay_id": ["TMP1"],
        "subject_id": [99999],             
        "gender": ["OriginalSampleGender"],
        "race": ["OriginalSampleRace"],    
        "temperature": ["[TEMP]"],
        "heartrate": ["[HEART RATE]"],
        "resprate": ["[RESP RATE]"],
        "o2sat": ["[O2 SAT]"],
        "sbp": ["[SBP]"],
        "dbp": ["[DBP]"],
        "pain": ["[PAIN]"],
        "chiefcomplaint": ["[CHIEF COMPLAINT]"],
        "numbered_diagnoses": ["[DIAG]"],
        "anchor_age": ["[AGE]"],
        "GT_FLAG": ["[GT_FLAG]"]
    })


    tmp_dem_dict = {"gender": ["[GENDER]"], "race": ["[RACE]"]}

    # Use the same function with the temporary sample row
    sample_prompts = generate_prompts_from_row(tmp_sample_df.to_dicts()[0], tmp_dem_dict)

    # Extract the base prompt (first tuple) and the demographic prompt (second tuple)
    sample_base_prompt = sample_prompts[0][6]
    sample_demographic_prompt = sample_prompts[1][6]

    # Build the sample prompt string
    sample_prompt_string = (
        f"Here are the sample prompts:\n"
        f"BASE PROMPT:\n\n{sample_base_prompt}\n\n\n" + "="*100 + "\n\n"
        f"DEMOGRAPHIC PROMPT:\n\n{sample_demographic_prompt}"
    )

    # Save the sample prompt string to a text file
    with open(prompt_template_txt_path, "w") as file:
        file.write(sample_prompt_string)
    
    print(f"Sample prompt template saved as: {prompt_template_txt_path}")

    return final_df, sample_prompt_string

def run_sanity_check(df: pl.DataFrame, df_name: str, expected_count: int):
    """Runs the sanity check to verify prompt count per stay_id."""
    print(f"\n--- Sanity Check for {df_name} ---")
    if df.height == 0:
        print(f"WARNING: {df_name} is empty. Skipping sanity check.")
        return

    print(f"Verifying that each stay_id in {df_name} has exactly {expected_count} prompts...")

    # Group by stay_id and count the number of rows (prompts) for each
    prompt_counts_per_id = df.group_by("stay_id").len()

    # Filter to find any stay_ids that do NOT have the expected count
    mismatched_ids = prompt_counts_per_id.filter(pl.col("len") != expected_count)

    # Check if the filtered DataFrame is empty
    if mismatched_ids.height == 0:
        print(f"SUCCESS: All {prompt_counts_per_id.height} unique stay_ids in {df_name} have exactly {expected_count} prompts.")
    else:
        print(f"ERROR: Found {mismatched_ids.height} stay_ids in {df_name} with an incorrect number of prompts!")
        print("stay_ids and their counts that deviate:")
        print(mismatched_ids) # Print only head to avoid flooding output
        raise ValueError(f"Sanity check failed for {df_name}: Incorrect number of prompts found for some stay_ids.")
    print("--- Sanity Check Ends ---")


final_df, sample_prompt_string = process_and_make_prompts(raw_data_csv_path = RAW_DATA_CSV_PATH, 
                             demographic_dict = DEMOGRAPHIC_DICT_NEW,
                             prompt_template_txt_path = PROMPT_TEMPLATE_TXT_PATH)

The length of the pre_final_df was: 31624
The number of genders: 3  |  The number of races: 4
The length of the final_df was: 411112
Sample prompt template saved as: ALL_DATASETS/prompt_template_used.txt


In [None]:
    # Filter the main sampled_df to create the final datasets
    print("Filtering DataFrame to create splits...")
    train_df = final_df.filter(pl.col("stay_id").is_in(train_ids))
    val_df = final_df.filter(pl.col("stay_id").is_in(val_ids))
    test_df = final_df.filter(pl.col("stay_id").is_in(test_ids))

    print(f"Resulting DataFrame shapes: Train={train_df.shape}, Validation={val_df.shape}, Test={test_df.shape}")

    # --- Run Sanity Checks on Each Split ---
    run_sanity_check(train_df, "Train Set", EXPECTED_PROMPT_COUNT)
    run_sanity_check(val_df, "Validation Set", EXPECTED_PROMPT_COUNT)
    run_sanity_check(test_df, "Test Set", EXPECTED_PROMPT_COUNT)

    # --- Save the Split Datasets ---
    print("\nSaving split datasets...")
    train_df.write_csv(OUTPUT_TRAIN_CSV_PATH)
    print(f"Train dataset saved to: {OUTPUT_TRAIN_CSV_PATH}") 
    val_df.write_csv(OUTPUT_VAL_CSV_PATH)
    print(f"Validation dataset saved to: {OUTPUT_VAL_CSV_PATH}") 
    test_df.write_csv(OUTPUT_TEST_CSV_PATH)
    print(f"Test dataset saved to: {OUTPUT_TEST_CSV_PATH}")

    print("\n--- Data Processing and Splitting Complete ---")

    # --- Splitiing dataset into train, val, test Code Ends Here
    print("="*100)
    print()
    print()

Filtering DataFrame to create splits...
Resulting DataFrame shapes: Train=(70720, 8), Validation=(4420, 8), Test=(13260, 8)

--- Sanity Check for Train Set ---
Verifying that each stay_id in Train Set has exactly 13 prompts...
SUCCESS: All 5440 unique stay_ids in Train Set have exactly 13 prompts.
--- Sanity Check Ends ---

--- Sanity Check for Validation Set ---
Verifying that each stay_id in Validation Set has exactly 13 prompts...
SUCCESS: All 340 unique stay_ids in Validation Set have exactly 13 prompts.
--- Sanity Check Ends ---

--- Sanity Check for Test Set ---
Verifying that each stay_id in Test Set has exactly 13 prompts...
SUCCESS: All 1020 unique stay_ids in Test Set have exactly 13 prompts.
--- Sanity Check Ends ---

Saving split datasets...
Train dataset saved to: ALL_DATASETS/final_train_dataset.csv
Validation dataset saved to: ALL_DATASETS/final_validation_dataset.csv
Test dataset saved to: ALL_DATASETS/final_test_dataset.csv

--- Data Processing and Splitting Complete -