## Setup

In [184]:
# from google.colab import auth
# from google.cloud import bigquery
# from google.colab import data_table
# from google.colab import widgets

# from collections import Counter
# import re
# import numpy as np
# import pandas as pd
# import math



In [185]:
# import os
# from google.colab import drive
# import sys

# drive.mount('/content/drive')
# os.chdir('/content/drive/MyDrive/HML/Final Project')
# sys.path.append(os.path.abspath('/content/drive/MyDrive/HML/Final Project'))

In [186]:
# !pip install duckdb --quiet
# import duckdb

In [187]:
# drive_path = '/content/drive/MyDrive/HML/Final Project/MIMIC-III'
# con = duckdb.connect(f'{drive_path}/mimiciii.duckdb')

In [188]:
import pandas as pd
import numpy as np


import torch
from torch_geometric.nn import GATConv
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import torch.nn as nn
import torch.nn.functional as F

## Loading data

In [189]:
subject_ids = pd.read_csv('data/initial_cohort.csv')['subject_id'].to_list()
lavbevent_meatdata = pd.read_csv('data/labs_metadata.csv')
vital_meatdata = pd.read_csv('data/vital_metadata.csv')

In [190]:
ICUQ = \
"""--sql
SELECT admissions.subject_id::INTEGER AS subject_id, admissions.hadm_id::INTEGER AS hadm_id
, admissions.admittime::DATE AS admittime, admissions.dischtime::DATE AS dischtime
, admissions.ethnicity, admissions.deathtime::DATE AS deathtime
, patients.gender, patients.dob::DATE AS dob, icustays.icustay_id::INTEGER AS icustay_id, patients.dod::DATE as dod,
icustays.intime::DATE AS intime,icustays.outtime::DATE AS outtime
FROM admissions
INNER JOIN patients
    ON admissions.subject_id = patients.subject_id
LEFT JOIN icustays
    ON admissions.hadm_id = icustays.hadm_id

WHERE admissions.has_chartevents_data = 1
AND admissions.subject_id::INTEGER IN ?
ORDER BY admissions.subject_id, admissions.hadm_id, admissions.admittime;
"""

# icu =  con.execute(ICUQ, [subject_ids]).fetchdf().rename(str.lower, axis='columns')


In [191]:
LABQUERY = \
f"""--sql
SELECT labevents.subject_id::INTEGER AS subject_id\
      , labevents.hadm_id::INTEGER AS hadm_id\
      , labevents.charttime::DATE AS charttime
      , labevents.itemid::INTEGER AS itemid\
      , labevents.valuenum::DOUBLE AS valuenum
      , admissions.admittime::DATE AS admittime
FROM labevents
          INNER JOIN admissions
                    ON labevents.subject_id = admissions.subject_id
                        AND labevents.hadm_id = admissions.hadm_id
                        AND labevents.charttime::DATE between
                            (admissions.admittime::DATE)
                            AND (admissions.admittime::DATE + interval 48 hour)
                        AND itemid::INTEGER IN ? \
                        """

VITQUERY = f"""--sql
        SELECT chartevents.subject_id::INTEGER AS subject_id\
             , chartevents.hadm_id::INTEGER AS hadm_id\
             , chartevents.charttime::DATE AS charttime\
             , chartevents.itemid::INTEGER AS itemid\
             , chartevents.valuenum::DOUBLE AS valuenum\
             , admissions.admittime::DATE AS admittime\
        FROM chartevents
                 INNER JOIN admissions
                            ON chartevents.subject_id = admissions.subject_id
                                AND chartevents.hadm_id = admissions.hadm_id
                                AND chartevents.charttime::DATE between
                                   (admissions.admittime::DATE)
                                   AND (admissions.admittime::DATE + interval 48 hour)
                                AND itemid::INTEGER in ?
      -- exclude rows marked as error
      AND chartevents.error::INTEGER IS DISTINCT \
        FROM 1 \
        """


# lab = con.execute(LABQUERY, [lavbevent_meatdata['itemid'].tolist()]).fetchdf().rename(str.lower, axis='columns')
# vit = con.execute(VITQUERY, [vital_meatdata['itemid'].tolist()]).fetchdf().rename(str.lower, axis='columns')

In [192]:
pred_window = 13*24      # duration of prediction window (hours)
pred_gap = 24            # minimal gap between prediction and target (hours)
min_los = 24             # minimal length of stay (hours)
min_target_onset = 2*24  # minimal time of target since admission (hours)
pred_freq = '4H'        # prediction frequency

labs = pd.read_csv('data/labs.csv')
vits = pd.read_csv('data/vits.csv')
hosps = pd.read_csv('data/icu.csv')

for col in ['admittime', 'dischtime', 'dob', 'dod', 'intime', 'outtime']:
    hosps[col] = pd.to_datetime(hosps[col].str.strip(), errors='coerce')

## preprocessing

In [193]:
# ethnicity  - to category
hosps.ethnicity = hosps.ethnicity.str.lower()
hosps.loc[(hosps.ethnicity.str.contains('^white')),'ethnicity'] = 'white'
hosps.loc[(hosps.ethnicity.str.contains('^black')),'ethnicity'] = 'black'
hosps.loc[(hosps.ethnicity.str.contains('^hisp')) | (hosps.ethnicity.str.contains('^latin')),'ethnicity'] = 'hispanic'
hosps.loc[(hosps.ethnicity.str.contains('^asia')),'ethnicity'] = 'asian'
hosps.loc[~(hosps.ethnicity.str.contains('|'.join(['white', 'black', 'hispanic', 'asian']))),'ethnicity'] = 'other'

# ethnicity - one hot encoding
hosps['eth_white'] = (hosps['ethnicity'] == 'white').astype(int)
hosps['eth_black'] = (hosps['ethnicity'] == 'black').astype(int)
hosps['eth_hispanic'] = (hosps['ethnicity'] == 'hispanic').astype(int)
hosps['eth_asian'] = (hosps['ethnicity'] == 'asian').astype(int)
hosps['eth_other'] = (hosps['ethnicity'] == 'other').astype(int)
hosps.drop(['ethnicity', 'deathtime'], inplace=True, axis=1)

In [194]:

# Generate feature columns for los, age and mortality
def age(admittime, dob):
    if admittime < dob:
      return 0
    return admittime.year - dob.year - ((admittime.month, admittime.day) < (dob.month, dob.day))

hosps['age'] = hosps.apply(lambda row: age(row['admittime'], row['dob']), axis=1)
hosps['los_hosp_hr'] = (hosps.dischtime - hosps.admittime).dt.total_seconds()/3600
hosps['mort'] = np.where(~np.isnat(hosps.dod),1,0)

# Gender to binary
hosps['gender'] = np.where(hosps['gender']=="M", 1, 0)

# @title Q1.1 - Patient Exclusion Criteria
hosps = hosps.sort_values('admittime').groupby('subject_id').first().reset_index()
print(f"1. Include only first admissions: N={hosps.shape[0]}")

hosps = hosps[hosps.age.between(18,90)]
print(f"2. Exclusion by ages: N={hosps.shape[0]}")

# Exclude patients hospitalized for less than 24 hours
hosps = hosps[hosps['los_hosp_hr'] >= min_los]
print(f"3. Include only patients who admitted for at least {min_los} hours: N={hosps.shape[0]}")

# Exclude patients that died in the first 48 hours of admission
hours_to_death = (hosps['dod'] - hosps['admittime']).dt.total_seconds() / 3600
hosps = hosps[~((hosps['mort'].astype(bool)) & (hours_to_death < min_target_onset))]
print(f"4. Exclude patients who died within {min_target_onset}-hours of admission: N={hosps.shape[0]}")

1. Include only first admissions: N=32513
2. Exclusion by ages: N=25548
3. Include only patients who admitted for at least 24 hours: N=25168
4. Exclude patients who died within 48-hours of admission: N=24825


In [195]:
labs = labs[labs['hadm_id'].isin(hosps['hadm_id'])]
labs = pd.merge(labs,lavbevent_meatdata,on='itemid')
labs = labs[labs['valuenum'].between(labs['min'],labs['max'],  inclusive='both')]

vits = vits[vits['hadm_id'].isin(hosps['hadm_id'])]
vits = pd.merge(vits,vital_meatdata,on='itemid')
vits = vits[vits['valuenum'].between(vits['min'],vits['max'], inclusive='both')]

vits.loc[(vits['feature name'] == 'TempF'),'valuenum'] = (vits[vits['feature name'] == 'TempF']['valuenum']-32)/1.8
vits.loc[vits['feature name'] == 'TempF','feature name'] = 'TempC'

merged = pd.concat([vits, labs])
merged['charttime'] = pd.to_datetime(merged['charttime'], errors='coerce')

pivot = pd.pivot_table(merged, index=['subject_id', 'hadm_id', pd.Grouper(key='charttime', freq=pred_freq)],
                       columns=['feature name'], values='valuenum', aggfunc=['mean', 'max', 'min', 'std'])
pivot.columns = [f'{c[1]}_{c[0]}' for c in pivot.columns.to_flat_index()]

# temp = merged.copy()

merged = pd.merge(hosps, pivot.reset_index(), on=['subject_id', 'hadm_id'])
merged[pivot.columns] = merged.groupby(['subject_id', 'hadm_id'])[pivot.columns].ffill()


  pivot = pd.pivot_table(merged, index=['subject_id', 'hadm_id', pd.Grouper(key='charttime', freq=pred_freq)],


In [196]:
merged = merged.sort_values(['subject_id', 'hadm_id', 'charttime'])
labs_features_names = set(labs['feature name'])
vits_features_names = set(vits['feature name'])
labs_features = [col for col in merged.columns if col.split('_')[0] in labs_features_names]
vits_features = [col for col in merged.columns if col.split('_')[0] in vits_features_names]

lab_diff_cols = {}
for col in labs_features:
    if col.find("mean") >= 0:
      base = merged.groupby(['subject_id', 'hadm_id'])[col].transform('first')
      lab_diff_cols[f'{col}_diff'] = merged[col] - base

lab_diff_df = pd.DataFrame(lab_diff_cols)

vital_diff_cols = {}
for col in vits_features:
  if col.find("mean") >= 0:
      diff_series = merged.groupby(['subject_id', 'hadm_id'])[col].diff()
      vital_diff_cols[f'{col}_diff'] = diff_series

vital_diff_df = pd.DataFrame(vital_diff_cols)

# Concatenate back to original DataFrame
merged = pd.concat([merged, lab_diff_df, vital_diff_df], axis=1)

In [197]:
merged['charttime'] = pd.to_datetime(merged['charttime'], errors='coerce')

time_to_death = (merged['dod'] - merged['charttime']).dt.total_seconds() / (60 * 60)
merged['target'] = (pd.notnull(time_to_death) & (time_to_death <= pred_window + pred_gap)).astype(int)
merged = merged[time_to_death.isna() | (time_to_death >= pred_gap)]

In [198]:
import numpy as np
from sklearn.model_selection import GroupShuffleSplit

merged_clean = merged.reset_index(drop=True)

np.random.seed(0)

#Split to train & test (all data of a single patient needs to be in the same group)
X = merged_clean
y = merged_clean['target']
groups = merged_clean['subject_id']

gss = GroupShuffleSplit(n_splits=1, train_size=.8, test_size=0.1)
train_index, test_index = next(gss.split(X, y, groups))
val_index = list(set(X.index.to_list()) - (set(train_index.tolist()) | set(test_index.tolist())))

X_train, y_train = X.iloc[train_index], y.iloc[train_index]
X_val, y_val = X.iloc[val_index], y.iloc[val_index]
X_test, y_test = X.iloc[test_index], y.iloc[test_index]

In [199]:
from sklearn.preprocessing import StandardScaler



num_cols = X_train.select_dtypes(include='float').columns
scaler = StandardScaler()

X_train.loc[:, num_cols] = scaler.fit_transform(X_train[num_cols])
X_val.loc[:, num_cols] = scaler.transform(X_val[num_cols])
X_test.loc[:, num_cols] = scaler.transform(X_test[num_cols])

baseline_df = X_train[X_train.charttime.dt.date == X_train.admittime.dt.date].mean(axis=0).fillna(0)
X_train.loc[:, num_cols] = X_train[num_cols].fillna(baseline_df)
X_val.loc[:, num_cols] = X_val[num_cols].fillna(baseline_df)
X_test.loc[:, num_cols] = X_test[num_cols].fillna(baseline_df)

In [200]:
to_drop = ['hadm_id','admittime', 'dischtime', 'dod','dob', 'mort', 'los_hosp_hr', 'charttime','adm_to_death', 'target']

to_keep = ~X_train.columns.isin(to_drop)
to_keep = X_train.columns[to_keep]
X_train = X_train[to_keep]
X_test = X_test[to_keep]
X_val = X_val[to_keep]

In [201]:
from sklearn.cluster import KMeans
import numpy as np

def cluster_and_select_subjects(X_train, num_clusters=10, random_state=42):
    """
    Calculate the first row of each subject_id in X_train, cluster it to num_clusters 
    and choose one subject_id from each cluster.
    
    Parameters:
    X_train: DataFrame with subject_id column
    num_clusters: int, number of clusters to create
    random_state: int, for reproducibility
    
    Returns:
    list: selected subject_ids, one from each cluster
    """
    first_rows = X_train.groupby('subject_id').first().reset_index()
    
    features_for_clustering = first_rows.drop('subject_id', axis=1)
    

    kmeans = KMeans(n_clusters=num_clusters, random_state=random_state, n_init=10)
    cluster_labels = kmeans.fit_predict(features_for_clustering)
    

    first_rows['cluster'] = cluster_labels
    

    np.random.seed(random_state)
    selected_subjects = []
    
    for cluster_id in range(num_clusters):
        cluster_subjects = first_rows[first_rows['cluster'] == cluster_id]['subject_id'].values
        if len(cluster_subjects) > 0:
            selected_subject = np.random.choice(cluster_subjects)
            selected_subjects.append(selected_subject)
    
    return selected_subjects

selected_subjects = cluster_and_select_subjects(X_train, num_clusters=40, random_state=42)

DTypePromotionError: The DType <class 'numpy.dtypes.DateTime64DType'> could not be promoted by <class 'numpy.dtypes.Float64DType'>. This means that no common DType exists for the given inputs. For example they cannot be stored in a single array unless the dtype is `object`. The full list of DTypes is: (<class 'numpy.dtypes.Int64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.DateTime64DType'>, <class 'numpy.dtypes.DateTime64DType'>, <class 'numpy.dtypes.Int64DType'>, <class 'numpy.dtypes.Int64DType'>, <class 'numpy.dtypes.Int64DType'>, <class 'numpy.dtypes.Int64DType'>, <class 'numpy.dtypes.Int64DType'>, <class 'numpy.dtypes.Int64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.Float64DType'>)

In [None]:
# Create X_core with rows where subject_id is in selected_subjects
X_core = X_train[X_train['subject_id'].isin(selected_subjects)]

# Update X_train to exclude the selected subjects
X_train = X_train[~X_train['subject_id'].isin(selected_subjects)]

In [None]:
from tensorflow.keras.preprocessing.sequence import pad_sequences
import torch
def generate_series_data(df, group_col="subject_id", maxlen=18):
  grouped = df.groupby(group_col)
  subject_sequences = [group.values[:, 1:] for _, group in grouped]
  padded_tensor = pad_sequences(subject_sequences, padding='post', dtype='float32')
  sequence_lengths = [len(seq) for seq in subject_sequences]
  padding_mask = np.zeros((len(sequence_lengths), maxlen), dtype=np.float32)
  for i, length in enumerate(sequence_lengths):
      padding_mask[i, :length] = 1.0
  labels = padded_tensor[:,:,-1]
  padded_tensor = padded_tensor[:,:,:-1]
  padded_tensor = torch.tensor(padded_tensor, dtype=torch.float32)
  labels = torch.tensor(labels, dtype=torch.float32)
  padding_mask = torch.tensor(padding_mask, dtype=torch.float32)
  return padded_tensor, labels, padding_mask

In [None]:
padded_tensor, labels, padding_mask = generate_series_data(X_train, group_col="subject_id", maxlen=18)
padd_tensor_core, labels_core, padding_mask_core = generate_series_data(X_core, group_col="subject_id", maxlen=18)

## Model

In [None]:
import torch
from torch.utils.data import Dataset


class PatientDataset(Dataset):
    def __init__(self, core, X, y, padding_mask, padding_mask_core, k=5, single_patient=False):
        self.core = core
        self.X = X
        self.y = y
        self.padding_mask = padding_mask
        self.padding_mask_core = padding_mask_core
        self.k = k
        self.single_patient = single_patient
        if self.single_patient:
            self.cal_graphs()

    def __len__(self):
        return len(self.X_train)
    
    def __getitem__(self, idx):
        return self.core, self.X[idx], self.y[idx], self.padding_mask[idx]
    
    def cal_graphs(self):
        edge_list = []
        for patient_idx in range(len(self.X)):
            edges = self.build_knn_graph(self.X[patient_idx:patient_idx+1], self.core, 
                                                self.padding_mask[patient_idx:patient_idx+1], self.padding_mask_core, k=self.k)
            edge_list.append(edges)
        self.edge_list =  torch.cat(edge_list, dim=1).T
        
    @staticmethod
    def build_knn_graph(batch, core, padding_mask_batch, padding_mask_core, k=5):
        """
        Build a KNN graph from batch and core tensors.
        
        Args:
            batch: 3D tensor (batch_size, seq_len, features)
            core: 3D tensor (core_size, seq_len, features)
            padding_mask_batch: 2D tensor (batch_size, seq_len) indicating valid time points
            padding_mask_core: 2D tensor (core_size, seq_len) indicating valid time points
            k: number of nearest neighbors for patient connections
        
        Returns:
            edge_index: tensor of shape (2, num_edges) representing graph edges
        """
        batch_size, seq_len, _ = batch.shape
        core_size = core.shape[0]
        total_patients = batch_size + core_size
    
        all_patients = torch.cat([batch, core], dim=0)
        all_padding_mask = torch.cat([padding_mask_batch, padding_mask_core], dim=0)
        
        edges = []
        
        for patient_idx in range(total_patients):
            for t in range(seq_len - 1):
            
                if all_padding_mask[patient_idx, t] > 0 and all_padding_mask[patient_idx, t + 1] > 0:
                    node_curr = patient_idx * seq_len + t
                    node_next = patient_idx * seq_len + t + 1
   
                    edges.append([node_curr, node_next])
                    edges.append([node_next, node_curr])
        
        for t in range(seq_len):
            valid_patients = all_padding_mask[:, t] > 0
            valid_indices = torch.where(valid_patients)[0]
            
            if len(valid_indices) > 1:
                features_t = all_patients[valid_indices, t, :]  
                
          
                distances = torch.cdist(features_t, features_t, p=2)
                
                for i, patient_idx in enumerate(valid_indices):
                    num_neighbors = min(k + 1, len(valid_indices))
                    _, nearest_indices = torch.topk(distances[i], num_neighbors, largest=False)
                    nearest_indices = nearest_indices[1:] 
                    
                    for j in nearest_indices:
                        neighbor_idx = valid_indices[j]
                        node_curr = patient_idx * seq_len + t
                        node_neighbor = neighbor_idx * seq_len + t
                        edges.append([node_curr, node_neighbor])
        
        if edges:
            edge_index = torch.tensor(edges, dtype=torch.long).t()
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)
        
        return edge_index

dataset = PatientDataset(core=padd_tensor_core, X=padded_tensor, y=labels, padding_mask=padding_mask, padding_mask_core=padding_mask_core, k=5)
dataset.cal_graphs()

In [None]:
import torch
from torch_geometric.nn import GATConv
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import torch.nn as nn
import torch.nn.functional as F

class GraphGRUMortalityModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, n1_gat_layers, n2_gru_layers, X_core, 
                 num_heads=4, dropout=0.1, seq_len=18):
        """
        Mortality prediction model with Graph Attention + GRU layers
        
        Args:
            input_dim: Input feature dimension
            hidden_dim: Hidden dimension for GAT and GRU layers
            n1_gat_layers: Number of Graph Attention layers
            n2_gru_layers: Number of GRU layers
            X_core_dim: Core set dimension (number of core patients)
            num_heads: Number of attention heads for GAT
            dropout: Dropout rate
            seq_len: Sequence length
        """
        super(GraphGRUMortalityModel, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n1_gat_layers = n1_gat_layers
        self.n2_gru_layers = n2_gru_layers
        self.X_core = X_core
        self.seq_len = seq_len
        self.num_heads = num_heads
        
        self.gat_layers = nn.ModuleList()
        
        
        self.gat_layers.append(
            GATConv(input_dim, hidden_dim // num_heads, heads=num_heads, 
                   dropout=dropout, concat=True),
                    nn.ReLU()
        )
        
        for _ in range(n1_gat_layers - 1):
            self.gat_layers.append(
                GATConv(hidden_dim, hidden_dim // num_heads, heads=num_heads,
                       dropout=dropout, concat=True),
                        nn.ReLU()
            )
        
        self.gru = nn.GRU(hidden_dim, hidden_dim, n2_gru_layers, 
                         batch_first=True, dropout=dropout)
        
        self.classifier = nn.Linear(hidden_dim, 1)
        
  
        
    def forward(self, core, x ,padding_mask, edge_index):
        """
        Forward pass
        
        Args:
            core: Core patients tensor (X_core_dim, seq_len, input_dim)
            x: Batch patients tensor (batch_size, seq_len, input_dim)
            y: Target labels (batch_size, seq_len)
            padding_mask: Padding mask (batch_size, seq_len)
            edge_index: Graph edge indices (2, num_edges)
        
        Returns:
            predictions: Mortality predictions (batch_size, seq_len, 1)
        """
        batch_size = x.size(0)
        
        # Combine core and batch data
        all_patients = torch.cat([x, core], dim=0)  # (batch_size + X_core_dim, seq_len, input_dim)
        total_patients = batch_size + self.X_core_dim
        
        # Reshape for graph processing: (total_patients * seq_len, input_dim)
        graph_input = all_patients.view(total_patients * self.seq_len, -1)
        
        # Apply GAT layers
        for gat_layer in self.gat_layers:
            graph_input = F.relu(gat_layer(graph_input, edge_index))
            graph_input = self.dropout(graph_input)
        
        # Reshape back to sequence format: (total_patients, seq_len, hidden_dim)
        graph_output = graph_input.view(total_patients, self.seq_len, -1)
        
        # Extract only batch patients (exclude core)
        batch_output = graph_output[:batch_size]  # (batch_size, seq_len, hidden_dim)
        
        # Apply GRU layers
        # Pack sequences for efficient processing
        sequence_lengths = padding_mask.sum(dim=1).cpu().long()
        packed_input = pack_padded_sequence(batch_output, sequence_lengths, 
                                          batch_first=True, enforce_sorted=False)
        
        packed_output, _ = self.gru(packed_input)
        gru_output, _ = pad_packed_sequence(packed_output, batch_first=True, 
                                          total_length=self.seq_len)
        
        # Apply classifier
        predictions = self.classifier(gru_output)  # (batch_size, seq_len, 1)
        
        return predictions.squeeze(-1)  # (batch_size, seq_len)