## Setup

In [1]:
# from google.colab import auth
# from google.cloud import bigquery
# from google.colab import data_table
# from google.colab import widgets

# from collections import Counter
# import re
# import numpy as np
# import pandas as pd
# import math



In [None]:
# import os
# from google.colab import drive
# import sys

# drive.mount('/content/drive')
# os.chdir('/content/drive/MyDrive/HML/Final Project')
# sys.path.append(os.path.abspath('/content/drive/MyDrive/HML/Final Project'))

In [3]:
# !pip install duckdb --quiet
# import duckdb

In [4]:
# drive_path = '/content/drive/MyDrive/HML/Final Project/MIMIC-III'
# con = duckdb.connect(f'{drive_path}/mimiciii.duckdb')

In [5]:
import pandas as pd
import numpy as np
from tqdm import tqdm 

import torch
from torch_geometric.nn import GATv2Conv
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from tensorflow.keras.preprocessing.sequence import pad_sequences


2025-08-25 12:58:33.506851: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1756115913.520552 1853399 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1756115913.524836 1853399 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1756115913.535862 1853399 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1756115913.535875 1853399 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1756115913.535877 1853399 computation_placer.cc:177] computation placer alr

## Loading data

In [None]:
subject_ids = pd.read_csv('data/initial_cohort.csv')['subject_id'].to_list()
lavbevent_meatdata = pd.read_csv('data/labs_metadata.csv')
vital_meatdata = pd.read_csv('data/vital_metadata.csv')

In [None]:
ICUQ = \
"""--sql
SELECT admissions.subject_id::INTEGER AS subject_id, admissions.hadm_id::INTEGER AS hadm_id
, admissions.admittime::DATE AS admittime, admissions.dischtime::DATE AS dischtime
, admissions.ethnicity, admissions.deathtime::DATE AS deathtime
, patients.gender, patients.dob::DATE AS dob, icustays.icustay_id::INTEGER AS icustay_id, patients.dod::DATE as dod,
icustays.intime::DATE AS intime,icustays.outtime::DATE AS outtime
FROM admissions
INNER JOIN patients
    ON admissions.subject_id = patients.subject_id
LEFT JOIN icustays
    ON admissions.hadm_id = icustays.hadm_id

WHERE admissions.has_chartevents_data = 1
AND admissions.subject_id::INTEGER IN ?
ORDER BY admissions.subject_id, admissions.hadm_id, admissions.admittime;
"""

# icu =  con.execute(ICUQ, [subject_ids]).fetchdf().rename(str.lower, axis='columns')


In [None]:
LABQUERY = \
f"""--sql
SELECT labevents.subject_id::INTEGER AS subject_id\
      , labevents.hadm_id::INTEGER AS hadm_id\
      , labevents.charttime::DATE AS charttime
      , labevents.itemid::INTEGER AS itemid\
      , labevents.valuenum::DOUBLE AS valuenum
      , admissions.admittime::DATE AS admittime
FROM labevents
          INNER JOIN admissions
                    ON labevents.subject_id = admissions.subject_id
                        AND labevents.hadm_id = admissions.hadm_id
                        AND labevents.charttime::DATE between
                            (admissions.admittime::DATE)
                            AND (admissions.admittime::DATE + interval 48 hour)
                        AND itemid::INTEGER IN ? \
                        """

VITQUERY = f"""--sql
        SELECT chartevents.subject_id::INTEGER AS subject_id\
             , chartevents.hadm_id::INTEGER AS hadm_id\
             , chartevents.charttime::DATE AS charttime\
             , chartevents.itemid::INTEGER AS itemid\
             , chartevents.valuenum::DOUBLE AS valuenum\
             , admissions.admittime::DATE AS admittime\
        FROM chartevents
                 INNER JOIN admissions
                            ON chartevents.subject_id = admissions.subject_id
                                AND chartevents.hadm_id = admissions.hadm_id
                                AND chartevents.charttime::DATE between
                                   (admissions.admittime::DATE)
                                   AND (admissions.admittime::DATE + interval 48 hour)
                                AND itemid::INTEGER in ?
      -- exclude rows marked as error
      AND chartevents.error::INTEGER IS DISTINCT \
        FROM 1 \
        """


# lab = con.execute(LABQUERY, [lavbevent_meatdata['itemid'].tolist()]).fetchdf().rename(str.lower, axis='columns')
# vit = con.execute(VITQUERY, [vital_meatdata['itemid'].tolist()]).fetchdf().rename(str.lower, axis='columns')

In [9]:
pred_window = 13*24      # duration of prediction window (hours)
pred_gap = 24            # minimal gap between prediction and target (hours)
min_los = 24             # minimal length of stay (hours)
min_target_onset = 2*24  # minimal time of target since admission (hours)
pred_freq = '4H'        # prediction frequency

labs = pd.read_csv('data/labs.csv')
vits = pd.read_csv('data/vits.csv')
hosps = pd.read_csv('data/icu.csv')

for col in ['admittime', 'dischtime', 'dob', 'dod', 'intime', 'outtime']:
    hosps[col] = pd.to_datetime(hosps[col].str.strip(), errors='coerce')

In [10]:
# Create labels for each subject_id and hadm_id combination
def create_labels(hosps):
    """
    Create three labels for each subject_id and hadm_id:
    1. 30-day mortality (died within 30 days of discharge or during admission)
    2. Prolonged stay (length of stay > 7 days)
    3. 30-day readmission (readmitted within 30 days of discharge)
    """
    
    hosps_sorted = hosps.sort_values(['subject_id', 'admittime'])[['subject_id','hadm_id', 'admittime','dischtime','dod']].drop_duplicates().copy()
    hosps_sorted['los_hosp_hr'] = (hosps_sorted['dischtime'] - hosps_sorted['admittime']).dt.total_seconds()/3600
    hosps_sorted['mort_30day'] = 0
    
    died_during_admission = (~hosps_sorted['dod'].isna()) & (hosps_sorted['dod'] <= hosps_sorted['dischtime'])
    hosps_sorted.loc[died_during_admission, 'mort_30day'] = 1
    days_to_death_post_discharge = (hosps_sorted['dod'] - hosps_sorted['dischtime']).dt.total_seconds() / (24 * 3600)
    died_within_30_days = (~hosps_sorted['dod'].isna()) & (days_to_death_post_discharge <= 30) & (days_to_death_post_discharge >= 0)
    hosps_sorted.loc[died_within_30_days, 'mort_30day'] = 1
    
    hosps_sorted['prolonged_stay'] = (hosps_sorted['los_hosp_hr'] > 7 * 24).astype(int)
    hosps_sorted['readmission_30day'] = 0
    
    hosps_sorted['next_admittime'] = hosps_sorted.groupby('subject_id')['admittime'].shift(-1)
    days_between = (hosps_sorted['next_admittime'] - hosps_sorted['dischtime']).dt.total_seconds() / (24 * 3600)
    hosps_sorted['readmission_30day'] = ((days_between > 0) & (days_between <= 30)).astype(int)
    
    return hosps_sorted[['subject_id', 'hadm_id', 'mort_30day', 'prolonged_stay', 'readmission_30day']]

labels_df = create_labels(hosps)
print("Labels created successfully!")
print(f"Shape: {labels_df.shape}")
print("\nLabel distributions:")
print(f"30-day mortality: {labels_df['mort_30day'].sum()} / {len(labels_df)} ({labels_df['mort_30day'].mean():.3f})")
print(f"Prolonged stay: {labels_df['prolonged_stay'].sum()} / {len(labels_df)} ({labels_df['prolonged_stay'].mean():.3f})")
print(f"30-day readmission: {labels_df['readmission_30day'].sum()} / {len(labels_df)} ({labels_df['readmission_30day'].mean():.3f})")

Labels created successfully!
Shape: (40156, 5)

Label distributions:
30-day mortality: 5671 / 40156 (0.141)
Prolonged stay: 17308 / 40156 (0.431)
30-day readmission: 2147 / 40156 (0.053)


In [None]:
bios = pd.read_csv('data/bios.csv')
bios['charttime'] = pd.to_datetime(bios['charttime'].str.strip(), errors='coerce')

## preprocessing

In [12]:
# ethnicity  - to category
hosps.ethnicity = hosps.ethnicity.str.lower()
hosps.loc[(hosps.ethnicity.str.contains('^white')),'ethnicity'] = 'white'
hosps.loc[(hosps.ethnicity.str.contains('^black')),'ethnicity'] = 'black'
hosps.loc[(hosps.ethnicity.str.contains('^hisp')) | (hosps.ethnicity.str.contains('^latin')),'ethnicity'] = 'hispanic'
hosps.loc[(hosps.ethnicity.str.contains('^asia')),'ethnicity'] = 'asian'
hosps.loc[~(hosps.ethnicity.str.contains('|'.join(['white', 'black', 'hispanic', 'asian']))),'ethnicity'] = 'other'

# ethnicity - one hot encoding
hosps['eth_white'] = (hosps['ethnicity'] == 'white').astype(int)
hosps['eth_black'] = (hosps['ethnicity'] == 'black').astype(int)
hosps['eth_hispanic'] = (hosps['ethnicity'] == 'hispanic').astype(int)
hosps['eth_asian'] = (hosps['ethnicity'] == 'asian').astype(int)
hosps['eth_other'] = (hosps['ethnicity'] == 'other').astype(int)
hosps.drop(['ethnicity', 'deathtime'], inplace=True, axis=1)

In [13]:

# Generate feature columns for los, age and mortality
def age(admittime, dob):
    if admittime < dob:
      return 0
    return admittime.year - dob.year - ((admittime.month, admittime.day) < (dob.month, dob.day))

hosps['age'] = hosps.apply(lambda row: age(row['admittime'], row['dob']), axis=1)
hosps['los_hosp_hr'] = (hosps.dischtime - hosps.admittime).dt.total_seconds()/3600
hosps['mort'] = np.where(~np.isnat(hosps.dod),1,0)

# Gender to binary
hosps['gender'] = np.where(hosps['gender'] == "M", 1, 0)

# @title Q1.1 - Patient Exclusion Criteria
hosps = hosps.sort_values('admittime').groupby('subject_id').first().reset_index()
print(f"1. Include only first admissions: N={hosps.shape[0]}")

hosps = hosps[hosps.age.between(18,90)]
print(f"2. Exclusion by ages: N={hosps.shape[0]}")

# Exclude patients hospitalized for less than 24 hours
hosps = hosps[hosps['los_hosp_hr'] >= min_los]
print(f"3. Include only patients who admitted for at least {min_los} hours: N={hosps.shape[0]}")

# Exclude patients that died in the first 48 hours of admission
hours_to_death = (hosps['dod'] - hosps['admittime']).dt.total_seconds() / 3600
hosps = hosps[~((hosps['mort'].astype(bool)) & (hours_to_death < min_target_onset))]
print(f"4. Exclude patients who died within {min_target_onset}-hours of admission: N={hosps.shape[0]}")

1. Include only first admissions: N=32513
2. Exclusion by ages: N=25548
3. Include only patients who admitted for at least 24 hours: N=25168
4. Exclude patients who died within 48-hours of admission: N=24825


In [14]:
labs = labs[labs['hadm_id'].isin(hosps['hadm_id'])]
labs = pd.merge(labs,lavbevent_meatdata,on='itemid')
labs = labs[labs['valuenum'].between(labs['min'],labs['max'],  inclusive='both')]

vits = vits[vits['hadm_id'].isin(hosps['hadm_id'])]
vits = pd.merge(vits,vital_meatdata,on='itemid')
vits = vits[vits['valuenum'].between(vits['min'],vits['max'], inclusive='both')]

vits.loc[(vits['feature name'] == 'TempF'),'valuenum'] = (vits[vits['feature name'] == 'TempF']['valuenum']-32)/1.8
vits.loc[vits['feature name'] == 'TempF','feature name'] = 'TempC'

merged = pd.concat([vits, labs])
merged['charttime'] = pd.to_datetime(merged['charttime'], errors='coerce')

pivot = pd.pivot_table(merged, index=['subject_id', 'hadm_id', pd.Grouper(key='charttime', freq=pred_freq)],
                       columns=['feature name'], values='valuenum', aggfunc=['mean', 'max', 'min', 'std'])
pivot.columns = [f'{c[1]}_{c[0]}' for c in pivot.columns.to_flat_index()]

# temp = merged.copy()

merged = pd.merge(hosps, pivot.reset_index(), on=['subject_id', 'hadm_id'])
merged[pivot.columns] = merged.groupby(['subject_id', 'hadm_id'])[pivot.columns].ffill()


  pivot = pd.pivot_table(merged, index=['subject_id', 'hadm_id', pd.Grouper(key='charttime', freq=pred_freq)],


In [15]:
bios_onehot = pd.get_dummies(bios, columns=['org_itemid'], prefix='org')
groupby_cols = ['subject_id', 'hadm_id']
onehot_cols = [col for col in bios_onehot.columns if col.startswith('org_')]
bios_table = bios_onehot.groupby(groupby_cols)[onehot_cols].sum().reset_index()
for col in onehot_cols:
    bios_table[col] = (bios_table[col] > 0).astype(int)

In [16]:
# Left join merged with bios_table on (subject_id, hadm_id, charttime)
merged_with_bios = merged[['subject_id', 'hadm_id']].merge(
    bios_table,
    on=['subject_id', 'hadm_id'],
    how='left'
)

onehot_cols = [col for col in merged_with_bios.columns if col.startswith('org_')]
merged_with_bios[onehot_cols] = merged_with_bios[onehot_cols].fillna(0).astype(int)
bios = merged_with_bios.drop(columns=['hadm_id']).drop_duplicates().set_index('subject_id')


In [17]:
merged = merged.sort_values(['subject_id', 'hadm_id', 'charttime'])
labs_features_names = set(labs['feature name'])
vits_features_names = set(vits['feature name'])
labs_features = [col for col in merged.columns if col.split('_')[0] in labs_features_names]
vits_features = [col for col in merged.columns if col.split('_')[0] in vits_features_names]

lab_diff_cols = {}
for col in labs_features:
    if col.find("mean") >= 0:
      base = merged.groupby(['subject_id', 'hadm_id'])[col].transform('first')
      lab_diff_cols[f'{col}_diff'] = merged[col] - base

lab_diff_df = pd.DataFrame(lab_diff_cols)

vital_diff_cols = {}
for col in vits_features:
  if col.find("mean") >= 0:
      diff_series = merged.groupby(['subject_id', 'hadm_id'])[col].diff()
      vital_diff_cols[f'{col}_diff'] = diff_series

vital_diff_df = pd.DataFrame(vital_diff_cols)

# Concatenate back to original DataFrame
merged = pd.concat([merged, lab_diff_df, vital_diff_df], axis=1)

In [18]:
merged['charttime'] = pd.to_datetime(merged['charttime'], errors='coerce')

time_to_death = (merged['dod'] - merged['charttime']).dt.total_seconds() / (60 * 60)
merged['target'] = (pd.notnull(time_to_death) & (time_to_death <= pred_window + pred_gap)).astype(int)
merged = merged[time_to_death.isna() | (time_to_death >= pred_gap)]

In [None]:
prescriptions = pd.read_csv('data/prescriptions.csv', index_col=0)
prescriptions = prescriptions.loc[prescriptions.hadm_id.isin(merged.hadm_id) & prescriptions.subject_id.isin(merged.subject_id)]

# Convert date columns to datetime
prescriptions['startdate'] = pd.to_datetime(prescriptions['startdate'])
prescriptions['enddate'] = pd.to_datetime(prescriptions['enddate'])
prescriptions['startdate'] = prescriptions['startdate'].fillna(prescriptions['enddate'])
prescriptions['enddate'] = prescriptions['enddate'].fillna(prescriptions['startdate'])
prescriptions = prescriptions.dropna(subset=['startdate', 'enddate'], how='all')

item_counts = prescriptions[["subject_id", "drug"]].drop_duplicates()["drug"].value_counts()
valid_items = item_counts[item_counts > 240].index
prescriptions_filtered = prescriptions[prescriptions["drug"].isin(valid_items)][['subject_id', "hadm_id", 'drug', 'startdate', 'enddate']].drop_duplicates()

# Merge on subject_id and hadm_id once (cartesian product within each admission)
merged_all = merged.merge(
    prescriptions_filtered,
    on=["subject_id", "hadm_id"],
    suffixes=("", "_presc")
)

# Filter by charttime between prescription window
mask = (merged_all["charttime"].dt.date >= merged_all["startdate"].dt.date) & \
       (merged_all["charttime"].dt.date <= merged_all["enddate"].dt.date)

merged_all = merged_all[mask]

merged_all['value'] = 1
drug_pivot = merged_all.pivot_table(
    index=['subject_id', 'hadm_id', 'charttime'],
    columns='drug',
    values='value',
    aggfunc='count',
    fill_value=0
)

tmp = merged.merge(drug_pivot.reset_index().drop_duplicates(), on=['subject_id', 'hadm_id', 'charttime'], how='left')

(248568, 273)


Unnamed: 0_level_0,Unnamed: 1_level_0,drug,0.45% Sodium Chloride,0.83% Sodium Chloride,0.9% Sodium Chloride,0.9% Sodium Chloride (Mini Bag Plus),1/2 NS,5% Dextrose,5% Dextrose (EXCEL BAG),Acetaminophen,Acetaminophen (Liquid),Acetaminophen IV,...,Valsartan,Vancomycin,Vancomycin HCl,Vasopressin,Vial,Vitamin D,Warfarin,Zolpidem Tartrate,traZODONE,traZODONE HCl
subject_id,hadm_id,charttime,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
4,185777,2191-03-16 00:00:00,0,0,0,0,0,0,0,1,0,0,...,0,0,1,0,0,0,0,0,0,0
4,185777,2191-03-16 04:00:00,0,0,0,0,0,0,0,1,0,0,...,0,0,1,0,0,0,0,0,0,0
4,185777,2191-03-16 08:00:00,0,0,0,0,0,0,0,1,0,0,...,0,0,1,0,0,0,0,0,0,0
4,185777,2191-03-16 12:00:00,0,0,0,0,0,0,0,1,0,0,...,0,0,1,0,0,0,0,0,0,0
4,185777,2191-03-16 16:00:00,0,0,0,0,0,0,0,1,0,0,...,0,0,1,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99999,113369,2117-12-31 20:00:00,0,0,2,0,1,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0
99999,113369,2118-01-01 00:00:00,0,0,1,0,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0
99999,113369,2118-01-01 04:00:00,0,0,1,0,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0
99999,113369,2118-01-01 08:00:00,0,0,1,0,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0


In [None]:
# merged.shape, drug_pivot.reset_index().drop_duplicates().shape, merged_all.shape
# tmp = merged.merge(drug_pivot.reset_index().drop_duplicates(), on=['subject_id', 'hadm_id', 'charttime'], how='left')


False

In [19]:
import numpy as np
from sklearn.model_selection import GroupShuffleSplit

merged_clean = merged.reset_index(drop=True)

np.random.seed(0)

#Split to train & test (all data of a single patient needs to be in the same group)
X = merged_clean
X = X.merge(labels_df, on=['subject_id', 'hadm_id'], how='inner')
groups = merged_clean['subject_id']

gss = GroupShuffleSplit(n_splits=1, train_size=.8, test_size=0.1)
train_index, test_index = next(gss.split(X, groups=groups))
val_index = list(set(X.index.to_list()) - (set(train_index.tolist()) | set(test_index.tolist())))

X_train = X.iloc[train_index]
X_val = X.iloc[val_index]
X_test = X.iloc[test_index]

y_train = X_train[["subject_id", "mort_30day", "prolonged_stay", "readmission_30day"]].drop_duplicates()
y_train = y_train.groupby('subject_id',as_index=False).max()

X_train.drop(columns=["target", "mort_30day", "prolonged_stay", "readmission_30day"], axis=1, inplace=True)

y_val = X_val[["subject_id", "mort_30day", "prolonged_stay", "readmission_30day"]].drop_duplicates()
y_val = y_val.groupby('subject_id',as_index=False).max()
X_val.drop(columns=["target","mort_30day","prolonged_stay","readmission_30day"], axis=1, inplace=True)

y_test = X_test[["subject_id","mort_30day","prolonged_stay","readmission_30day"]].drop_duplicates()
y_test = y_test.groupby('subject_id',as_index=False).max()
X_test.drop(columns=["target","mort_30day","prolonged_stay","readmission_30day"], axis=1, inplace=True)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  X_train.drop(columns=["target", "mort_30day", "prolonged_stay", "readmission_30day"], axis=1, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  X_val.drop(columns=["target","mort_30day","prolonged_stay","readmission_30day"], axis=1, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  X_test.drop(columns=["target","mort_30day","prolonged_stay","readmission_30day"], axis=1, inplace=True)


In [20]:
import sys, numpy
import pickle
# Create an alias so pickle can find the old path
sys.modules["numpy._core.numeric"] = numpy.core.numeric  

with open("data/notes_with_embeddings.pkl", "rb") as f:
    notes = pickle.load(f)
    notes_ordered = X[['subject_id']].drop_duplicates().merge(
    notes[["subject_id","embeddings"]], 
    on='subject_id', 
    how='left'
)

    embeddings_dict = {}

for idx, row in notes_ordered.iterrows():
    subject_id = row['subject_id']
    embeddings = row['embeddings']
    try:
        if isinstance(embeddings, np.ndarray) :
            # Convert to tensor if it's not already
            if not isinstance(embeddings, torch.Tensor):
                embeddings = torch.tensor(embeddings, dtype=torch.float32)
            
            # Perform average pooling across the sequence dimension
            # Assuming embeddings shape is (sequence_length, embedding_dim)
            pooled_embedding = torch.mean(embeddings, dim=0)  # Shape: (embedding_dim,)
            embeddings_dict[subject_id] = pooled_embedding
        else:
            # Handle missing embeddings with zero vector
            # Assuming embedding dimension is 768 (common for transformers)
            embeddings_dict[subject_id] = torch.zeros(768, dtype=torch.float32)
    except Exception as e:
        flag = 1

# Convert to a tensor where each row corresponds to a subject_id
subject_ids_list = notes_ordered['subject_id'].tolist()
pooled_embeddings = [embeddings_dict[subject_id] for subject_id in subject_ids_list]

print(f"Pooled embeddings shape: {len(pooled_embeddings)}")
print(f"Number of subjects: {len(subject_ids_list)}")

notes_df = pd.DataFrame({
    'subject_id': subject_ids_list,
    'embeddings': pooled_embeddings}).set_index('subject_id')

Pooled embeddings shape: 24796
Number of subjects: 24796


In [21]:
from sklearn.preprocessing import StandardScaler

num_cols = X_train.select_dtypes(include='float').columns
scaler = StandardScaler()

X_train.loc[:, num_cols] = scaler.fit_transform(X_train[num_cols])
X_val.loc[:, num_cols] = scaler.transform(X_val[num_cols])
X_test.loc[:, num_cols] = scaler.transform(X_test[num_cols])

baseline_df = X_train[X_train.charttime.dt.date == X_train.admittime.dt.date].mean(axis=0).fillna(0)
X_train.loc[:, num_cols] = X_train[num_cols].fillna(baseline_df)
X_val.loc[:, num_cols] = X_val[num_cols].fillna(baseline_df)
X_test.loc[:, num_cols] = X_test[num_cols].fillna(baseline_df)

In [22]:
to_drop = ['hadm_id','icustay_id','intime','outtime','admittime', 'dischtime', 'dod','dob', 'mort', 'los_hosp_hr', 'charttime','adm_to_death']

to_keep = ~X_train.columns.isin(to_drop)
to_keep = X_train.columns[to_keep]
X_train = X_train[to_keep]
X_test = X_test[to_keep]
X_val = X_val[to_keep]

In [23]:
from sklearn.cluster import KMeans
import numpy as np

def cluster_and_select_subjects(X_train, num_clusters=10, random_state=42):
    """
    Calculate the first row of each subject_id in X_train, cluster it to num_clusters 
    and choose one subject_id from each cluster.
    
    Parameters:
    X_train: DataFrame with subject_id column
    num_clusters: int, number of clusters to create
    random_state: int, for reproducibility
    
    Returns:
    list: selected subject_ids, one from each cluster
    """
    first_rows = X_train.groupby('subject_id').first().reset_index()
    
    features_for_clustering = first_rows.drop('subject_id', axis=1)
    

    kmeans = KMeans(n_clusters=num_clusters, random_state=random_state, n_init=10)
    cluster_labels = kmeans.fit_predict(features_for_clustering)
    

    first_rows['cluster'] = cluster_labels
    

    np.random.seed(random_state)
    selected_subjects = []
    
    for cluster_id in range(num_clusters):
        cluster_subjects = first_rows[first_rows['cluster'] == cluster_id]['subject_id'].values
        if len(cluster_subjects) > 0:
            selected_subject = np.random.choice(cluster_subjects)
            selected_subjects.append(selected_subject)
    
    return selected_subjects

selected_subjects = cluster_and_select_subjects(X_train, num_clusters=100, random_state=42)

In [24]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create X_core with rows where subject_id is in selected_subjects
X_core = X_train[X_train['subject_id'].isin(selected_subjects)]
y_core = y_train[y_train['subject_id'].isin(selected_subjects)]

# Update X_train to exclude the selected subjects
X_train = X_train[~X_train['subject_id'].isin(selected_subjects)]
y_train = y_train[~y_train['subject_id'].isin(selected_subjects)]

train_labels = torch.tensor(y_train[['mort_30day', 'prolonged_stay', 'readmission_30day']].values, dtype=torch.float32).to(DEVICE)
val_labels = torch.tensor(y_val[['mort_30day', 'prolonged_stay', 'readmission_30day']].values, dtype=torch.float32).to(DEVICE)
test_labels = torch.tensor(y_test[['mort_30day', 'prolonged_stay', 'readmission_30day']].values, dtype=torch.float32).to(DEVICE)

notes_df_train = notes_df.loc[X_train.subject_id.unique()]
notes_df_val = notes_df.loc[X_val.subject_id.unique()]
notes_df_test = notes_df.loc[X_test.subject_id.unique()]

bios_train = bios.loc[X_train.subject_id.unique()]
bios_val = bios.loc[X_val.subject_id.unique()]
bios_test = bios.loc[X_test.subject_id.unique()]

In [25]:
def generate_series_data(df, group_col="subject_id", maxlen=18):
  grouped = df.groupby(group_col)
  subject_sequences = [group.values[:, 1:] for _, group in grouped]
  padded_tensor = pad_sequences(subject_sequences, padding='post', dtype='float32')
  sequence_lengths = [len(seq) for seq in subject_sequences]
  padding_mask = np.zeros((len(sequence_lengths), maxlen), dtype=np.float32)
  for i, length in enumerate(sequence_lengths):
      padding_mask[i, :length] = 1.0
  padded_tensor = torch.tensor(padded_tensor, dtype=torch.float32)
  padding_mask = torch.tensor(padding_mask, dtype=torch.float32)
  return padded_tensor, padding_mask

In [26]:
#import NoteEmbedder
#import imp
#os.environ["TRANSFORMERS_CACHE"] = "/home/bnet/ronsheinin/MLHC/MLHC/cache"
#imp.reload(NoteEmbedder)
#import os
#notes = pd.read_csv('data/notes.csv')
#note_embeddings = NoteEmbedder.run_embeeding(notes)

## Model

In [31]:
import torch
from torch.utils.data import Dataset


class PatientDataset(Dataset):
    def __init__(self, X, y, core, padding_mask, padding_mask_core, notes, bios, k=5):
        self.core = core
        self.X = X
        self.y = y
        self.padding_mask = padding_mask
        self.padding_mask_core = padding_mask_core
        self.k = k
        self.notes = notes
        self.bios = torch.tensor(bios)
        self.node_to_neighbors = self.build_knn_graph(X, core, padding_mask, padding_mask_core, k=k)


    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.padding_mask[idx], idx, self.notes[idx], self.bios[idx]


    def get_edge_index(self, batch, padding_mask_batch, batch_indices):
        """
        Returns a tensor representing the edge index of the graph.
        
        Returns:
            edge_index: tensor of shape (2, num_edges) representing graph edges
        """
        batch_size, seq_len, _ = batch.shape
        core_size = self.core.shape[0]
        total_patients = batch_size + core_size

        batch_size = batch.shape[0]
        # all_patients = torch.cat([batch, self.core], dim=0)
        all_padding_mask = torch.cat([padding_mask_batch, self.padding_mask_core.to(padding_mask_batch.device)], dim=0)

        edges = []
        
        for patient_idx in range(total_patients):
            for t in range(seq_len - 1):
                if all_padding_mask[patient_idx, t] > 0 and all_padding_mask[patient_idx, t + 1] > 0:
                    node_curr = patient_idx * seq_len + t
                    node_next = patient_idx * seq_len + t + 1
   
                    edges.append([node_curr, node_next])
                    edges.append([node_next, node_curr])
        
        edges = [torch.tensor(edges).t()]

        neighbors_list = []
        for idx in batch_indices:
            neighbors_list.extend(self.node_to_neighbors[idx.item() * seq_len: (idx.item() + 1) * seq_len])
            # for i in range(seq_len):
            #     neighbors_list.append(list(self.node_to_neighbors[idx.item() * seq_len + i]))
        
        # for i in range(self.X.shape[0] * seq_len, self.X.shape[0] * seq_len + core_size * seq_len):
        #     neighbors_list.append(list(self.node_to_neighbors[i]))
        neighbors_list.extend(self.node_to_neighbors[self.X.shape[0] * seq_len:])
    
        for t in range(seq_len):
            valid_patients = all_padding_mask[:, t] > 0
            valid_indices = torch.where(valid_patients)[0] 

            if len(valid_indices) > 1:
                for i, patient_idx in enumerate(valid_indices):
                    node_curr = patient_idx * seq_len + t
                    to_append = torch.tensor(neighbors_list[node_curr], dtype=torch.long) - (self.X.shape[0] - batch_size) * seq_len  # Adjust index for core patients
                    edges.append(torch.stack([to_append, torch.full_like(to_append, node_curr)], dim=0))

        
        if edges:
            edge_index = torch.cat(edges, dim=1).to(torch.long)
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)
        
        return edge_index


        
    @staticmethod
    def build_knn_graph(batch, core, padding_mask_batch, padding_mask_core, k=5):
        """
        Build a KNN graph from batch and core tensors.
        
        Args:
            batch: 3D tensor (batch_size, seq_len, features)
            core: 3D tensor (core_size, seq_len, features)
            padding_mask_batch: 2D tensor (batch_size, seq_len) indicating valid time points
            padding_mask_core: 2D tensor (core_size, seq_len) indicating valid time points
            k: number of nearest neighbors for patient connections
        
        Returns:
            edge_index: tensor of shape (2, num_edges) representing graph edges
        """

        batch_size, seq_len, _ = batch.shape
        core_size = core.shape[0]
        total_patients = batch_size + core_size
        batch_size = batch.shape[0]

        all_patients = torch.cat([batch, core], dim=0)
        all_padding_mask = torch.cat([padding_mask_batch, padding_mask_core], dim=0)
        
        node_to_neighbors = {i: set() for i in range(total_patients * seq_len)}

        edges = []
           
        for t in range(seq_len):
            valid_patients = all_padding_mask[:, t] > 0
            valid_indices = torch.where(valid_patients)[0] 

            core_patients = all_padding_mask[batch_size:, t] > 0
            core_indices = torch.where(core_patients)[0] + batch_size

            if len(valid_indices) > 1:
                features_t = all_patients[valid_indices, t, :]
                features_t_to = all_patients[core_indices, t, :]

                distances = torch.cdist(features_t, features_t_to, p=2)
                
                for i, patient_idx in enumerate(valid_indices):
                    num_neighbors = min(k, len(core_indices))
                    _, nearest_indices = torch.topk(distances[i], num_neighbors, largest=False)
                    
                    for j in nearest_indices:
                        neighbor_idx = core_indices[j]
                        node_curr = patient_idx * seq_len + t
                        node_neighbor = neighbor_idx * seq_len + t

                        edges.append([node_neighbor, node_curr])
                        
                        node_to_neighbors[node_curr.item()].add(node_neighbor.item())
        
        result = [list(node_to_neighbors[i]) for i in range(total_patients * seq_len)]
        return result

# dataset = PatientDataset(core=padd_tensor_core, X=padded_tensor_train, y=labels_train, padding_mask=padding_mask_train, padding_mask_core=padding_mask_core, k=5)
# dataset.cal_graphs()

In [66]:
from sklearn.metrics import average_precision_score, roc_auc_score


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#DEVICE = torch.device('cpu')  # For testing purposes, use CPU


class GraphGRUMortalityModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, n1_gat_layers, n2_gru_layers, X_core, core_padding_mask,
                 num_heads=4, dropout=0.1, seq_len=18, k=5, notes_embedding_dim = 768, 
                 bios_embedding_dim = 128, bio_num_embeddings = 1000,
                 gnn_flag=True):
        """
        Mortality prediction model with Graph Attention + GRU layers
        
        Args:
            input_dim: Input feature dimension
            hidden_dim: Hidden dimension for GAT and GRU layers
            n1_gat_layers: Number of Graph Attention layers
            n2_gru_layers: Number of GRU layers
            X_core_dim: Core set dimension (number of core patients)
            num_heads: Number of attention heads for GAT
            dropout: Dropout rate
            seq_len: Sequence length
        """
        super(GraphGRUMortalityModel, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n1_gat_layers = n1_gat_layers
        self.n2_gru_layers = n2_gru_layers
        self.X_core = X_core.to(DEVICE)
        self.core_padding_mask = core_padding_mask.to(DEVICE)
        self.seq_len = seq_len
        self.num_heads = num_heads
        self.gnn_flag = gnn_flag
        
        self.notes_layer = nn.Linear(notes_embedding_dim, hidden_dim).to(DEVICE)

        self.bios_embedding_bag = nn.EmbeddingBag(num_embeddings=bio_num_embeddings, embedding_dim=bios_embedding_dim, mode='mean').to(DEVICE)
        self.bios_embedding_bag.weight.requires_grad = True
        self.bios_layer = nn.Linear(bios_embedding_dim, hidden_dim).to(DEVICE)

        self.gat_layers = nn.ModuleList().to(DEVICE)
        
        self.gat_layers.append(
            GATv2Conv(input_dim, hidden_dim // num_heads, heads=num_heads, dropout=dropout, concat=True)
        )
        
        for _ in range(n1_gat_layers - 1):
            self.gat_layers.append(
                GATv2Conv(hidden_dim, hidden_dim // num_heads, heads=num_heads, dropout=dropout, concat=True)
            )
        
        if self.gnn_flag:
            self.gru = nn.GRU(hidden_dim, hidden_dim, n2_gru_layers, batch_first=True, dropout=dropout)
        else:
            self.gru = nn.GRU(input_dim, hidden_dim, n2_gru_layers, batch_first=True, dropout=dropout)
        
        self.classifier1 = nn.Sequential(
            nn.Linear(5 * hidden_dim, 2 * hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )

        self.classifier2 = nn.Sequential(
            nn.Linear(5 * hidden_dim, 2 * hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )

        self.classifier3 = nn.Sequential(
            nn.Linear(5 * hidden_dim, 2 * hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )
        
        self.dropout = nn.Dropout(dropout).to(DEVICE)
        self.k = k

        self.best_model = None

        # initialize weights
        for p in self.parameters():
          try:
              if p.requires_grad:
                nn.init.xavier_uniform_(p) #, -0.15, 0.15)
          except ValueError:
              pass

        
    def forward(self, x ,padding_mask, edge_index, notes, bios):
        batch_size = x.size(0)
        
        if self.gnn_flag:
            all_patients = torch.cat([x, self.X_core], dim=0)  # (batch_size + X_core_dim, seq_len, input_dim)
            total_patients = batch_size + self.X_core.shape[0]
            
            # Reshape for graph processing: (total_patients * seq_len, input_dim)
            graph_input = all_patients.view(total_patients * self.seq_len, -1)
            # Apply GAT layers
            for gat_layer in self.gat_layers:
                graph_input = F.relu(gat_layer(graph_input, edge_index))
            
            # Reshape back to sequence format: (total_patients, seq_len, hidden_dim)
            graph_output = graph_input.view(total_patients, self.seq_len, -1)

            # Extract only batch patients (exclude core)
        
            batch_output = graph_output[:batch_size]  # (batch_size, seq_len, hidden_dim)
        
        else:
            batch_output = x
        # Apply GRU layers
        # Pack sequences for efficient processing
        # lengths from mask (True = pad) → count valid steps
        lengths = (padding_mask.to(bool)).sum(dim=1)                       # (batch,)
        lengths = lengths.clamp(min=1).cpu()

        mask_index = padding_mask.sum(dim=1).long() - 1  # Get the last valid index for each sequence
        mask_expanded = padding_mask.unsqueeze(-1)    
        gru_output, _ = self.gru(batch_output)
        out = torch.cat([
            gru_output[torch.arange(gru_output.size(0)), mask_index, :],
            (gru_output * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1),
            (gru_output * mask_expanded).max(dim=1)[0]
        ], dim=-1) 
        notes = F.relu(self.notes_layer(notes))
        
        indices_list = [row.nonzero(as_tuple=True)[0] for row in bios]
        indices = torch.cat(indices_list).to(torch.long).to(DEVICE)
        # offsets = where each sample starts
        lengths = [len(i) for i in indices_list]
        offsets = torch.tensor([0] + torch.cumsum(torch.tensor(lengths[:-1]), dim=0).tolist(), dtype=torch.long).to(DEVICE)
        bios_embeddings = self.bios_embedding_bag(indices, offsets)
        bios = F.relu(self.bios_layer(bios_embeddings))

        X_concat = torch.cat([out, notes, bios], dim=-1)  # (batch_size, seq_len, 2*hidden_dim)
        preds1 = self.classifier1(X_concat)  # (batch_size, 1)
        preds2 = self.classifier2(X_concat)
        preds3 = self.classifier3(X_concat)
        predictions = torch.concat([preds1, preds2, preds3], dim=1)
        return predictions
    
    
    def masked_bce_loss(self, logits, targets, mask, pos_weight=None):
        T = min(logits.shape[1], targets.shape[1], mask.shape[1])
        logits, targets, mask = logits[:, :T], targets[:, :T], mask[:, :T]
        loss = nn.BCEWithLogitsLoss(reduction='none', pos_weight=pos_weight)(logits, targets)
        return (loss * mask).sum() / mask.sum()
    

    def train_all(self, dataloaders, datasets, epochs: int = 10, learning_rate: float = 1e-3, pos_lambda : float = 1):
        self.train()
        optim = torch.optim.Adam(self.parameters(), lr=learning_rate)
        best_validation_accuracy = - float('inf')
        losses = []
        for i in range(datasets['train'].y.shape[1]):
            targets = datasets['train'].y[:,i] # Get the max target for each patien
            pos_weight = (targets == 0).sum() / (targets == 1).sum()  # Adjust pos_weight as needed
            losses.append(nn.BCEWithLogitsLoss(reduction='mean', pos_weight=pos_weight))
            print(f'Pos weight {i}: {pos_weight:.4f}')


        for epoch in range(epochs):
            print(f'Starting epoch {epoch + 1}/{epochs}')
            total = 0
            for x, y, padding_mask, idx, notes, bios in tqdm(dataloaders['train']):
                optim.zero_grad()

                x, padding_mask, y, notes, bios = x.to(DEVICE), padding_mask.to(DEVICE), y.to(DEVICE), notes.to(DEVICE), bios.to(DEVICE)
                edge_index = datasets['train'].get_edge_index(x, padding_mask, idx).to(DEVICE)
                predictions = self.forward(x, padding_mask, edge_index, notes, bios)
                loss = losses[0](predictions[:, 0], y[:, 0]) + \
                       losses[1](predictions[:, 1], y[:, 1]) + \
                       losses[2](predictions[:, 2], y[:, 2])
                #loss = F.binary_cross_entropy_with_logits(predictions.view(-1), y.view(-1))
                loss.backward()
                optim.step()
                
                total += loss.item()
            avg_loss = total / len(dataloaders['train'])
            print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')
            validation_accuracy = self.validate(dataloaders['val'], datasets['val'])
            # 'mort_30day', 'prolonged_stay', 'readmission_30day'
            print(f'Validation Accuracy mort_30day: AUC - {validation_accuracy[0]:.4f} | AP - {validation_accuracy[1]:.4f}')
            print(f'prolonged_stay: AUC - {validation_accuracy[2]:.4f} | AP - {validation_accuracy[3]:.4f}')
            print(f'readmission_30day: AUC - {validation_accuracy[4]:.4f} | AP - {validation_accuracy[5]:.4f}')
            if validation_accuracy[1] > best_validation_accuracy:
                best_validation_accuracy = validation_accuracy[1]
                self.best_model = self.state_dict()
                print("Best model updated")
        self.load_state_dict(self.best_model)

    def validate(self, dataloader, dataset):
        self.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            all_true_labels = {i: [] for i in range(3)}
            all_predicted_labels = {i: [] for i in range(3)}
            for x, y, padding_mask, idx, notes, bios in tqdm(dataloader):
                x, padding_mask, y, notes, bios = x.to(DEVICE), padding_mask.to(DEVICE), y.to(DEVICE), notes.to(DEVICE), bios.to(DEVICE)
                # edge_index = dataset.build_knn_graph(x, self.X_core, padding_mask, self.core_padding_mask, k=self.k).to(DEVICE)
                edge_index = dataset.get_edge_index(x, padding_mask, idx).to(DEVICE)
                predictions = self.forward(x, padding_mask, edge_index, notes, bios)
                for i in range(3):
                    all_true_labels[i].extend(y[:, i].cpu().numpy())
                    all_predicted_labels[i].extend(torch.sigmoid(predictions[:, i]).cpu().numpy().flatten())
                # all_true_labels.extend(y.cpu().numpy())
                # all_predicted_labels.extend(torch.sigmoid(predictions).cpu().numpy().flatten())
                predicted_labels = (torch.sigmoid(predictions[:,0]) > 0.5).float()
                correct += (predicted_labels == y[:,0]).sum().item()
                total += y[:,0].shape[0]
        accuracy = correct / total if total > 0 else 0
        self.train()
        return roc_auc_score(all_true_labels[0], all_predicted_labels[0]), average_precision_score(all_true_labels[0], all_predicted_labels[0]), \
                roc_auc_score(all_true_labels[1], all_predicted_labels[1]), average_precision_score(all_true_labels[1], all_predicted_labels[1]), \
                roc_auc_score(all_true_labels[2], all_predicted_labels[2]), average_precision_score(all_true_labels[2], all_predicted_labels[2])


    


In [62]:
batch_size = 256
hidden_dim = 64
k = 7

padded_tensor_train, padding_mask_train = generate_series_data(X_train, group_col="subject_id", maxlen=18)
padded_tensor_val, padding_mask_val = generate_series_data(X_val, group_col="subject_id", maxlen=18)
padded_tensor_test, padding_mask_test = generate_series_data(X_test, group_col="subject_id", maxlen=18)
padd_tensor_core, padding_mask_core = generate_series_data(X_core, group_col="subject_id", maxlen=18)

datasets = {x: PatientDataset(d, y, core=padd_tensor_core, padding_mask=m, padding_mask_core=padding_mask_core, k=k ,notes=n, bios=b) for x, d, y, m, n, b in
        zip(['train', 'val', 'test'], [padded_tensor_train, padded_tensor_val, padded_tensor_test], 
            [train_labels, val_labels, test_labels], 
            [padding_mask_train, padding_mask_val, padding_mask_test],
            [torch.stack(notes_df_train.embeddings.values.tolist()), 
             torch.stack(notes_df_val.embeddings.values.tolist()), 
             torch.stack(notes_df_test.embeddings.values.tolist())],
             [bios_train.values, bios_val.values, bios_test.values])}
dataloaders = {x: DataLoader(datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'val', 'test']}




In [67]:
model = GraphGRUMortalityModel(input_dim=padded_tensor_train.shape[2], hidden_dim=hidden_dim, n1_gat_layers=1, n2_gru_layers=1, 
                               X_core=padd_tensor_core, core_padding_mask=padding_mask_core, 
                               num_heads=4, dropout=0.1, seq_len=padded_tensor_train.shape[1], k=k, 
                                notes_embedding_dim=768, bios_embedding_dim=64, bio_num_embeddings=bios_train.values.shape[1],
                               gnn_flag=True).to(DEVICE)

model.train_all(dataloaders, datasets, epochs=15, learning_rate=1e-3)
print("Training completed. Validating on test set...")

auc1, pr1, auc2, pr2, auc3, pr3 =  model.validate(dataloaders['test'], datasets['test'])
print(f'Test Accuracy in model: AUC1: {auc1:.4f}, AP1: {pr1:.4f}, AUC2: {auc2:.4f}, AP2: {pr2:.4f}, AUC3: {auc3:.4f}, AP3: {pr3:.4f}')



Pos weight 0: 6.8976
Pos weight 1: 1.1203
Pos weight 2: 21.9756
Starting epoch 1/15


100%|██████████| 78/78 [00:46<00:00,  1.66it/s]


Epoch 1/15, Loss: 3.1464


100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Validation Accuracy mort_30day: AUC - 0.8033 | AP - 0.3666
prolonged_stay: AUC - 0.6586 | AP - 0.5978
readmission_30day: AUC - 0.5744 | AP - 0.0516
Best model updated
Starting epoch 2/15


100%|██████████| 78/78 [00:46<00:00,  1.67it/s]


Epoch 2/15, Loss: 2.9294


100%|██████████| 10/10 [00:05<00:00,  1.77it/s]


Validation Accuracy mort_30day: AUC - 0.8132 | AP - 0.3801
prolonged_stay: AUC - 0.6790 | AP - 0.6259
readmission_30day: AUC - 0.5873 | AP - 0.0507
Best model updated
Starting epoch 3/15


100%|██████████| 78/78 [00:47<00:00,  1.64it/s]


Epoch 3/15, Loss: 2.8442


100%|██████████| 10/10 [00:05<00:00,  1.77it/s]


Validation Accuracy mort_30day: AUC - 0.8172 | AP - 0.4058
prolonged_stay: AUC - 0.7014 | AP - 0.6516
readmission_30day: AUC - 0.6022 | AP - 0.0574
Best model updated
Starting epoch 4/15


100%|██████████| 78/78 [00:47<00:00,  1.66it/s]


Epoch 4/15, Loss: 2.7833


100%|██████████| 10/10 [00:06<00:00,  1.59it/s]


Validation Accuracy mort_30day: AUC - 0.8120 | AP - 0.3968
prolonged_stay: AUC - 0.7222 | AP - 0.6690
readmission_30day: AUC - 0.6104 | AP - 0.0573
Starting epoch 5/15


100%|██████████| 78/78 [00:47<00:00,  1.63it/s]


Epoch 5/15, Loss: 2.7385


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]


Validation Accuracy mort_30day: AUC - 0.8066 | AP - 0.3977
prolonged_stay: AUC - 0.7229 | AP - 0.6731
readmission_30day: AUC - 0.6131 | AP - 0.0599
Starting epoch 6/15


100%|██████████| 78/78 [00:45<00:00,  1.70it/s]


Epoch 6/15, Loss: 2.6829


100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Validation Accuracy mort_30day: AUC - 0.8109 | AP - 0.4068
prolonged_stay: AUC - 0.7289 | AP - 0.6816
readmission_30day: AUC - 0.6171 | AP - 0.0617
Best model updated
Starting epoch 7/15


100%|██████████| 78/78 [00:46<00:00,  1.68it/s]


Epoch 7/15, Loss: 2.6617


100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Validation Accuracy mort_30day: AUC - 0.8000 | AP - 0.3966
prolonged_stay: AUC - 0.7166 | AP - 0.6666
readmission_30day: AUC - 0.6212 | AP - 0.0642
Starting epoch 8/15


100%|██████████| 78/78 [00:46<00:00,  1.67it/s]


Epoch 8/15, Loss: 2.6205


100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Validation Accuracy mort_30day: AUC - 0.8110 | AP - 0.3993
prolonged_stay: AUC - 0.7276 | AP - 0.6795
readmission_30day: AUC - 0.6024 | AP - 0.0623
Starting epoch 9/15


100%|██████████| 78/78 [00:47<00:00,  1.65it/s]


Epoch 9/15, Loss: 2.5928


100%|██████████| 10/10 [00:05<00:00,  1.79it/s]


Validation Accuracy mort_30day: AUC - 0.8078 | AP - 0.3963
prolonged_stay: AUC - 0.7265 | AP - 0.6779
readmission_30day: AUC - 0.6390 | AP - 0.0712
Starting epoch 10/15


100%|██████████| 78/78 [00:45<00:00,  1.71it/s]


Epoch 10/15, Loss: 2.5751


100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Validation Accuracy mort_30day: AUC - 0.8019 | AP - 0.4042
prolonged_stay: AUC - 0.7202 | AP - 0.6697
readmission_30day: AUC - 0.6271 | AP - 0.0737
Starting epoch 11/15


100%|██████████| 78/78 [00:44<00:00,  1.74it/s]


Epoch 11/15, Loss: 2.5274


100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Validation Accuracy mort_30day: AUC - 0.8003 | AP - 0.4154
prolonged_stay: AUC - 0.7142 | AP - 0.6625
readmission_30day: AUC - 0.6395 | AP - 0.0801
Best model updated
Starting epoch 12/15


100%|██████████| 78/78 [00:46<00:00,  1.68it/s]


Epoch 12/15, Loss: 2.4972


100%|██████████| 10/10 [00:05<00:00,  1.75it/s]


Validation Accuracy mort_30day: AUC - 0.7965 | AP - 0.4082
prolonged_stay: AUC - 0.7240 | AP - 0.6759
readmission_30day: AUC - 0.6319 | AP - 0.0742
Starting epoch 13/15


100%|██████████| 78/78 [00:46<00:00,  1.68it/s]


Epoch 13/15, Loss: 2.4619


100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Validation Accuracy mort_30day: AUC - 0.7994 | AP - 0.4028
prolonged_stay: AUC - 0.7202 | AP - 0.6744
readmission_30day: AUC - 0.6389 | AP - 0.0777
Starting epoch 14/15


100%|██████████| 78/78 [00:45<00:00,  1.70it/s]


Epoch 14/15, Loss: 2.4072


100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Validation Accuracy mort_30day: AUC - 0.8003 | AP - 0.3941
prolonged_stay: AUC - 0.7263 | AP - 0.6772
readmission_30day: AUC - 0.6330 | AP - 0.0718
Starting epoch 15/15


100%|██████████| 78/78 [00:47<00:00,  1.66it/s]


Epoch 15/15, Loss: 2.4107


100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Validation Accuracy mort_30day: AUC - 0.7971 | AP - 0.3946
prolonged_stay: AUC - 0.7119 | AP - 0.6703
readmission_30day: AUC - 0.6299 | AP - 0.0776
Training completed. Validating on test set...


100%|██████████| 10/10 [00:05<00:00,  1.72it/s]

Test Accuracy in model: AUC1: 0.8553, AP1: 0.5122, AUC2: 0.8022, AP2: 0.7665, AUC3: 0.6503, AP3: 0.0833





In [None]:

acc, auc, pr =  model.validate(dataloaders['test'], datasets['test'])
print(f'Test Accuracy in model: {acc:.4f}, AUC: {auc:.4f}, AP: {pr:.4f}')

100%|██████████| 10/10 [00:04<00:00,  2.13it/s]

Test Accuracy in model: 0.7863, AUC: 0.8395, AP: 0.4951


100%|██████████| 77/77 [00:50<00:00,  1.54it/s]


Epoch 1/15, Loss: 1.5944


100%|██████████| 10/10 [00:06<00:00,  1.62it/s]


Validation Accuracy: 0.3713 | AUC: 0.7740 | AP: 0.1687
Best model updated
Starting epoch 2/15


100%|██████████| 77/77 [00:49<00:00,  1.56it/s]


Epoch 2/15, Loss: 1.4701


100%|██████████| 10/10 [00:06<00:00,  1.45it/s]


Validation Accuracy: 0.5278 | AUC: 0.7918 | AP: 0.1959
Best model updated
Starting epoch 3/15


100%|██████████| 77/77 [00:48<00:00,  1.57it/s]


Epoch 3/15, Loss: 1.4253


100%|██████████| 10/10 [00:06<00:00,  1.49it/s]


Validation Accuracy: 0.5637 | AUC: 0.8071 | AP: 0.2007
Best model updated
Starting epoch 4/15


100%|██████████| 77/77 [00:48<00:00,  1.58it/s]


Epoch 4/15, Loss: 1.4124


100%|██████████| 10/10 [00:06<00:00,  1.49it/s]


Validation Accuracy: 0.6510 | AUC: 0.8170 | AP: 0.2071
Best model updated
Starting epoch 5/15


100%|██████████| 77/77 [00:48<00:00,  1.58it/s]


Epoch 5/15, Loss: 1.3996


100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Validation Accuracy: 0.7417 | AUC: 0.7936 | AP: 0.2050
Starting epoch 6/15


100%|██████████| 77/77 [00:49<00:00,  1.55it/s]


Epoch 6/15, Loss: 1.3517


100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Validation Accuracy: 0.7804 | AUC: 0.7955 | AP: 0.2005
Starting epoch 7/15


100%|██████████| 77/77 [00:49<00:00,  1.55it/s]


Epoch 7/15, Loss: 1.3370


100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Validation Accuracy: 0.7522 | AUC: 0.8047 | AP: 0.2153
Best model updated
Starting epoch 8/15


100%|██████████| 77/77 [00:49<00:00,  1.55it/s]


Epoch 8/15, Loss: 1.3255


100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Validation Accuracy: 0.5780 | AUC: 0.8325 | AP: 0.2263
Best model updated
Starting epoch 9/15


100%|██████████| 77/77 [00:49<00:00,  1.55it/s]


Epoch 9/15, Loss: 1.3074


100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Validation Accuracy: 0.6640 | AUC: 0.8306 | AP: 0.2243
Starting epoch 10/15


100%|██████████| 77/77 [00:49<00:00,  1.55it/s]


Epoch 10/15, Loss: 1.2698


100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Validation Accuracy: 0.6267 | AUC: 0.8325 | AP: 0.2238
Starting epoch 11/15


100%|██████████| 77/77 [00:49<00:00,  1.55it/s]


Epoch 11/15, Loss: 1.2637


100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Validation Accuracy: 0.7628 | AUC: 0.8211 | AP: 0.2248
Starting epoch 12/15


100%|██████████| 77/77 [00:49<00:00,  1.55it/s]


Epoch 12/15, Loss: 1.2288


100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Validation Accuracy: 0.6178 | AUC: 0.8314 | AP: 0.2235
Starting epoch 13/15


100%|██████████| 77/77 [00:49<00:00,  1.55it/s]


Epoch 13/15, Loss: 1.1946


100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Validation Accuracy: 0.7001 | AUC: 0.8297 | AP: 0.2405
Best model updated
Starting epoch 14/15


100%|██████████| 77/77 [00:49<00:00,  1.55it/s]


Epoch 14/15, Loss: 1.1705


100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Validation Accuracy: 0.6207 | AUC: 0.8285 | AP: 0.2405
Best model updated
Starting epoch 15/15


100%|██████████| 77/77 [00:49<00:00,  1.54it/s]


Epoch 15/15, Loss: 1.1669


100%|██████████| 10/10 [00:06<00:00,  1.63it/s]


Validation Accuracy: 0.5813 | AUC: 0.8371 | AP: 0.2379
Training completed. Validating on test set...


100%|██████████| 10/10 [00:06<00:00,  1.62it/s]

Test Accuracy in BaseLine model: 0.5691, AUC: 0.8350, AP: 0.2529





In [None]:
import optuna
import gc
import joblib
def objective_mlp(trial):
    hidden_dim = trial.suggest_int('hidden_dim', 32, 128, step=16)
    n1_gat_layers = trial.suggest_int('n1_gat_layers', 1, 5)
    n2_gru_layers = trial.suggest_int('n2_gru_layers', 1, 5)
    dropout = trial.suggest_float('dropout', 0.1, 0.5, step=0.1)
    num_heads = trial.suggest_int('num_heads', 2, 8, step=2)
    k = trial.suggest_int('k', 5, 15, step=2)
    lr= trial.suggest_float('learning_rate', 1e-4, 1e-2, log=True)

    model = GraphGRUMortalityModel(input_dim=padded_tensor_train.shape[2], hidden_dim=hidden_dim, n1_gat_layers=n1_gat_layers,
                                   n2_gru_layers=n2_gru_layers, X_core=padd_tensor_core, core_padding_mask=padding_mask_core,
                                   num_heads=num_heads, dropout=dropout, seq_len=padded_tensor_train.shape[1], k=k, gnn_flag=True).to(DEVICE)

    model.train_all(dataloaders, datasets, epochs=10, learning_rate=lr)
    
    acc, auc, pr = model.validate(dataloaders['test'], datasets['test'])
    
    del model 
    torch.cuda.empty_cache() 
    gc.collect()
    return pr

batch_size = 256

study = optuna.create_study(direction='maximize')
study.optimize(objective_mlp, n_trials=20)  
print("Best trial:")
print(study.best_trial)
print("Best hyperparameters:", study.best_trial.params) 
joblib.dump(study, 'optuna_study.pkl')


[I 2025-08-18 16:49:33,775] A new study created in memory with name: no-name-909f0b03-19ae-42e6-aaf8-4896d952bb34


Pos weight: 11.4991
Starting epoch 1/10


 32%|███▏      | 25/78 [01:23<02:58,  3.36s/it]

In [27]:
print(study.best_trial)
print("Best hyperparameters:", study.best_trial.params) 

FrozenTrial(number=12, state=TrialState.COMPLETE, values=[0.4732600299259349], datetime_start=datetime.datetime(2025, 8, 21, 1, 10, 11, 209685), datetime_complete=datetime.datetime(2025, 8, 21, 1, 34, 42, 427077), params={'hidden_dim': 96, 'n1_gat_layers': 1, 'n2_gru_layers': 1, 'dropout': 0.1, 'num_heads': 6, 'k': 5, 'learning_rate': 0.002245162564243561}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'hidden_dim': IntDistribution(high=128, log=False, low=32, step=16), 'n1_gat_layers': IntDistribution(high=5, log=False, low=1, step=1), 'n2_gru_layers': IntDistribution(high=5, log=False, low=1, step=1), 'dropout': FloatDistribution(high=0.5, log=False, low=0.1, step=0.1), 'num_heads': IntDistribution(high=8, log=False, low=2, step=2), 'k': IntDistribution(high=15, log=False, low=5, step=2), 'learning_rate': FloatDistribution(high=0.01, log=True, low=0.0001, step=None)}, trial_id=12, value=None)
Best hyperparameters: {'hidden_dim': 96, 'n1_gat_layers': 1, 'n2_gr

In [None]:
hidden_dim = study.best_trial.params['hidden_dim']
n1_gat_layers = study.best_trial.params['n1_gat_layers']
n2_gru_layers = study.best_trial.params['n2_gru_layers']
dropout = study.best_trial.params['dropout']
num_heads = study.best_trial.params['num_heads']
k = study.best_trial.params['k']

model = GraphGRUMortalityModel(input_dim=padded_tensor_train.shape[2], hidden_dim=hidden_dim, n1_gat_layers=n1_gat_layers,
                            n2_gru_layers=n2_gru_layers, X_core=padd_tensor_core, core_padding_mask=padding_mask_core,
                            num_heads=num_heads, dropout=dropout, seq_len=padded_tensor_train.shape[1], k=k, gnn_flag=True).to(DEVICE)


model.train_all(dataloaders, datasets, epochs=15, learning_rate=study.best_trial.params['learning_rate'])
print("Training completed. Validating on test set...")

acc, auc, pr =  model.validate(dataloaders['test'], datasets['test'])
print(f'Test Accuracy in GNN model: {acc:.4f}, AUC: {auc:.4f}, AP: {pr:.4f}')



Pos weight: 6.8976
Starting epoch 1/15


  0%|          | 0/78 [00:00<?, ?it/s]

100%|██████████| 78/78 [00:38<00:00,  2.02it/s]


Epoch 1/15, Loss: 1.0871


100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Validation Accuracy: 0.7996 | AUC: 0.8001 | AP: 0.3765
Best model updated
Starting epoch 2/15


100%|██████████| 78/78 [00:38<00:00,  2.02it/s]


Epoch 2/15, Loss: 0.9812


100%|██████████| 10/10 [00:05<00:00,  1.86it/s]


Validation Accuracy: 0.7407 | AUC: 0.8194 | AP: 0.4082
Best model updated
Starting epoch 3/15


100%|██████████| 78/78 [00:39<00:00,  1.98it/s]


Epoch 3/15, Loss: 0.9296


100%|██████████| 10/10 [00:04<00:00,  2.16it/s]


Validation Accuracy: 0.7956 | AUC: 0.8275 | AP: 0.4212
Best model updated
Starting epoch 4/15


100%|██████████| 78/78 [00:38<00:00,  2.05it/s]


Epoch 4/15, Loss: 0.9071


100%|██████████| 10/10 [00:04<00:00,  2.15it/s]


Validation Accuracy: 0.7060 | AUC: 0.8355 | AP: 0.4343
Best model updated
Starting epoch 5/15


100%|██████████| 78/78 [00:38<00:00,  2.04it/s]


Epoch 5/15, Loss: 0.8704


100%|██████████| 10/10 [00:04<00:00,  2.14it/s]


Validation Accuracy: 0.8319 | AUC: 0.8412 | AP: 0.4420
Best model updated
Starting epoch 6/15


 28%|██▊       | 22/78 [00:10<00:27,  2.04it/s]


KeyboardInterrupt: 

In [None]:
model = GraphGRUMortalityModel(input_dim=padded_tensor_train.shape[2], hidden_dim=hidden_dim, n1_gat_layers=2, n2_gru_layers=2, X_core=padd_tensor_core, 
                               core_padding_mask=padding_mask_core, num_heads=4, dropout=0.1, seq_len=padded_tensor_train.shape[1], k=k, gnn_falg=False).to(DEVICE)

model.train_all(dataloaders, datasets, epochs=15, learning_rate=1e-3)
print("Training completed. Validating on test set...")

acc, auc, pr =  model.validate(dataloaders['test'], datasets['test'])
print(f'Test Accuracy in GNN model: {acc:.4f}, AUC: {auc:.4f}, AP: {pr:.4f}')