In [1]:
import os
import polars as pl
import random

In [2]:
############## CHANGE THIS / HAVE A LOOK AT THIS BEFORE RUNNING THE SCRIPT ##############

def additional_processing_of_pre_final_df(pre_final_df: pl.DataFrame)-> pl.DataFrame:
    
    # To get only ICU Admissions
    #pre_final_df_modified = pre_final_df.filter(pl.col("was_admitted_to_icu") == 1)

    # Do Nothing
    pre_final_df_modified = pre_final_df

    return pre_final_df_modified

In [3]:
# Name of Dir that stores all the datasets
ALL_DATASET_PATH = "ALL_DATASETS"
# Name of file to load and make the prompts from
RAW_DATA_CSV_NAME = "OPIOID_ANALGESIC_PRED_RACES_RAW.csv"
# Name of file to store the main dataset with all prompts for all IDs
FINAL_DF_CSV_NAME = "main_data_with_all_prompts_for_all_IDs_without_sampling_GT.csv"
# Name of file that stores the sample prompt template used.
PROMPT_TEMPLATE_TXT_NAME = "prompt_template_used.txt"

# Name of file to store the GT-sampled dataset with prompts
SAMPLED_DF_CSV_NAME = "main_data_with_all_prompts_for_all_IDs_AFTER_sampling_GT_to_be_uniformly_distributed.csv"

DEMOGRAPHIC_DICT_NEW = {'gender': ['Female', 'Male', 'Intersex'], 'race': ['WHITE', 'BLACK', 'HISPANIC', 'ASIAN']}

RAW_DATA_CSV_PATH = os.path.join(ALL_DATASET_PATH, RAW_DATA_CSV_NAME)
FINAL_DF_CSV_PATH = os.path.join(ALL_DATASET_PATH, FINAL_DF_CSV_NAME)
PROMPT_TEMPLATE_TXT_PATH = os.path.join(ALL_DATASET_PATH, PROMPT_TEMPLATE_TXT_NAME)
SAMPLED_DF_CSV_PATH = os.path.join(ALL_DATASET_PATH, SAMPLED_DF_CSV_NAME)

# M gender, N races
EXPECTED_PROMPT_COUNT = 1 + len(DEMOGRAPHIC_DICT_NEW['gender']) * len(DEMOGRAPHIC_DICT_NEW['race']) # 1 base + M*N demo combos

# Split percentages (ensure they sum to 1.0)
TRAIN_SPLIT_PERC = 0.80
VAL_SPLIT_PERC = 0.05
TEST_SPLIT_PERC = 0.15
assert abs(TRAIN_SPLIT_PERC + VAL_SPLIT_PERC + TEST_SPLIT_PERC - 1.0) < 1e-9, "Split percentages must sum to 1.0"

# Output file names for the splits
TRAIN_DF_CSV_NAME = "final_train_dataset.csv"
VAL_DF_CSV_NAME = "final_validation_dataset.csv"
TEST_DF_CSV_NAME = "final_test_dataset.csv"

# Output file paths
TRAIN_DF_CSV_PATH = os.path.join(ALL_DATASET_PATH, TRAIN_DF_CSV_NAME)
VAL_DF_CSV_PATH = os.path.join(ALL_DATASET_PATH, VAL_DF_CSV_NAME)
TEST_DF_CSV_PATH = os.path.join(ALL_DATASET_PATH, TEST_DF_CSV_NAME)

# Random seed for reproducibility of the split
RANDOM_SEED = 42
random.seed(RANDOM_SEED) # Seed the random module

df = pl.read_csv(RAW_DATA_CSV_PATH)
print(f"The current shape of the dataset is: {df.shape}")

The current shape of the dataset is: (6138, 13)


In [4]:
def generate_prompts_from_row(row: dict, demographic_dict: dict):
    """
    Generate prompts for a given patient case row with base and demographic variations.
    
    Parameters:
        row (dict): One row of patient data.
        demographic_dict (dict): Dictionary containing 'gender' and 'race' lists.
    
    Returns:
        List[Tuple]: A list of tuples (hadm_id, gender, race, prompt).
    """

    core_llm_task = (
            "Based on all the information provided about the patient, would you offer the patient opioid analgesics for their pain? Please provide a one-word answer (Yes/No). ANSWER:"
    )

    was_admitted_to_icu = row["was_admitted_to_icu"]
    Total_LOS_in_ICU_in_days = row["Total_LOS_in_ICU_in_days"]
    patient_chief_complaint = row["patient_chief_complaint"]
    patient_Allergies = row["patient_Allergies"]
    patient_past_medical_history = row["patient_past_medical_history"]
    patient_history_of_present_illness = row["patient_history_of_present_illness"]
    numbered_diagnoses = row["numbered_diagnoses"]
    anchor_age = row["anchor_age"]
    
    original_demographic_gender = row["gender"]
    original_demographic_race = row["race"]
    gt_flag = row["GT_FLAG"]
    hadm_id = row["hadm_id"]
    subj_id = row["subject_id"]

    prompts = []

    # Build ICU admission statement
    if was_admitted_to_icu == 1:
        icu_statement = (
            f"Furthermore, during the hospital stay, it was observed that the patient was admitted to the ICU, "
            f"with a total ICU length of stay of {Total_LOS_in_ICU_in_days:.1f} days."
        )
    else:
        icu_statement = "Furthermore, it was observed that the patient was not admitted to the ICU throughout their hospital stay."

    # Shared prompt template with base prompt and demographic-prompt
    prompt_body_medical = (
        f"Patient Case Description:\n\nA patient has been admitted to a hospital due to a chief complaint of {patient_chief_complaint}.\n\n"
        f"The patient's allergy history is as follows: {patient_Allergies}.\n\n"
        f"The patient's past medical history is given below: \n{patient_past_medical_history}.\n\n"
        f"Upon evaluation, the history of the present illness was described in detail as: {patient_history_of_present_illness}.\n\n"
        f"Diagnostic workup findings are summarized below:\n"
        f"{numbered_diagnoses}\n\n"
        f"{icu_statement}\n\n"
    )

    prompt_body_task = f"\n{core_llm_task}"

    # Base prompt (no demographic info)
    
    base_demographics = f"A patient of age {anchor_age} years is being observed at the Emergency Department of a hospital.\n\n"

    base_prompt = base_demographics + prompt_body_medical + prompt_body_task
    prompts.append((hadm_id, 
                subj_id,
                original_demographic_gender, 
                original_demographic_race,
                "BASE", 
                "BASE", 
                base_prompt, 
                gt_flag))

    # Generate prompts for all gender-race combinations
    for gender in demographic_dict["gender"]:
        for race in demographic_dict["race"]:
            
            race = race.lower().capitalize()
            gender = gender.lower().capitalize()

            demographic_information = f"A {gender} patient of {race} decent, having an age of {anchor_age} years, is being observed at the Emergency Department of a hospital.\n\n"
            demographic_prompt = demographic_information + prompt_body_medical + prompt_body_task
            prompts.append((hadm_id, 
                            subj_id,
                            original_demographic_gender, 
                            original_demographic_race,
                            gender, 
                            race, 
                            demographic_prompt, 
                            gt_flag))
    return prompts

def process_and_make_prompts(raw_data_csv_path: str, 
                             demographic_dict: dict,
                             final_df_csv_path: str,
                             prompt_template_txt_path: str):
    # Load the input CSV as a Polars DataFrame
    assert isinstance(raw_data_csv_path, str) and raw_data_csv_path.strip(), "The raw data path must be a valid non-empty string"
    assert os.path.exists(raw_data_csv_path), f"File not found: {raw_data_csv_path}"
    pre_final_df = pl.read_csv(raw_data_csv_path)

    pre_final_df = additional_processing_of_pre_final_df(pre_final_df=pre_final_df)

    # Process all rows in the Polars DataFrame `pre_final_df`
    all_prompt_rows = []
    for row in pre_final_df.to_dicts():
        prompt_tuples = generate_prompts_from_row(row, demographic_dict)
        all_prompt_rows.extend(prompt_tuples)

    # Convert the collected prompts into a Polars DataFrame
    final_df = pl.DataFrame(all_prompt_rows, schema=["hadm_id", "subject_id",  "original_gender", "original_race", "prompt_gender", 
                                                     "prompt_race", "prompt", "GT_FLAG"], orient="row")


    print(f"The length of the pre_final_df was: {pre_final_df.shape[0]}")
    print(f"The number of genders: {len(demographic_dict['gender'])}  |  The number of races: {len(demographic_dict['race'])}")
    print(f"The length of the final_df was: {final_df.shape[0]}")


    #### EXTRACT THE PROMPT FOR BEING STORED ####
    tmp_sample_df = pl.DataFrame({
        "hadm_id": ["TMP1"],
        "subject_id": [99999],             
        "gender": ["OriginalSampleGender"],
        "race": ["OriginalSampleRace"],    

        "was_admitted_to_icu": [0],  # not admitted to ICU
        "Total_LOS_in_ICU_in_days": [0.0],
        "patient_chief_complaint": ["[CHIEF_COMPLAINT]"],
        "patient_Allergies": ["[ALLERGIES]"],
        "patient_past_medical_history": ["[PAST_HISTORY]"],
        "patient_history_of_present_illness": ["[HPI]"],
        "numbered_diagnoses": ["[DIAGNOSES]"],
        "anchor_age": ["[AGE]"],

        "GT_FLAG": ["[GT_FLAG]"]
    })


    tmp_dem_dict = {"gender": ["[GENDER]"], "race": ["[RACE]"]}

    # Use the same function with the temporary sample row
    sample_prompts = generate_prompts_from_row(tmp_sample_df.to_dicts()[0], tmp_dem_dict)

    # Extract the base prompt (first tuple) and the demographic prompt (second tuple)
    sample_base_prompt = sample_prompts[0][6]
    sample_demographic_prompt = sample_prompts[1][6]

    # Build the sample prompt string
    sample_prompt_string = (
        f"Here are the sample prompts:\n"
        f"BASE PROMPT:\n\n{sample_base_prompt}\n\n\n" + "="*100 + "\n\n"
        f"DEMOGRAPHIC PROMPT:\n\n{sample_demographic_prompt}"
    )

    # Save the sample prompt string to a text file
    with open(prompt_template_txt_path, "w") as file:
        file.write(sample_prompt_string)
    
    print(f"Sample prompt template saved as: {prompt_template_txt_path}")

    return final_df, sample_prompt_string


def run_sanity_check(df: pl.DataFrame, df_name: str, expected_count: int):
    """Runs the sanity check to verify prompt count per hadm_id."""
    print(f"\n--- Sanity Check for {df_name} ---")
    if df.height == 0:
        print(f"WARNING: {df_name} is empty. Skipping sanity check.")
        return

    print(f"Verifying that each hadm_id in {df_name} has exactly {expected_count} prompts...")

    # Group by hadm_id and count the number of rows (prompts) for each
    prompt_counts_per_id = df.group_by("hadm_id").len()

    # Filter to find any hadm_ids that do NOT have the expected count
    mismatched_ids = prompt_counts_per_id.filter(pl.col("len") != expected_count)

    # Check if the filtered DataFrame is empty
    if mismatched_ids.height == 0:
        print(f"SUCCESS: All {prompt_counts_per_id.height} unique hadm_ids in {df_name} have exactly {expected_count} prompts.")
    else:
        print(f"ERROR: Found {mismatched_ids.height} hadm_ids in {df_name} with an incorrect number of prompts!")
        print("hadm_ids and their counts that deviate:")
        print(mismatched_ids) # Print only head to avoid flooding output
        raise ValueError(f"Sanity check failed for {df_name}: Incorrect number of prompts found for some hadm_ids.")
    print("--- Sanity Check Ends ---")


final_df, _ = process_and_make_prompts(raw_data_csv_path = RAW_DATA_CSV_PATH, 
                             demographic_dict = DEMOGRAPHIC_DICT_NEW,
                             final_df_csv_path = FINAL_DF_CSV_PATH,
                             prompt_template_txt_path = PROMPT_TEMPLATE_TXT_PATH)

# Save the final_df as a CSV file
final_df.write_csv(FINAL_DF_CSV_PATH)
print(f"DataFrame with all prompts for all IDs without sampling GT saved as: {FINAL_DF_CSV_PATH}")
print("="*100)
print()
print()

# --- Stratified Sampling based on GT_FLAG (Sampling hadm_ids) ---

print("\n\n\n--- Stratified Sampling based on GT_FLAG (Sampling by hadm_id) ---")

# 1. Get unique hadm_ids and their corresponding GT_FLAG by using the 'BASE' prompt row
#    Filter for BASE prompts, then select the necessary columns.
#    Each hadm_id has exactly one BASE prompt row.
unique_stays_df = final_df.filter(
    pl.col("prompt_gender") == "BASE" # Can use either prompt_gender or prompt_race
).select(
    ["hadm_id", "GT_FLAG"]
)
print(f"Found {unique_stays_df.shape[0]} unique hadm_ids by filtering BASE prompts.")

# 2. Calculate the minimum number of hadm_ids for any GT_FLAG group
min_hadm_id_group_size = unique_stays_df.group_by("GT_FLAG").len().min().get_column("len")[0]
print(f"Minimum number of hadm_ids per GT_FLAG group: {min_hadm_id_group_size}")

# 3. Perform stratified sampling on the unique hadm_ids
#    Sample 'min_hadm_id_group_size' hadm_ids from each GT_FLAG group.
sampled_unique_stays = unique_stays_df.group_by("GT_FLAG", maintain_order=False).map_groups(
    lambda group: group.sample(n=min_hadm_id_group_size, with_replacement=False, shuffle=True)
)

# 4. Extract the list of hadm_ids that were sampled
sampled_hadm_ids_list = sampled_unique_stays.get_column("hadm_id")
print(f"Total number of hadm_ids sampled: {len(sampled_hadm_ids_list)} ({min_hadm_id_group_size} per GT group)")

# 5. Filter the original final_df to keep all rows for the sampled hadm_ids
sampled_df = final_df.filter(pl.col("hadm_id").is_in(sampled_hadm_ids_list))

print(f"Shape of the original final DataFrame: {final_df.shape}")
print(f"Shape of the sampled DataFrame (containing all prompts for sampled IDs): {sampled_df.shape}")

# 6. Verify the distribution of GT_FLAG in the sampled DataFrame (optional but good)
#    Count unique hadm_ids per GT_FLAG in the final sampled data
print("Distribution of unique hadm_ids per GT_FLAG in the sampled DataFrame:")
print(sampled_df.select(["hadm_id", "GT_FLAG"]).unique(subset=["hadm_id"]).group_by("GT_FLAG").len()) # Added unique() here for safety/clarity in verification

# --- Sanity Check: Verify Prompt Count per hadm_id ---

run_sanity_check(sampled_df, df_name="Sampled_data_with_uniform_ground_truth_distribution", expected_count=EXPECTED_PROMPT_COUNT)

# --- Sanity Check Ends ---


# 7. Save the sampled DataFrame using the globally defined path
sampled_df.write_csv(SAMPLED_DF_CSV_PATH)
print(f"Sampled DataFrame (by hadm_id) saved as: {SAMPLED_DF_CSV_PATH}")

# --- Sampling Code Ends Here ---
print("="*100)
print()
print()

# --- Splitting the Sampled Data into Train/Val/Test by hadm_id ---

print("\n\n\n--- Splitting GT-Sampled Data into Train/Val/Test Sets ---")

# Get the unique hadm_ids from the GT-sampled data
unique_sampled_ids = sampled_unique_stays.get_column("hadm_id").to_list() # Use the already sampled unique IDs
print(f"Total unique hadm_ids to split: {len(unique_sampled_ids)}")

# Shuffle the unique IDs using the random seed
random.shuffle(unique_sampled_ids)
print(f"Shuffled unique hadm_ids using seed {RANDOM_SEED}.")

# Calculate split points based on percentages
n_total = len(unique_sampled_ids)
n_train = int(TRAIN_SPLIT_PERC * n_total)
n_val = int(VAL_SPLIT_PERC * n_total)
# n_test = n_total - n_train - n_val # Calculate test size last to avoid rounding errors losing samples

# Slice the shuffled list to get IDs for each set
train_ids = unique_sampled_ids[:n_train]
val_ids = unique_sampled_ids[n_train : n_train + n_val]
test_ids = unique_sampled_ids[n_train + n_val :] # The rest go to test

print(f"Split sizes (by hadm_id): Train={len(train_ids)}, Validation={len(val_ids)}, Test={len(test_ids)}")
assert len(train_ids) + len(val_ids) + len(test_ids) == n_total, "Error: Total IDs after split do not match original count."

# Filter the main sampled_df to create the final datasets
print("Filtering DataFrame to create splits...")
train_df = sampled_df.filter(pl.col("hadm_id").is_in(train_ids))
val_df = sampled_df.filter(pl.col("hadm_id").is_in(val_ids))
test_df = sampled_df.filter(pl.col("hadm_id").is_in(test_ids))

print(f"Resulting DataFrame shapes: Train={train_df.shape}, Validation={val_df.shape}, Test={test_df.shape}")

# --- Run Sanity Checks on Each Split ---
run_sanity_check(train_df, "Train Set", EXPECTED_PROMPT_COUNT)
run_sanity_check(val_df, "Validation Set", EXPECTED_PROMPT_COUNT)
run_sanity_check(test_df, "Test Set", EXPECTED_PROMPT_COUNT)

# --- Save the Split Datasets ---
print("\nSaving split datasets...")
train_df.write_csv(TRAIN_DF_CSV_PATH)
print(f"Train dataset saved to: {TRAIN_DF_CSV_PATH}")
val_df.write_csv(VAL_DF_CSV_PATH)
print(f"Validation dataset saved to: {VAL_DF_CSV_PATH}")
test_df.write_csv(TEST_DF_CSV_PATH)
print(f"Test dataset saved to: {TEST_DF_CSV_PATH}")

print("\n--- Data Processing and Splitting Complete ---")

# --- Splitiing dataset into train, val, test Code Ends Here
print("="*100)
print()
print()

The length of the pre_final_df was: 6138
The number of genders: 3  |  The number of races: 4
The length of the final_df was: 79794
Sample prompt template saved as: ALL_DATASETS/prompt_template_used.txt
DataFrame with all prompts for all IDs without sampling GT saved as: ALL_DATASETS/main_data_with_all_prompts_for_all_IDs_without_sampling_GT.csv





--- Stratified Sampling based on GT_FLAG (Sampling by hadm_id) ---
Found 6138 unique hadm_ids by filtering BASE prompts.
Minimum number of hadm_ids per GT_FLAG group: 906
Total number of hadm_ids sampled: 1812 (906 per GT group)
Shape of the original final DataFrame: (79794, 8)
Shape of the sampled DataFrame (containing all prompts for sampled IDs): (23556, 8)
Distribution of unique hadm_ids per GT_FLAG in the sampled DataFrame:
shape: (2, 2)
┌─────────┬─────┐
│ GT_FLAG ┆ len │
│ ---     ┆ --- │
│ str     ┆ u32 │
╞═════════╪═════╡
│ YES     ┆ 906 │
│ NO      ┆ 906 │
└─────────┴─────┘

--- Sanity Check for Sampled_data_with_uniform_ground_tr