<a href="https://colab.research.google.com/github/NoureldinAyman/Drug-Recommendation/blob/main/Drug_Recommendation_Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Drug Recommendation System using MIMIC-IV

This project implements a sophisticated machine learning pipeline to create a **drug recommendation system**, designed to function as a clinical decision support tool. The primary objective is to accurately predict the specific medications a patient is likely to be prescribed during a hospital admission, based on a comprehensive view of their clinical profile.

### Methodology and Data

The project leverages the **MIMIC-IV (Medical Information Mart for Intensive Care)** dataset, a large, de-identified database containing detailed patient information from critical care units.

A key innovation of this project is its prediction target. Instead of merely predicting a drug's name, the model predicts a **composite drug label**. This granular label combines three critical pieces of information:
* **NDC (National Drug Code):** The specific drug product.
* **Dosage Strength:** The concentration of the medication.
* **Prescription Duration:** The length of the treatment course.

This approach provides a much more clinically actionable prediction compared to generic drug recommendations.

#### Feature Engineering and Modeling

The model's predictive power is built on a rich, multi-modal feature set that fuses both structured and unstructured data.

* **Structured Clinical Data:** This includes patient demographics (age, gender), admission details (type, location), insurance status, and codified clinical events like diagnoses and procedures (ICD codes), and emergency department triage data.
* **Unstructured Clinical Notes:** To capture the nuanced narrative of a patient's condition, the system performs advanced Natural Language Processing (NLP) on discharge summaries. It utilizes **Bio_ClinicalBERT**, a state-of-the-art transformer model pre-trained specifically on biomedical and clinical text, to generate powerful, context-aware embeddings.

These features are then fed into a **deep neural network** architected for multi-label classification, allowing it to predict a unique set of multiple potential medications for each patient.

To ensure the model is robust and generalizable, the data is split on a **patient-level basis**. This prevents data leakage, which occurs when information about the same patient appears in both the training and testing sets. This method results in a more realistic and reliable evaluation of the model's performance on truly unseen patients.

# Table of Contents

- [Setup and Imports](#Setup-and-Imports)
- [Data Loading](#Data-Loading)
- [Data Cleaning and Type Conversion](#Data-Cleaning-and-Type-Conversion)
- [Base Dataframe Creation](#Base-Dataframe-Creation)
- [Refined Target Engineering (Composite Drug Labels)](#Target-Engineering-Composite-Drug-Labels)
- [Feature Engineering](#Feature-Engineering)
    - [Feature Engineering - Demographics and Admissions Data](#FE---Demographics-and-Admissions-Data)
    - [Feature Engineering - Diagnoses (ICD Codes)](#FE---Diagnoses-ICD-Codes)
    - [Feature Engineering - Procedures (ICD Codes)](#FE---Procedures-ICD-Codes)
    - [Feature Engineering - Emergency Department (ED) Data](#FE---Emergency-Department-ED-Data)
- [Text Preprocessing and Transformer Embeddings](#Text-Preprocessing-and-Transformer-Embeddings)
- [Final Data Assembly](#Final-Data)
- [Data Loading for Modeling](#Data-Loading-for-Modeling)
- [Train/Validation/Test Split](#TrainValidationTest-Split)
- [Model Development](#Model-Development)
    - [Model Definition](#Model-Definition)
    - [Model Compilation and Training](#Model-Compilation-and-Training)
    - [Model Evaluation](#Model-Evaluation)

# Setup and Imports

In [1]:
# Import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import gc
import re

In [None]:
# Import ML libraries
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
import torch
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

Mount Google Drive to access files stored there


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Define the base path for the project directory in Google Drive


In [None]:
project_base_path = "/content/drive/MyDrive/AIS302 Project/"

# Data Loading

In [None]:
# Configure paths for all data files and specify if they should be read in chunks
files_config = {
    "patients": {"path": "Data/hosp/patients.csv.gz", "chunk": False},
    "admissions": {"path": "Data/hosp/admissions.csv.gz", "chunk": False},
    "diagnoses_icd": {"path": "Data/hosp/diagnoses_icd.csv.gz", "chunk": False},
    "d_icd_diagnoses": {"path": "Data/hosp/d_icd_diagnoses.csv.gz", "chunk": False},
    "prescriptions": {"path": "Data/hosp/prescriptions.csv.gz", "chunk": True, "chunksize": 500000},
    "labevents": {"path": "Data/hosp/labevents.csv.gz", "chunk": True, "chunksize": 500000},
    "d_labitems": {"path": "Data/hosp/d_labitems.csv.gz", "chunk": False},
    "procedures_icd": {"path": "Data/hosp/procedures_icd.csv.gz", "chunk": False},
    "emar": {"path": "Data/hosp/emar.csv.gz", "chunk": True, "chunksize": 500000},
    "discharge_notes": {"path": "Data/note/discharge.csv.gz", "chunk": False},
    "medrecon": {"path": "Data/ed/medrecon.csv.gz", "chunk": False},
    "triage_ed": {"path": "Data/ed/triage.csv.gz", "chunk": False},
    "edstays": {"path": "Data/ed/edstays.csv.gz", "chunk": False}
}

# Dictionary to hold the loaded dataframes
dfs = {}

# Loop through the file configuration to load each csv file
for name, config in files_config.items():
    full_path = os.path.join(project_base_path, config["path"])

    # If the file is large, read it in chunks to manage memory usage
    if config["chunk"]:
        chunk_list = []
        reader = pd.read_csv(full_path, compression='gzip', low_memory=False, chunksize=config["chunksize"])

        # Iterate over chunks and append them to a list
        for i, chunk_df in enumerate(reader):
            chunk_list.append(chunk_df)

        # Concatenate all chunks into a single dataframe
        dfs[name] = pd.concat(chunk_list, ignore_index=True)
        del chunk_list # Free up memory
        gc.collect()
    else:
        # Read smaller files directly into a dataframe
        dfs[name] = pd.read_csv(full_path, compression='gzip', low_memory=False)
    print(f"  Successfully loaded {name}. Shape: {dfs[name].shape}")

# Data Cleaning and Type Conversion

In [None]:
# Convert date of death column in patients table to datetime objects
dfs['patients']['dod'] = pd.to_datetime(dfs['patients']['dod'], errors='coerce')

# Convert time-related columns in the admissions table to datetime objects
adm_time_cols = ['admittime', 'dischtime', 'deathtime', 'edregtime', 'edouttime']
for col in adm_time_cols:
    dfs['admissions'][col] = pd.to_datetime(dfs['admissions'][col], errors='coerce')

# Convert time-related columns in the prescriptions table to datetime objects
presc_time_cols = ['starttime', 'stoptime']
for col in presc_time_cols:
    dfs['prescriptions'][col] = pd.to_datetime(dfs['prescriptions'][col], errors='coerce')

# Convert time-related columns in the labevents table to datetime objects
lab_time_cols = ['charttime', 'storetime']
for col in lab_time_cols:
    dfs['labevents'][col] = pd.to_datetime(dfs['labevents'][col], errors='coerce')

# Convert chart date in procedures table to datetime objects
dfs['procedures_icd']['chartdate'] = pd.to_datetime(dfs['procedures_icd']['chartdate'], errors='coerce')

# Convert time-related columns in the emar (medication administration) table to datetime objects
emar_time_cols = ['charttime', 'scheduletime', 'storetime']
for col in emar_time_cols:
    dfs['emar'][col] = pd.to_datetime(dfs['emar'][col], errors='coerce')

# Convert time-related columns in discharge notes to datetime objects
discharge_time_cols = ['charttime', 'storetime']
for col in discharge_time_cols:
    dfs['discharge_notes'][col] = pd.to_datetime(dfs['discharge_notes'][col], errors='coerce')

# Convert chart time in medrecon (medication reconciliation) table to datetime objects
dfs['medrecon']['charttime'] = pd.to_datetime(dfs['medrecon']['charttime'], errors='coerce')

# Convert time-related columns in edstays (emergency department stays) to datetime objects
edstays_time_cols = ['intime', 'outtime']
for col in edstays_time_cols:
    dfs['edstays'][col] = pd.to_datetime(dfs['edstays'][col], errors='coerce')

# Base Dataframe Creation

In [None]:
# Create a base dataframe with admission-level information
base_df = dfs['admissions'][['subject_id', 'hadm_id', 'admittime', 'dischtime', 'admission_type', 'admission_location', 'discharge_location', 'insurance', 'language', 'marital_status', 'race', 'hospital_expire_flag']].copy()

# Merge with patient demographic data (gender, age, date of death)
base_df = pd.merge(base_df, dfs['patients'][['subject_id', 'gender', 'anchor_age', 'dod']], on='subject_id', how='left')

# Rename the anchor_age column to 'age' for clarity
base_df.rename(columns={'anchor_age': 'age'}, inplace=True)

# Calculate the length of stay in days for each admission
base_df['los_days'] = (base_df['dischtime'] - base_df['admittime']).dt.total_seconds() / (24 * 60 * 60)

# Target Engineering (Composite Drug Labels)

In [None]:
# Start of refined target variable engineering
prescriptions_temp = dfs['prescriptions'].copy()

# Filter for relevant admissions and convert time columns
prescriptions_temp = prescriptions_temp[prescriptions_temp['hadm_id'].isin(base_df['hadm_id'].unique())]
prescriptions_temp['starttime'] = pd.to_datetime(prescriptions_temp['starttime'], errors='coerce')
prescriptions_temp['stoptime'] = pd.to_datetime(prescriptions_temp['stoptime'], errors='coerce')

# Clean ndc codes by handling nans and removing trailing '.0'
prescriptions_temp['ndc'] = prescriptions_temp['ndc'].fillna('__ORIGINAL_NDC_NAN__')
prescriptions_temp['ndc'] = prescriptions_temp['ndc'].astype(str)
prescriptions_temp['ndc'] = prescriptions_temp['ndc'].str.replace(r'\.0$', '', regex=True)

# Remove records with invalid ndc placeholders
invalid_ndc_placeholders = ['0', '__ORIGINAL_NDC_NAN__', 'nan', 'NaN', 'MISSING_NDC', 'UNKNOWN_NDC']
prescriptions_temp = prescriptions_temp[~prescriptions_temp['ndc'].isin(invalid_ndc_placeholders)]
prescriptions_temp.dropna(subset=['starttime', 'ndc'], inplace=True)

# Calculate prescription duration in days
duration_hours = (prescriptions_temp['stoptime'] - prescriptions_temp['starttime']).dt.total_seconds() / 3600
prescriptions_temp['duration_days'] = duration_hours / 24
prescriptions_temp.loc[prescriptions_temp['duration_days'] < 0, 'duration_days'] = 0 # Handle negative durations

# Categorize the prescription duration into bins
duration_bins = [-float('inf'), 1, 3, 7, 14, 30, float('inf')]
duration_labels = ['<=1d', '1-3d', '3-7d', '7-14d', '14-30d', '>30d']
prescriptions_temp['duration_category'] = pd.cut(prescriptions_temp['duration_days'], bins=duration_bins, labels=duration_labels, right=True)
prescriptions_temp['duration_category'] = prescriptions_temp['duration_category'].cat.add_categories('unknown_duration').fillna('unknown_duration')
prescriptions_temp['duration_category'] = prescriptions_temp['duration_category'].astype(str)

# Clean the product strength information
prescriptions_temp['prod_strength'] = prescriptions_temp['prod_strength'].astype(str)
prescriptions_temp['dosage_form_strength'] = prescriptions_temp['prod_strength'].str.lower().str.strip()
placeholders_for_strength = ['nan', '', 'none', 'unknown_strength']
prescriptions_temp.loc[prescriptions_temp['dosage_form_strength'].isin(placeholders_for_strength) | prescriptions_temp['dosage_form_strength'].isnull(), 'dosage_form_strength'] = 'unknown_strength_cleaned'

# Copy the processed prescriptions dataframe and free up memory
prescriptions_filtered = prescriptions_temp.copy()
del prescriptions_temp

# Ensure data types are correct for creating the composite key
prescriptions_filtered['ndc'] = prescriptions_filtered['ndc'].astype(str)
prescriptions_filtered['dosage_form_strength'] = prescriptions_filtered['dosage_form_strength'].astype(str)
prescriptions_filtered['duration_category'] = prescriptions_filtered['duration_category'].astype(str)

# Filter for prescriptions with valid ndc, strength, and duration
valid_ndc_prescriptions = prescriptions_filtered[ prescriptions_filtered['ndc'].str.match(r'^[0-9\-]+$') & (prescriptions_filtered['ndc'].str.len() >= 4)].copy()
valid_ndc_prescriptions = valid_ndc_prescriptions[valid_ndc_prescriptions['dosage_form_strength'] != 'unknown_strength_cleaned']
valid_ndc_prescriptions = valid_ndc_prescriptions[valid_ndc_prescriptions['duration_category'] != 'unknown_duration']

# Create a composite target label by combining ndc, dosage strength, and duration category
valid_ndc_prescriptions['composite_target_label'] = \
    valid_ndc_prescriptions['ndc'] + "_" + \
    valid_ndc_prescriptions['dosage_form_strength'] + "_" + \
    valid_ndc_prescriptions['duration_category']

# Determine the top N composite targets to use for prediction
TOP_N_COMPOSITE_TARGETS = 200
if len(valid_ndc_prescriptions['composite_target_label'].unique()) < TOP_N_COMPOSITE_TARGETS:
    TOP_N_COMPOSITE_TARGETS = len(valid_ndc_prescriptions['composite_target_label'].unique())
composite_target_counts = valid_ndc_prescriptions['composite_target_label'].value_counts()
if TOP_N_COMPOSITE_TARGETS == 0 :
    final_target_labels_to_predict = []
else:
    final_target_labels_to_predict = composite_target_counts.head(TOP_N_COMPOSITE_TARGETS).index.tolist()

# Filter for prescriptions that fall into our final target labels
relevant_prescriptions_for_y = valid_ndc_prescriptions[
    valid_ndc_prescriptions['composite_target_label'].isin(final_target_labels_to_predict) &
    valid_ndc_prescriptions['hadm_id'].isin(base_df['hadm_id'].unique())
]
# Group by admission to get all composite labels for each
hadm_composite_drugs = relevant_prescriptions_for_y.groupby('hadm_id')['composite_target_label'].apply(lambda x: list(set(x))).reset_index()

# Build the new multi-label target dataframe based on the composite labels
new_target_y_list = []
for hadm_id_val in base_df['hadm_id'].unique():
    labels_for_hadm = hadm_composite_drugs[hadm_composite_drugs['hadm_id'] == hadm_id_val]
    current_hadm_label_vector = {label: 0 for label in final_target_labels_to_predict}
    if not labels_for_hadm.empty:
        prescribed_labels_for_hadm = labels_for_hadm.iloc[0]['composite_target_label']
        for label in prescribed_labels_for_hadm:
            if label in current_hadm_label_vector:
                current_hadm_label_vector[label] = 1
    current_hadm_label_vector['hadm_id'] = hadm_id_val
    new_target_y_list.append(current_hadm_label_vector)
new_target_y_df = pd.DataFrame(new_target_y_list)
new_target_y_df = new_target_y_df.set_index('hadm_id')

# Ensure the base dataframe is indexed by hadm_id for merging
if base_df.index.name != 'hadm_id':
    base_df_indexed = base_df.set_index('hadm_id')
else:
    base_df_indexed = base_df.copy()

# Merge the new composite target labels into the analytical dataframe
analytical_df = pd.merge(base_df_indexed, new_target_y_df, left_index=True, right_index=True, how='left')
# Fill any missing values in the new target columns with 0
analytical_df[final_target_labels_to_predict] = analytical_df[final_target_labels_to_predict].fillna(0).astype(int)

# Feature Engineering

## FE - Demographics and Admissions Data

In [None]:
# Preprocess categorical features: fill missing values with 'Unknown'
cols_to_fill_na_unknown = ['admission_location', 'discharge_location', 'insurance', 'language', 'marital_status']
for col in cols_to_fill_na_unknown:
    if col in analytical_df.columns:
        analytical_df[col] = analytical_df[col].fillna('Unknown')

# Define categorical columns to be one-hot encoded
categorical_cols_to_one_hot = ['admission_type', 'admission_location', 'discharge_location', 'insurance', 'language', 'marital_status', 'race', 'gender']
# Apply one-hot encoding to convert categorical variables into a numerical format
analytical_df = pd.get_dummies(analytical_df, columns=categorical_cols_to_one_hot, prefix=categorical_cols_to_one_hot, dummy_na=False)

# Drop columns that are no longer needed
analytical_df = analytical_df.drop(columns=['subject_id', 'dod'])

## FE - Diagnoses (ICD Codes)

In [None]:
# Add diagnosis information (ICD codes)
diagnoses = dfs['diagnoses_icd'][['hadm_id', 'icd_code']].copy()

# Identify the top 100 most common diagnosis codes
TOP_N_ICD_CODES = 100
common_icd_codes = diagnoses['icd_code'].value_counts().head(TOP_N_ICD_CODES).index.tolist()

# Filter for diagnoses relevant to our admissions and common codes
diag_filtered_for_hot_encode = diagnoses[diagnoses['icd_code'].isin(common_icd_codes) & diagnoses['hadm_id'].isin(analytical_df.index)]

# Group by admission to get a list of diagnoses for each
hadm_icd_codes = diag_filtered_for_hot_encode.groupby('hadm_id')['icd_code'].apply(list).reset_index()

# Create binary features for each of the top 100 diagnoses
diag_feature_list = []
for hadm_id_val in analytical_df.index:
    codes_for_hadm = hadm_icd_codes[hadm_icd_codes['hadm_id'] == hadm_id_val]
    current_hadm_icd_vector = {f"diag_{code}": 0 for code in common_icd_codes}
    if not codes_for_hadm.empty:
        icd_list_for_hadm = codes_for_hadm.iloc[0]['icd_code']
        for code in icd_list_for_hadm:
            if f"diag_{code}" in current_hadm_icd_vector:
                current_hadm_icd_vector[f"diag_{code}"] = 1
    current_hadm_icd_vector['hadm_id'] = hadm_id_val
    diag_feature_list.append(current_hadm_icd_vector)

diag_features_df = pd.DataFrame(diag_feature_list)
diag_features_df = diag_features_df.set_index('hadm_id')
# Merge diagnosis features into the analytical dataframe
analytical_df = analytical_df.merge(diag_features_df, on='hadm_id', how='left')
for col in diag_features_df.columns:
    analytical_df[col] = analytical_df[col].fillna(0).astype(int)

## FE - Procedures (ICD Codes)

In [None]:
# Add procedure information (ICD codes)
procedures_df = dfs['procedures_icd'][['hadm_id', 'icd_code', 'icd_version']].copy()

# Identify the top 50 most common procedure codes
TOP_N_PROC_CODES = 50
common_proc_codes = procedures_df['icd_code'].value_counts().head(TOP_N_PROC_CODES).index.tolist()

# Filter for procedures relevant to our admissions and common codes
proc_filtered_for_hot_encode = procedures_df[procedures_df['icd_code'].isin(common_proc_codes) & procedures_df['hadm_id'].isin(analytical_df.index)]

# Group by admission to get a list of procedures for each
hadm_proc_codes = proc_filtered_for_hot_encode.groupby('hadm_id')['icd_code'].apply(list).reset_index()

# Create binary features for each of the top 50 procedures
proc_feature_list = []
for hadm_id_val in analytical_df.index:
    codes_for_hadm = hadm_proc_codes[hadm_proc_codes['hadm_id'] == hadm_id_val]
    current_hadm_proc_vector = {f"proc_{str(code).replace('.', '_')}": 0 for code in common_proc_codes}

    if not codes_for_hadm.empty:
        proc_list_for_hadm = codes_for_hadm.iloc[0]['icd_code']
        for code in proc_list_for_hadm:
            clean_code_col = f"proc_{str(code).replace('.', '_')}"

            if clean_code_col in current_hadm_proc_vector:
                current_hadm_proc_vector[clean_code_col] = 1
    current_hadm_proc_vector['hadm_id'] = hadm_id_val
    proc_feature_list.append(current_hadm_proc_vector)

proc_features_df = pd.DataFrame(proc_feature_list)
proc_features_df = proc_features_df.set_index('hadm_id')

# Merge procedure features into the analytical dataframe
analytical_df = analytical_df.merge(proc_features_df, on='hadm_id', how='left')
for col in proc_features_df.columns:
    if col in analytical_df.columns:
        analytical_df[col] = analytical_df[col].fillna(0).astype(int)

## FE - Emergency Department (ED) Data

In [None]:
# Add emergency department (ED) data
edstays_df = dfs['edstays'][['hadm_id', 'stay_id', 'intime', 'outtime']].copy()
edstays_df.dropna(subset=['hadm_id'], inplace=True)
edstays_df['hadm_id'] = edstays_df['hadm_id'].astype(analytical_df.index.dtype)

# Calculate length of stay in the ED in hours
edstays_df['intime'] = pd.to_datetime(edstays_df['intime'], errors='coerce')
edstays_df['outtime'] = pd.to_datetime(edstays_df['outtime'], errors='coerce')
edstays_df['ed_los_hours'] = (edstays_df['outtime'] - edstays_df['intime']).dt.total_seconds() / 3600
edstays_df.loc[edstays_df['ed_los_hours'] < 0, 'ed_los_hours'] = np.nan

# Handle cases where a single hospital admission is linked to multiple ED stays by keeping the latest one
edstays_df = edstays_df.sort_values(by=['hadm_id', 'outtime'], ascending=[True, False])
edstays_df = edstays_df.drop_duplicates(subset=['hadm_id'], keep='first')
edstays_processed = edstays_df[['hadm_id', 'stay_id', 'ed_los_hours']].set_index('hadm_id')

# Process triage data from the ED
triage_df = dfs['triage_ed'].copy()
vital_cols = ['temperature', 'heartrate', 'resprate', 'o2sat', 'sbp', 'dbp', 'pain']

# Clean and impute missing vital signs with the median value
for col in vital_cols:
    triage_df[col] = pd.to_numeric(triage_df[col], errors='coerce')
    median_val = triage_df[col].median()
    triage_df[col] = triage_df[col].fillna(median_val)
    triage_df.rename(columns={col: f"ed_{col}_triage"}, inplace=True)

# Process patient acuity score from triage
triage_df['ed_acuity'] = pd.to_numeric(triage_df['acuity'], errors='coerce')
triage_df['ed_acuity'] = triage_df['ed_acuity'].fillna(triage_df['ed_acuity'].mode()[0] if not triage_df['ed_acuity'].mode(dropna=True).empty else 0).astype(int)

# Select the processed triage features
processed_vital_cols_renamed = [f"ed_{col}_triage" for col in vital_cols if f"ed_{col}_triage" in triage_df.columns]
acuity_col_name = ['ed_acuity'] if 'ed_acuity' in triage_df.columns else []
triage_features_cols_to_select = ['stay_id'] + processed_vital_cols_renamed + acuity_col_name
triage_features_cols_to_select = [col for col in triage_features_cols_to_select if col in triage_df.columns or col == 'stay_id']
triage_features = triage_df[triage_features_cols_to_select].copy()

# Process medication reconciliation data from the ED
medrecon_df = dfs['medrecon'].copy()

# Count the number of medications recorded for each ED stay
medrecon_counts = medrecon_df.groupby('stay_id').size().reset_index(name='ed_medrecon_count')

# Combine all ED features (length of stay, triage, medrecon)
ed_features_combined = edstays_processed.reset_index().merge(triage_features, on='stay_id', how='left')
ed_features_combined = ed_features_combined.merge(medrecon_counts, on='stay_id', how='left')
ed_features_combined['ed_medrecon_count'].fillna(0, inplace=True)

# Clean up and set index for merging
ed_features_combined = ed_features_combined.drop(columns=['stay_id'])
ed_features_final = ed_features_combined.set_index('hadm_id')

# One-hot encode the ED acuity score
ed_features_final = pd.get_dummies(ed_features_final, columns=['ed_acuity'], prefix='ed_acuity', dummy_na=False)

# Merge the final ED features into the main analytical dataframe
analytical_df = analytical_df.merge(ed_features_final, on='hadm_id', how='left')

# Fill missing values for the newly added ED features
new_ed_cols = ed_features_final.columns.tolist()
for col in new_ed_cols:
    if analytical_df[col].dtype == 'bool' or col.startswith('ed_acuity_'):
        analytical_df[col] = analytical_df[col].fillna(False).astype(bool)
    else:
        analytical_df[col] = analytical_df[col].fillna(0)

# Text Preprocessing and Transformer Embeddings

In [None]:
# Fefine a function to clean clinical notes
def clean_clinical_text(text):
    if not isinstance(text, str):
        return ""

    # Convert text to lowercase
    text = text.lower()

    # Remove de-identification placeholders like [** ... **]
    text = re.sub(r'\[\*\*.*?\*\*\]', ' ', text)
    text = text.replace('___', ' ')

    # Remove boilerplate headers like patient name, admission date, etc.
    boilerplate_patterns = [
        r"^\s*name\s*:.*?\n", r"^\s*unit no\s*:.*?\n", r"^\s*admission date\s*:.*?\n",
        r"^\s*discharge date\s*:.*?\n", r"^\s*date of birth\s*:.*?\n", r"^\s*sex\s*:.*?\n",
        r"^\s*service\s*:.*?\n", r"^\s*allergies\s*:.*?\n", r"^\s*attending\s*:.*?\n"
    ]
    for pattern in boilerplate_patterns:
        text = re.sub(pattern, '', text, flags=re.IGNORECASE | re.MULTILINE)

    # Remove section headers by replacing colon with a space
    text = re.sub(r'([a-z\s]+):\s*', r'\1 ', text)

    # Normalize whitespace and newlines
    text = text.replace('\n', ' ')
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()

    # Remove special characters
    text = text.replace('*', '')
    text = text.replace('#', '')
    return text

# Set up paths for processing clinical notes
cleaned_notes_parquet_path = os.path.join(project_base_path, "cleaned_discharge_notes.parquet")
notes_for_processing = None

# Check if cleaned notes already exist to save processing time
if os.path.exists(cleaned_notes_parquet_path):
    print("Loading pre-cleaned notes from parquet file...")
    notes_for_processing = pd.read_parquet(cleaned_notes_parquet_path)
    if 'analytical_df' in locals() and analytical_df.index.name == 'hadm_id' and 'hadm_id' in notes_for_processing.columns:
        notes_for_processing['hadm_id'] = notes_for_processing['hadm_id'].astype(analytical_df.index.dtype)
else:
    # If not, load raw notes, clean them, and save the result
    print("Cleaning notes and saving to parquet file...")
    tqdm.pandas() # Enable progress bar for pandas apply
    temp_notes_df = dfs['discharge_notes'][['hadm_id', 'text']].copy()
    temp_notes_df.dropna(subset=['text'], inplace=True)
    if 'analytical_df' in locals():
      temp_notes_df['hadm_id'] = temp_notes_df['hadm_id'].astype(analytical_df.index.dtype)

    # Aggregate and clean notes
    notes_for_processing = temp_notes_df.groupby('hadm_id')['text'].progress_apply(lambda x: ' '.join(x)).reset_index()
    notes_for_processing['cleaned_text'] = notes_for_processing['text'].progress_apply(clean_clinical_text)

    # Save the cleaned text to a parquet file for future use
    notes_for_processing[['hadm_id', 'cleaned_text']].to_parquet(cleaned_notes_parquet_path, index=False)

# Feature engineering with transformer embeddings (Bio_ClinicalBERT)
# This approach replaces TF-IDF features with more semantically rich embeddings
notes_with_cleaned_text = pd.read_parquet(cleaned_notes_parquet_path)
if analytical_df.index.name == 'hadm_id' and 'hadm_id' in notes_with_cleaned_text.columns:
    notes_with_cleaned_text['hadm_id'] = notes_with_cleaned_text['hadm_id'].astype(analytical_df.index.dtype)
notes_with_cleaned_text.set_index('hadm_id', inplace=True)

# Set up the model and tokenizer from Hugging Face
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "emilyalsentzer/Bio_ClinicalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
model.eval() # Set model to evaluation mode

# Define a function to get embeddings for a batch of texts
def get_embeddings_batch(texts_batch, tokenizer, model, device, max_length=512):
    # Tokenize the text batch
    inputs = tokenizer(texts_batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
    inputs = {key: val.to(device) for key, val in inputs.items()}

    # Get model outputs without calculating gradients
    with torch.no_grad():
        outputs = model(**inputs)

        # Use the embedding of the [CLS] token as the representation for the entire text
        cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
    return cls_embeddings

# Align the cleaned notes with the main analytical dataframe to ensure correct order and matching
aligned_notes_series = notes_with_cleaned_text['cleaned_text'].reindex(analytical_df.index).fillna('')

# Generate embeddings in batches to manage memory
all_texts = aligned_notes_series.tolist()
all_embeddings = []
batch_size = 32 # Adjust batch size based on GPU memory
for i in tqdm(range(0, len(all_texts), batch_size), desc="Generating Embeddings"):
    batch_texts = all_texts[i:i + batch_size]
    batch_embeddings = get_embeddings_batch(batch_texts, tokenizer, model, device)
    all_embeddings.append(batch_embeddings)

# Combine embeddings from all batches into a single numpy array
final_embeddings_array = np.vstack(all_embeddings)

# Create a dataframe from the embeddings
embedding_dim = final_embeddings_array.shape[1]
embedding_feature_names = [f"emb_{j}" for j in range(embedding_dim)]
embeddings_df = pd.DataFrame(final_embeddings_array, columns=embedding_feature_names, index=analytical_df.index)

# Merge the new transformer embedding features into the analytical dataframe
analytical_df = analytical_df.merge(embeddings_df, on='hadm_id', how='left')
analytical_df[embedding_feature_names] = analytical_df[embedding_feature_names].fillna(0)

# Final Data

In [None]:
# Save the final analytical dataframe and the list of target labels to files
analytical_df_save_path = os.path.join(project_base_path, "analytical_df_with_transformer_embeddings.parquet")
analytical_df.to_parquet(analytical_df_save_path, index=True)
target_labels_save_path = os.path.join(project_base_path, "final_target_labels_composite.txt")
with open(target_labels_save_path, 'w') as f:
    for label in final_target_labels_to_predict:
        f.write(f"{label}\n")
print("Saved analytical dataframe with embeddings and target labels.")

# Data Loading for Modeling

Load the preprocessed data from the saved files

In [None]:
analytical_df_load_path = os.path.join(project_base_path, "analytical_df_with_transformer_embeddings.parquet")
analytical_df = pd.read_parquet(analytical_df_load_path)
target_labels_load_path = os.path.join(project_base_path, "final_target_labels_composite.txt")

with open(target_labels_load_path, 'r') as f:
    final_target_labels_to_predict = [line.strip() for line in f]

print("Loaded preprocessed data and target labels.")

Separate features (X) and targets (Y)

In [None]:
valid_target_labels = [label for label in final_target_labels_to_predict if label in analytical_df.columns]
Y = analytical_df[valid_target_labels]
X = analytical_df.drop(columns=valid_target_labels)

Drop datetime columns from the feature set as they are not directly used in the model


In [None]:
datetime_cols_to_drop = ['admittime', 'dischtime']
actual_cols_to_drop_from_X = [col for col in datetime_cols_to_drop if col in X.columns]

if actual_cols_to_drop_from_X:
    X = X.drop(columns=actual_cols_to_drop_from_X)

# Train/Validation/Test Split

Prepare for group-based data splitting to prevent data leakage Patients can have multiple admissions, so we need to ensure all admissions for a single patient are in the same split (train, validation, or test)

In [None]:
admissions_df = dfs.get('admissions', pd.read_csv(os.path.join(project_base_path, "Data/hosp/admissions.csv.gz")))
subject_id_map_df = admissions_df[['hadm_id', 'subject_id']].copy()
subject_id_map_df.dropna(subset=['hadm_id', 'subject_id'], inplace=True)
subject_id_map_df['hadm_id'] = subject_id_map_df['hadm_id'].astype(X.index.dtype)

# Create a series that maps each admission (hadm_id) to its patient (subject_id)
groups_series = X.index.map(subject_id_map_df.set_index('hadm_id')['subject_id'])
groups_for_split_array = np.array(groups_series)

Split data into training/validation (80%) and test (20%) sets, keeping patient data together


In [None]:
gss_tv_test = GroupShuffleSplit(n_splits=1, test_size=0.20, random_state=42)
train_val_idx, test_idx = next(gss_tv_test.split(X, Y, groups_for_split_array))
X_train_val, X_test = X.iloc[train_val_idx], X.iloc[test_idx]
Y_train_val, Y_test = Y.iloc[train_val_idx], Y.iloc[test_idx]
groups_of_train_val_set = groups_for_split_array[train_val_idx]

plit the training/validation set further into training (75% of this set) and validation (25%)


In [None]:
gss_train_val = GroupShuffleSplit(n_splits=1, test_size=0.25, random_state=42) # 0.25 * 0.8 = 0.2
train_idx, val_idx = next(gss_train_val.split(X_train_val, Y_train_val, groups_of_train_val_set))
X_train, X_val = X_train_val.iloc[train_idx], X_train_val.iloc[val_idx]
Y_train, Y_val = Y_train_val.iloc[train_idx], Y_train_val.iloc[val_idx]

In [None]:
print(f"Training set shape: {X_train.shape}")
print(f"Validation set shape: {X_val.shape}")
print(f"Test set shape: {X_test.shape}")

# Model Development

## Model Definition

Define the neural network architecture

In [None]:
input_features = X_train.shape[1]
output_labels = Y_train.shape[1]
model = keras.Sequential([
    layers.Input(shape=(input_features,), name="input_layer"),
    layers.Dense(512, activation="relu", name="dense_1"),
    layers.Dropout(0.4, name="dropout_1"),
    layers.Dense(256, activation="relu", name="dense_2"),
    layers.Dropout(0.3, name="dropout_2"),
    layers.Dense(output_labels, activation="sigmoid", name="output_layer") # Sigmoid for multi-label classification
], name="drug_recommender_model")

model.summary()

## Model Compilation and Training

In [None]:
# Define the metrics to monitor during training and evaluation
METRICS = [
    keras.metrics.BinaryAccuracy(name='accuracy'),
    keras.metrics.Precision(name='precision'),
    keras.metrics.Recall(name='recall'),
    keras.metrics.AUC(name='auc', multi_label=True, num_labels=output_labels),
]

In [None]:
# Compile the model with an optimizer, loss function, and metrics
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.0005),
    loss="binary_crossentropy", # Appropriate for multi-label classification
    metrics=METRICS
)

# Set up callbacks to improve the training process
model_checkpoint_path = os.path.join(project_base_path, "best_drug_model.keras")
callbacks = [
    # Save the best version of the model based on validation AUC
    keras.callbacks.ModelCheckpoint(
        filepath=model_checkpoint_path,
        save_best_only=True,
        monitor="val_auc",
        mode="max",
        verbose=1
    ),
    # Stop training early if the validation AUC doesn't improve for a number of epochs
    keras.callbacks.EarlyStopping(
        monitor="val_auc",
        patience=10,
        mode="max",
        restore_best_weights=True,
        verbose=1
    )
]

Define training parameters


In [None]:
EPOCHS = 50
BATCH_SIZE = 128

In [None]:
# Train the model
history = model.fit(
    X_train,
    Y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_val, Y_val),
    callbacks=callbacks,
    verbose=1
)

## Model Results & Evaluation

In [None]:
# Evaluate the final model on the held-out test set
print("\nEvaluating model on the test set...")
test_loss, test_accuracy, test_precision, test_recall, test_auc = model.evaluate(X_test, Y_test, verbose=0)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Test Precision: {test_precision:.4f}")
print(f"Test Recall: {test_recall:.4f}")
print(f"Test AUC: {test_auc:.4f}")