# Reproduction of Paper `Learning the Graphical Structure of Electronic Health Records with Graph Convolutional Transformer` by DL4H Team 137 

In [1]:
# $ pip freeze > requirements.txt
# $ conda env export > environment.yml


## Introduction


*   Background of the problem
    * This study focuses on readmission/mortality prediction.
    * Unstructured data, particularly claims data, lacks a clear structure, making it challenging for models like MiME (Choi et al., 2018) to be utilized effectively.
    * The primary difficulties include discovering the hidden structure of the data while simultaneously making predictions.

    * The approach outlined in the paper is effective according to their test metrics.
*   Paper explanation
    * The paper proposes a new method, the Graph Convolutional Transformer (GCT), to jointly learn the hidden structure and perform the prediction task. This method uses unstructured data as the initial input and achieves accurate predictions for general medical tasks.

    * TEST METRICS FROM THE PAPER ARE SHOWN BELOW
    * It offers significant benefits for individuals without access to structured data. Additionally, the learned structure can be useful for others who wish to reuse the learned structure for future studies.



# 

# Scope of reproducibility (5)
The scope of this reproducibility study focuses on verifying the results claimed in the paper "Learning the Graphical Structure of Electronic Health Records with Graph Convolutional Transformer". The goal is to reproduce the model's ability to predict readmission/mortality using electronic health records as described in the original research.

# Methodology (15)

## Environment
### Python version
- Python 3.10
### Dependencies/packages needed
- torch==1.7.1
- numpy==1.19.5
- pandas==1.2.0
- scikit-learn==0.24.1
- matplotlib==3.3.3

## Data
### Data download instruction
- Data can be downloaded from `[Insert Link Here]`.
### Data descriptions with helpful charts and visualizations
- Include pie charts and histograms of key demographics and clinical features.
### Preprocessing code + command
- `python preprocessing.py --input path/to/raw/data --output path/to/cleaned/data`

#### load data

In [1]:
#load data
import pandas as pd
import numpy as np
filenames = ['patient', 'admissionDx', 'diagnosis', 'treatment']
raw_data_dir='./'
SUBSET_RATIO=1.0

def load_raw_data(raw_data_dir, filenames, subset_ratio=SUBSET_RATIO):

    data_frames = {}
    for filename in filenames:
        file_path = raw_data_dir + filename +'.csv'
        df = pd.read_csv(file_path)

        if subset_ratio < 1.0:
            df = df.sample(frac=subset_ratio)
        data_frames[filename]=df
    return data_frames
data_frames = load_raw_data(raw_data_dir, filenames)

#### Preprocess data and generate the eicu record

In [2]:
import sys
# preprocess patient data
def process_patient(df, hour_threshold=24):
    # Calculate encounter_timestamp and create a temporary DataFrame for sorting
    if df['patientunitstayid'].duplicated().any():      
        print('Duplicate encounter ID!!')
        sys.exit(0)
    df['encounter_timestamp'] = -df['hospitaladmitoffset'].astype(int)

    # Sorting patients by their IDs and then by the encounter timestamp
    df_sorted = df.sort_values(['patienthealthsystemstayid', 'encounter_timestamp'])

    # Detect readmissions by checking if the next stay is within the same patient ID
    
    df_sorted['readmission'] = True  # Initially mark all as True
    df_sorted.loc[df_sorted.groupby('patienthealthsystemstayid')['patientunitstayid'].tail(1).index, 'readmission'] = False
    df_sorted['unitdischargestatus'] = df_sorted['unitdischargestatus']=='Expired'

    duration_threshold = hour_threshold * 60.0
    mask = df_sorted['unitdischargeoffset'] <= duration_threshold
    
    df_sorted = df_sorted[mask]
    rename_dict = {'patienthealthsystemstayid':'patient_id',
                   'patientunitstayid':'encounter_id',
                   'encounter_timestamp':'encounter_timestamp',
                   'unitdischargestatus':'expired',
                  }
    df_selected = df_sorted[ list(rename_dict.keys())+['readmission'] ]
    df_renamed = df_selected.rename(columns=rename_dict)
    return df_renamed

In [3]:
# 
patient_dataframe = process_patient(data_frames['patient'])

In [4]:
patient_dataframe[:5]

Unnamed: 0,patient_id,encounter_id,encounter_timestamp,expired,readmission
1,128927,141178,14,False,True
5,128943,141197,25,False,True
7,128952,141208,1,False,False
9,128970,141229,4,False,False
12,128995,141260,18,False,False


In [5]:
# test correctness

# patient_dataframe[patient_dataframe['readmission']==True][:5]
# pdf = data_frames['patient']
# pdf[ pdf['patienthealthsystemstayid']==133737]
# patient_dataframe[pdf['patientunitstayid']==147378]['readmission']

In [6]:
# process admission
def process_admission_dx(df,patient_df):
    # Check and report the number of missing encounter IDs
    
    df['admitdxpath'] = df['admitdxpath'].str.lower()
    
    patient_encounter_ids = set(patient_df['encounter_id'])

    mask = df['patientunitstayid'].isin(patient_encounter_ids)

    missing_eid = df[~mask]
    
    print('admission without Encounter ID:', len(missing_eid))
    
    df = df[mask]
    rename_dict = {'patientunitstayid':'encounter_id',
                   'admitdxpath':'dx_id'
                  }
    df_selected = df[list(rename_dict.keys()) ]
    df_renamed = df_selected.rename(columns=rename_dict)
    
    return df_renamed

admission_dataframe = process_admission_dx(data_frames['admissionDx'],patient_dataframe)

admission without Encounter ID: 450589


In [7]:
admission_dataframe[:5]

Unnamed: 0,encounter_id,dx_id
26,2900366,admission diagnosis|was the patient admitted f...
27,2900366,admission diagnosis|all diagnosis|non-operativ...
28,2900366,admission diagnosis|non-operative organ system...
36,2900423,admission diagnosis|non-operative organ system...
37,2900423,admission diagnosis|was the patient admitted f...


In [8]:
def process_diagnosis(df,patient_df):
    # Check and report the number of missing encounter IDs
    
    df['diagnosisstring'] = df['diagnosisstring'].str.lower()
    
    patient_encounter_ids = set(patient_df['encounter_id'])

    mask = df['patientunitstayid'].isin(patient_encounter_ids)

    missing_eid = df[~mask]
    
    print('Admission Diagnosis without Encounter ID:', len(missing_eid))
    
    df = df[mask]
    rename_dict = {'patientunitstayid':'encounter_id',
                   'diagnosisstring':'dx_id'
                  }
    df_selected = df[list(rename_dict.keys()) ]
    df_renamed = df_selected.rename(columns=rename_dict)
    
    return df_renamed
diagnosis_dataframe = process_diagnosis(data_frames['diagnosis'],patient_dataframe)

Admission Diagnosis without Encounter ID: 2483092


In [9]:
diagnosis_dataframe[:5]

Unnamed: 0,encounter_id,dx_id
30,141229,cardiovascular|arrhythmias|atrial fibrillation
31,141229,cardiovascular|ventricular disorders|acute pul...
32,141229,cardiovascular|ventricular disorders|congestiv...
33,141229,neurologic|altered mental status / pain|change...
34,141229,cardiovascular|ventricular disorders|acute pul...


In [10]:
def process_treatment(df, patient_df):
    
    df['treatmentstring'] = df['treatmentstring'].str.lower()
    
    patient_encounter_ids = set(patient_df['encounter_id'])

    mask = df['patientunitstayid'].isin(patient_encounter_ids)

    missing_eid = df[~mask]
    
    print('treatment without Encounter ID:', len(missing_eid))
    
    df = df[mask]
    rename_dict = {'patientunitstayid':'encounter_id',
                   'treatmentstring':'treatment'
                  }

    df_selected = df[list(rename_dict.keys()) ]
    df_renamed = df_selected.rename(columns=rename_dict)
    
    return df_renamed

treatment_dataframe = process_treatment(data_frames['treatment'], patient_dataframe)

treatment without Encounter ID: 3372000


In [11]:
treatment_dataframe[:5]

Unnamed: 0,encounter_id,treatment
224,242203,gastrointestinal|medications|stress ulcer prop...
225,242203,pulmonary|ventilation and oxygenation|oxygen t...
226,242203,renal|urinary catheters|foley catheter
227,242203,renal|electrolyte correction|treatment of hype...
228,242203,gastrointestinal|medications|antiemetic|seroto...


## 

In [12]:
"""double check ok"""
# output from python 2.7 environment
# Processing patient.csv
# Processing admission diagnosis.csv
# Admission Diagnosis without Encounter ID: 450589
# Processing diagnosis.csv
# Diagnosis without Encounter ID: 2483092
# Processing treatment.csv
# Treatment without Encounter ID: 3372000
# Accepted treatments: 316745

# This is the same as above log output 

'double check ok'

In [13]:
print(len(patient_dataframe),len(admission_dataframe),len(diagnosis_dataframe),len(treatment_dataframe))

68076 176269 227580 316745


In [14]:
'The content in sequence seqex_list'
# Context Features:
# key: label.expired value: int64_list {
#   value: 0
# }

# key: label.readmission value: int64_list {
#   value: 0
# }

# key: patientId value: bytes_list {
#   value: "2630449:3229400"
# }

# Feature Lists:
# key: proc_ids
# feature: bytes_list {
#   value: "pulmonary|ventilation and oxygenation|oxygen therapy (< 40%)|nasal cannula"
#   value: "cardiovascular|intravenous fluid|normal saline administration"
#   value: "endocrine|glucose metabolism|insulin|continuous infusion"
#   value: "gastrointestinal|medications|stress ulcer prophylaxis|famotidine"
#   value: "cardiovascular|arrhythmias|anticoagulant administration|low molecular weight heparin|enoxaparin"
# }

# key: dx_ints
# feature: int64_list {
#   value: 202
#   value: 0
#   value: 201
#   value: 164
# }

# key: dx_ids
# feature: bytes_list {
#   value: "endocrine|glucose metabolism|dka"
#   value: "admission diagnosis|was the patient admitted from the o.r. or went to the o.r. within 4 hours of admission?|no"
#   value: "admission diagnosis|all diagnosis|non-operative|diagnosis|metabolic/endocrine|diabetic ketoacidosis"
#   value: "admission diagnosis|non-operative organ systems|organ system|metabolic/endocrine"
# }

# key: proc_ints
# feature: int64_list {
#   value: 68
#   value: 27
#   value: 273
#   value: 80
#   value: 417
# }

# first key: 1392393:1774519
# content of dx_str2int:
# key: surgery|respiratory failure|ventilatory failure|suspected value: 2762
# key: burns/trauma|trauma-other injuries|traumatic amputation|arm/hand value: 3196
# key: endocrine|fluids and electrolytes|hypernatremia|moderate (146 - 155 meq/dl) value: 1284
# key: admission diagnosis|all diagnosis|operative|diagnosis|cardiovascular|thrombectomy (with general anesthesia) value: 975
# key: endocrine|fluids and electrolytes|hyponatremia|severe (< 125 meq/dl) value: 709
# content of treat_str2int:
# key: pulmonary|surgery / incision and drainage of thorax|pulmonary resection|lobectomy value: 1083
# key: neurologic|ich/ cerebral infarct|anticonvulsants|phenytoin value: 603
# key: cardiovascular|arrhythmias|digoxin value: 284
# key: toxicology|drug overdose|agent specific therapy|beta blockers overdose|atropine value: 1489
# key: oncology|medications|analgesics|oral analgesics value: 17


'The content in sequence seqex_list'

In [15]:
def build_dataframe(patient_dataframe, treatment_dataframe,diagnosis_dataframe,admission_dataframe, min_num_codes=1,
                max_num_codes=50):

    '''
    This function is to bulid the dataframe for training,
    it is equals to build_seqex in process_eicu.py
    '''
    filter = lambda x: len(x)>=min_num_codes and len(x)<=max_num_codes

    # merge admission and diagnosis
    merged_admission_diagnosis = pd.concat([admission_dataframe, diagnosis_dataframe], axis=0)

    dx_list = list(set(merged_admission_diagnosis['dx_id']))

    dx_str2int = {s:i for i,s in enumerate(dx_list)}

    merged_admission_diagnosis['dx_ints'] = merged_admission_diagnosis['dx_id'].map(dx_str2int)
    
    merged_admission_diagnosis = merged_admission_diagnosis.groupby('encounter_id')['dx_ints'].agg(list).reset_index()

    merged_admission_diagnosis =merged_admission_diagnosis[merged_admission_diagnosis['dx_ints'].apply(filter)]

    # aggrigate treatment_dataframe
    
    treat_list = list(set(treatment_dataframe['treatment']))

    treat_str2int = {s:i for i,s in enumerate(treat_list)}
    
    treatment_dataframe['proc_ints'] =  treatment_dataframe['treatment'].map(treat_str2int)
    
    treatment_dataframe = treatment_dataframe.groupby('encounter_id')['proc_ints'].agg(list).reset_index()

    treatment_dataframe =treatment_dataframe[treatment_dataframe['proc_ints'].apply(filter)]
    
    #print(len(merged_admission_diagnosis),len(treatment_dataframe))
    
    # merge patient, admission and diagnosis
    merged_patient_proc_ints = pd.merge(merged_admission_diagnosis, patient_dataframe, on='encounter_id', how='inner')

    # merge patient, all

    merged_df= pd.merge(merged_patient_proc_ints, treatment_dataframe, on='encounter_id', how='inner')
    
    merged_df['patientId'] = merged_df.apply(lambda row: (row['patient_id'], row['encounter_id']), axis=1)

    merged_df.drop(['patient_id','encounter_timestamp'], axis=1, inplace=True)

    merged_df.set_index('encounter_id', inplace=True)
    
    return  merged_df, dx_str2int, treat_str2int,dx_list,treat_list
    
df, dx_str2int, treat_str2int,dx_list,treat_list = build_dataframe(patient_dataframe, treatment_dataframe,diagnosis_dataframe,admission_dataframe)


In [16]:
#TODO 
# Need to check the difference, 
# data process_eicu.py gives 41026
# It may comes from the joining method and datatype
print(len(df))

40410


In [17]:
df.loc[3353254]

dx_ints        [2278, 125, 2067, 2130, 3300]
expired                                False
readmission                            False
proc_ints                        [1002, 820]
patientId                 (2743102, 3353254)
Name: 3353254, dtype: object

In [18]:
print("""
Output of the process_eicu.py for encounter Id: 3353254
The number of the "dx_ints" and "proc_ints" is the same
The value is different,which is ok, the value is just an 
index to dictionary
""")
# seqx for 3353254: context {
#   feature {
#     key: "label.expired"
#     value {
#       int64_list {
#         value: 0
#       }
#     }
#   }
#   feature {
#     key: "label.readmission"
#     value {
#       int64_list {
#         value: 0
#       }
#     }
#   }
#   feature {
#     key: "patientId"
#     value {
#       bytes_list {
#         value: "2743102:3353254"
#       }
#     }
#   }
# }
# feature_lists {
#   feature_list {
#     key: "dx_ids"
#     value {
#       feature {
#         bytes_list {
#           value: "admission diagnosis|non-operative organ systems|organ system|gastrointestinal"
#           value: "admission diagnosis|was the patient admitted from the o.r. or went to the o.r. within 4 hours of admission?|no"
#           value: "renal|disorder of kidney|acute renal failure|due to hypovolemia/decreased circulating volume"
#           value: "admission diagnosis|all diagnosis|non-operative|diagnosis|gastrointestinal|bleeding, lower gi"
#           value: "gastrointestinal|gi bleeding / pud|lower gi bleeding"
#         }
#       }
#     }
#   }
#   feature_list {
#     key: "dx_ints"
#     value {
#       feature {
#         int64_list {
#           value: 31
#           value: 0
#           value: 225
#           value: 323
#           value: 324
#         }
#       }
#     }
#   }
#   feature_list {
#     key: "proc_ids"
#     value {
#       feature {
#         bytes_list {
#           value: "cardiovascular|intravenous fluid|normal saline administration|fluid bolus (250-1000mls)"
#           value: "cardiovascular|intravenous fluid|blood product administration|packed red blood cells|transfusion of > 2 units prbc\'s"
#         }
#       }
#     }
#   }
#   feature_list {
#     key: "proc_ints"
#     value {
#       feature {
#         int64_list {
#           value: 105
#           value: 483
#         }
#       }
#     }
#   }
# }


Output of the process_eicu.py for encounter Id: 3353254
The number of the "dx_ints" and "proc_ints" is the same
The value is different,which is ok, the value is just an 
index to dictionary



In [19]:
print(len(dx_str2int),len(treat_str2int))
#TODO Check the difference. Code gives 

3351 2212


In [20]:
from sklearn.model_selection import train_test_split

In [21]:

def select_train_valid_test(df, target ='readmission', random_seed=0):
    
    train_df, temp_df = train_test_split(df, test_size=0.2, random_state=random_seed)
    valid_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=random_seed)
    
    return train_df,valid_df,test_df
train_df, validate_df, test_df = select_train_valid_test(df)

In [22]:
from itertools import product
def generate_combinations(row):
    return list(product(row['dx_ints'], row['proc_ints']))

total_visit = len(df)
def count_conditional_prob_dp(df,total_visit):
    """
    This is equals to the count_conditional_prob_dp in graph_convolutional_transformer.py
    
    """

    dx_explode = df['dx_ints'].explode()
    dx_freqs = dx_explode.value_counts().to_dict()
    proc_explode = df['proc_ints'].explode()
    proc_freqs = proc_explode.value_counts().to_dict()
    
    df['dp'] = df.apply(generate_combinations, axis=1)
    exploded_df = df.explode('dp')
    dp_freqs = exploded_df['dp'].value_counts().to_dict()
    
    # print(dp_freqs)

    dx_probs = dict([(k, v / float(total_visit)) for k, v in dx_freqs.items()])
    proc_probs = dict([(k, v / float(total_visit)) for k, v in proc_freqs.items()])
    dp_probs = dict([(k, v / float(total_visit)) for k, v in dp_freqs.items()])
    
    dp_cond_probs = {}
    pd_cond_probs = {}
    for dx, dx_prob in dx_probs.items():
        for proc, proc_prob in proc_probs.items():
            dp = tuple([dx, proc])
            pd = tuple([proc, dx])
            if dp in dp_probs:
                dp_cond_probs[dp] = dp_probs[dp] / dx_prob
                pd_cond_probs[pd] = dp_probs[dp] / proc_prob
            else:
                dp_cond_probs[dp] = 0.0
                pd_cond_probs[pd] = 0.0
    
    return dx_probs, proc_probs, dp_probs, dp_cond_probs, pd_cond_probs
    
    
dx_probs, proc_probs, dp_probs, dp_cond_probs, pd_cond_probs = count_conditional_prob_dp(train_df,total_visit)

In [23]:
len(dp_cond_probs),len(pd_cond_probs),len(dp_probs)

(6381180, 6381180, 259759)

In [24]:
print(next(iter(pd_cond_probs.items())))

((484, 125), 0.7152446564211271)


In [25]:
df[:5]

Unnamed: 0_level_0,dx_ints,expired,readmission,proc_ints,patientId
encounter_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
242203,"[963, 3286, 125, 52, 978, 2866, 2866, 52, 978]",False,False,"[1212, 1200, 1136, 758, 1495, 1996, 1212, 1495...","(207372, 242203)"
242401,"[2951, 125, 932, 2725, 121, 10, 1552, 1552]",False,False,"[2023, 498, 498, 200, 824, 1392, 403, 2078, 20...","(207556, 242401)"
242429,"[125, 2028, 2951, 1865, 1682, 188, 2358, 3295]",False,False,"[164, 1045]","(207580, 242429)"
242476,"[1585, 125, 2951, 3080, 2502, 2502, 2502, 1973...",False,False,"[700, 2031, 2078, 170, 241, 530, 1721, 700, 17...","(207623, 242476)"
242757,"[1334, 125, 234, 1295, 2591, 2782]",False,False,"[1612, 736, 54, 1656, 404, 2075]","(207869, 242757)"


In [26]:
import numpy as np    
#simliar to funciton add_sparse_prior_guide_dp in code base
def add_sparse_prior_guide_dp(df,dp_cond_probs,pd_cond_probs,max_num_codes=50):

    df['prior_indices'] = None
    df['prior_values'] = None

    # Iterate through DataFrame rows

    for idx, row in df.iterrows():
        dx_ids = row['dx_ints']
        proc_ids = row['proc_ints']

        dp_combinations = list(product(range(len(dx_ids)), range(len(proc_ids))))
        pd_combinations = list(product(range(len(proc_ids)), range(len(dx_ids))))
        
        # Adjust indices for procedures
        dp_combinations_adjusted = [(x[0], max_num_codes + x[1]) for x in dp_combinations]
        pd_combinations_adjusted = [(max_num_codes + x[0], x[1]) for x in pd_combinations]
        
        # Combine indices and calculate values
        all_indices = dp_combinations_adjusted + pd_combinations_adjusted

        # Fetch probabilities using dictionary get method with default of 0.0 for missing entries
        
        dp_values = [dp_cond_probs.get((dx_ids[i],proc_ids[j]), 0.0) for i, j in dp_combinations]
        pd_values = [pd_cond_probs.get((proc_ids[i],dx_ids[j]), 0.0) for i, j in pd_combinations]
        # Assign to DataFrame
        df.at[idx, 'prior_indices'] = all_indices
        df.at[idx, 'prior_values'] = dp_values + pd_values

    return df    

In [27]:

train_df= add_sparse_prior_guide_dp(train_df,dp_cond_probs,pd_cond_probs)

validate_df = add_sparse_prior_guide_dp(validate_df,dp_cond_probs,pd_cond_probs)

test_df = add_sparse_prior_guide_dp(test_df,dp_cond_probs,pd_cond_probs)


In [28]:
print(len(df))

40410


In [29]:
print(len(train_df))

32328


In [30]:
# seqex for key '2743102:3353254': 
# context {
#   feature {
#     key: "label.expired"
#     value {
#       int64_list {
#         value: 0
#       }
#     }
#   }
#   feature {
#     key: "label.readmission"
#     value {
#       int64_list {
#         value: 0
#       }
#     }
#   }
#   feature {
#     key: "patientId"
#     value {
#       bytes_list {
#         value: "2743102:3353254"
#       }
#     }
#   }
# }
# feature_lists {
#   feature_list {
#     key: "dx_ids"
#     value {
#       feature {
#         bytes_list {
#           value: "admission diagnosis|non-operative organ systems|organ system|gastrointestinal"
#           value: "admission diagnosis|was the patient admitted from the o.r. or went to the o.r. within 4 hours of admission?|no"
#           value: "renal|disorder of kidney|acute renal failure|due to hypovolemia/decreased circulating volume"
#           value: "admission diagnosis|all diagnosis|non-operative|diagnosis|gastrointestinal|bleeding, lower gi"
#           value: "gastrointestinal|gi bleeding / pud|lower gi bleeding"
#         }
#       }
#     }
#   }
  # feature_list {
  #   key: "dx_ints"
  #   value {
  #     feature {
  #       int64_list {
  #         value: 31
  #         value: 0
  #         value: 225
  #         value: 323
  #         value: 324
  #       }
  #     }
  #   }
  # }
#   feature_list {
#     key: "prior_indices"
#     value {
#       feature {
#         int64_list {
#           value: 0
#           value: 50
#           value: 0
#           value: 51
#           value: 1
#           value: 50
#           value: 1
#           value: 51
#           value: 2
#           value: 50
#           value: 2
#           value: 51
#           value: 3
#           value: 50
#           value: 3
#           value: 51
#           value: 4
#           value: 50
#           value: 4
#           value: 51
#           value: 50
#           value: 0
#           value: 50
#           value: 1
#           value: 50
#           value: 2
#           value: 50
#           value: 3
#           value: 50
#           value: 4
#           value: 51
#           value: 0
#           value: 51
#           value: 1
#           value: 51
#           value: 2
#           value: 51
#           value: 3
#           value: 51
#           value: 4
#         }
#       }
#     }
#   }
  # feature_list {
  #   key: "prior_values"
  #   value {
  #     feature {
  #       float_list {
  #         value: 0.0373423844576
  #         value: 0.0242483019829
  #         value: 0.0424405224621
  #         value: 0.00441289320588
  #         value: 0.129310339689
  #         value: 0.00862068962306
  #         value: 0.0530973449349
  #         value: 0.030973451212
  #         value: 0.0440097786486
  #         value: 0.0366748161614
  #         value: 0.0655877366662
  #         value: 0.942078351974
  #         value: 0.012776831165
  #         value: 0.0204429309815
  #         value: 0.0153321977705
  #         value: 0.3355704844
  #         value: 0.771812081337
  #         value: 0.00671140942723
  #         value: 0.0939597338438
  #         value: 0.10067114234
  #       }
  #     }
  #   }
  # }
#   feature_list {
#     key: "proc_ids"
#     value {
#       feature {
#         bytes_list {
#           value: "cardiovascular|intravenous fluid|normal saline administration|fluid bolus (250-1000mls)"
#           value: "cardiovascular|intravenous fluid|blood product administration|packed red blood cells|transfusion of > 2 units prbc\'s"
#         }
#       }
#     }
#   }
#   feature_list {
#     key: "proc_ints"
#     value {
#       feature {
#         int64_list {
#           value: 105
#           value: 483
#         }
#       }
#     }
#   }
# }

In [31]:
print(list(train_df.loc[3353254]['prior_indices']))

[(0, 50), (0, 51), (1, 50), (1, 51), (2, 50), (2, 51), (3, 50), (3, 51), (4, 50), (4, 51), (50, 0), (50, 1), (50, 2), (50, 3), (50, 4), (51, 0), (51, 1), (51, 2), (51, 3), (51, 4)]


In [32]:
print(list(train_df.loc[3353254]['prior_values']))

[0.051671732522796346, 0.04609929078014184, 0.05969278631149135, 0.007816449543068248, 0.07432432432432433, 0.07207207207207207, 0.2894736842105263, 0.006578947368421052, 0.0991304347826087, 0.11304347826086956, 0.061855670103092786, 0.9308671922377199, 0.020012128562765314, 0.026682838083687085, 0.03456640388114009, 0.3408239700374532, 0.752808988764045, 0.1198501872659176, 0.003745318352059925, 0.24344569288389512]


In [33]:
train_df[:5]

Unnamed: 0_level_0,dx_ints,expired,readmission,proc_ints,patientId,dp,prior_indices,prior_values
encounter_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
2968542,"[2669, 125, 2951, 272]",False,True,[1793],"(2406228, 2968542)","[(2669, 1793), (125, 1793), (2951, 1793), (272...","[(0, 50), (1, 50), (2, 50), (3, 50), (50, 0), ...","[0.2540106951871658, 0.07454792922418822, 0.08..."
1198112,"[1944, 125, 2951, 415, 2725, 2502, 2007, 2794]",False,False,"[546, 1885, 647, 530, 878]","(895769, 1198112)","[(1944, 546), (1944, 1885), (1944, 647), (1944...","[(0, 50), (0, 51), (0, 52), (0, 53), (0, 54), ...","[0.030645161290322583, 0.3193548387096774, 0.0..."
990206,"[2058, 527, 444]",False,False,"[251, 2181]","(730623, 990206)","[(2058, 251), (2058, 2181), (527, 251), (527, ...","[(0, 50), (0, 51), (1, 50), (1, 51), (2, 50), ...","[0.06666666666666667, 0.18333333333333335, 0.5..."
2764033,"[125, 2951, 2190, 2791, 2791, 2791, 376, 376, ...",False,True,"[943, 1994, 2181, 1885, 1743, 2141, 1651, 2141...","(2227126, 2764033)","[(125, 943), (125, 1994), (125, 2181), (125, 1...","[(0, 50), (0, 51), (0, 52), (0, 53), (0, 54), ...","[0.007427571456348435, 0.0006610927474236827, ..."
1549215,"[234, 125, 1055, 172, 527, 594, 665]",False,False,"[251, 780, 484, 137]","(1193171, 1549215)","[(234, 251), (234, 780), (234, 484), (234, 137...","[(0, 50), (0, 51), (0, 52), (0, 53), (1, 50), ...","[0.054989816700611, 0.0622636019784696, 0.1844..."


In [34]:
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader


# Example DataFrame
print(len(dx_str2int),len(treat_str2int))
# 3351 2212

vcob_size={
    "dx_ints":len(dx_str2int),
    "proc_ints":len(treat_str2int) 
}
selected_features = ['dx_ints','proc_ints'] 
# encounter_id is not in selected feature, because it is now a index of the dataframe
# prior_indices and prior_values features are not selected to enter the model, 
# they are only used to calculate the guide matrix and the prior matrix 
# they have vary lenght and some of them too large to pad
# they will be selected by the encounter_id from the df when used
class CustomDataset(Dataset):
    def __init__(self, dataframe,max_num_code=50, vcob_size=vcob_size,selected_features = selected_features, label_name = 'readmission'):
        self.dataframe = dataframe[selected_features +[label_name]]
        self.max_num_code = max_num_code
        self.vcob_size = vcob_size
        self.label_name = label_name

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        
        row = self.dataframe.iloc[idx]
        encounter_id = self.dataframe .index[idx]
        
        dict_row =row.to_dict()
        feature_dict = {}
        
        feature_dict['encounter_id']=encounter_id
        for name in selected_features:
            
            if name in self.vcob_size:
                n= len(dict_row[name])
                pad = self.vcob_size[name]
                feature_dict[name] = torch.tensor(dict_row[name] + [pad]*(self.max_num_code-n),dtype=torch.int)
            else:
                feature_dict[name] = torch.tensor(dict_row[name],dtype=long)
        
        return feature_dict, torch.tensor(dict_row[self.label_name])
    

train_dataset = CustomDataset(train_df)
train_dataloader = DataLoader(train_dataset, batch_size=32,shuffle=True)

validate_dataset = CustomDataset(validate_df)
validate_dataloader = DataLoader(validate_dataset, batch_size=32,shuffle=False)

test_dataset = CustomDataset(test_df)
test_dataloader = DataLoader(test_dataset, batch_size=32,shuffle=False)
# Iterate through the DataLoader in a training loop
# for dict_row,label in dataloader:
#     #print(dict_row,label)
#     for encounter_id in dict_row["encounter_id"]:
#         eid  = int(encounter_id)
#         print(example_df.loc[eid][["prior_indices","prior_values"]])



3351 2212


In [35]:
example_batch, example_label = next(iter(train_dataloader))

In [36]:
print("One example input to the model, for testing function:\n", example_batch,example_label)

One example input to the model, for testing function:
 {'encounter_id': tensor([ 779527,  569729, 1364166,  805676, 3329996, 1470270, 1732882, 3344482,
        1693061,  406594, 1328810, 1250992,  438006, 3341940, 1990097,  965101,
        1778564, 2982035, 2837691, 2232916,  685408, 1776269,  969900, 3035199,
        3332199,  415952, 2473092, 1131062, 2921720, 3243809,  254952, 1713323]), 'dx_ints': tensor([[ 963,  125,  996,  ..., 3351, 3351, 3351],
        [ 238,  472, 2886,  ..., 3351, 3351, 3351],
        [ 942,  238, 2231,  ..., 3351, 3351, 3351],
        ...,
        [2951,  125, 1743,  ..., 3351, 3351, 3351],
        [ 125, 2951,  880,  ..., 3351, 3351, 3351],
        [3099,  125, 1501,  ..., 3351, 3351, 3351]], dtype=torch.int32), 'proc_ints': tensor([[1793,  798, 1793,  ..., 2212, 2212, 2212],
        [1350, 1350,  474,  ..., 2212, 2212, 2212],
        [ 234,  794,  961,  ..., 2212, 2212, 2212],
        ...,
        [1819,  683,  177,  ..., 2212, 2212, 2212],
        [ 868, 

# Model
### Citation to the original paper
- Choi et al., "Learning the Graphical Structure of Electronic Health Records with Graph Convolutional Transformer", 2021.
### Link to the original paper’s repo (if applicable)
- `[GitHub Repo](https://github.com/author/repo)`
### Model descriptions
- The GCT model uses graph convolution combined with a transformer architecture to process unstructured EHR data.

### Implementation code
- `code`


In [47]:
import torch
import torch.nn as nn
# create embedder 
class FeatureEmbedder(nn.Module):
    def __init__(self, embedding_size,
                 vocab_sizes=vcob_size,

                 feature_keys = ["dx_ints","proc_ints"]):
        
        super(FeatureEmbedder, self).__init__()
        self.embeddings = nn.ModuleDict()
        self.vocab_sizes = vocab_sizes
        self.embedding_size = embedding_size
        self.feature_keys=feature_keys
        # Adding one for the padding index
        for key in feature_keys:
            self.embeddings[key] = nn.Embedding(num_embeddings=vocab_sizes[key] + 1, embedding_dim=embedding_size, padding_idx=vocab_sizes[key])
        
        # Special case for 'visit' embedding
        self.embeddings['visit'] = nn.Embedding(num_embeddings=1, embedding_dim=embedding_size)

    def forward(self, feature_map):
        '''
        feature_map: key to  max_num_codes length code, one key "idx_ints" is like: {"idx_ints":[[11,0,1,12,...],[[1,12,3,5...]]]} 
        the shape of the tensor is (batch,max_num_codes)

        result: 
            embeddings: embeddings' shape is (batch_size,1+max_num_codes+max_num_codes,embedding_size)
                max_num_codes+max_num_codes+1 is for visit, idx embeddings, proc embeddings
            masks: a mask for each embeddings, the shape is (batch_size,1+max_num_codes+max_num_codes)
        '''
        embeddings = {}
        masks = {}
        batch_size,max_num_codes = feature_map['proc_ints'].shape
        for key in self.feature_keys:
            # pad unused to vocab_size
            padding = self.vocab_sizes[key]
            ids = feature_map[key]
            if len(ids)!=max_num_codes:
                Exception("current code length {},{} is not equals to the max_num_codes:{}".format(key,len(ids),max_num_codes))
            # Embedding lookup
            embeddings[key] = self.embeddings[key](ids)
            #print(embeddings[key].shape)
            
            # Create mask
            mask = (ids != padding).int()
            masks[key] = mask

        # Handle the 'visit' embedding separately
        
        embeddings['visit'] = self.embeddings['visit'](torch.zeros((batch_size,1),dtype=torch.int32))
        masks['visit'] = torch.ones(batch_size,1)
        
        # hardcode here to ensure the order of the embedings
        feature_names = ['visit','dx_ints','proc_ints']
        embeddings = [embeddings[name] for name in feature_names]
        embeddings = torch.cat(embeddings,axis=1)

        masks = [masks[name] for name in feature_names]
        masks = torch.cat(masks,axis=1)
        
        return embeddings, masks



In [49]:
embedder = FeatureEmbedder(16)
embeddings, masks = embedder(example_batch)
embeddings.shape,masks.shape

(torch.Size([32, 101, 16]), torch.Size([32, 101]))

In [None]:
import torch
import torch.nn.functional as F

def create_matrix_vdp(df, features, mask, use_prior, use_inf_mask=True, max_num_codes=50, prior_scalar=0.5):
    """
    Creates guide matrix and prior matrix when feature_set='vdp' in PyTorch.
    
    Args:
        features (dict): Dictionary of lists of integers for each feature.
        mask (Tensor): 3D tensor (batch_size, num_features, 1) indicating padded parts.
        use_prior (bool): Whether to create the prior matrix.
        use_inf_mask (bool): Whether to create the guide matrix.
        max_num_codes (int): Maximum number of codes per feature inside a single visit.
        prior_scalar (float): Scalar to hard-code the diagonal elements of the prior matrix.
    
    Returns:
        Tuple of Tensors: guide matrix and prior guide matrix.
    """
    eids = features['encounter_id'] 
    
    batch_size = eids.size(0)
    num_dx_ids = max_num_codes
    num_proc_ids = max_num_codes
    num_codes = 1 + num_dx_ids + num_proc_ids  # 1 for 'visit' 
    
    guide = None
    if use_inf_mask:
        row0 = torch.cat([torch.zeros(1, 1), torch.ones(1, num_dx_ids), torch.zeros(1, num_proc_ids)], dim=1)
        row1 = torch.cat([torch.zeros(num_dx_ids, 1 + num_dx_ids), torch.ones(num_dx_ids, num_proc_ids)], dim=1)
        row2 = torch.zeros(num_proc_ids, num_codes)
        
        guide = torch.cat([row0, row1, row2], dim=0)
        guide = guide + guide.T  
        guide = guide.unsqueeze(0).repeat(batch_size, 1, 1)  # replicate for each batch
        guide = guide * mask.unsqueeze(2) * mask.unsqueeze(1) + torch.eye(num_codes).unsqueeze(0)

    prior_guide = None

    
    if use_prior:

        prior_guide = torch.zeros(batch_size, max_num_codes*2, max_num_codes*2)
        
        prior_indices = torch.tensor(df.loc[eid][['prior_indices']])  #  
        prior_values = torch.tensor(df.loc[eid][['prior_values']])     # 
        batch_size = prior_indices.shape[0]
        
        for i in range(batch_size):
            
            indices = prior_indices[i].view(-1, 2).t().long()
            values = prior_values[i]
            sparse_matrix = torch.sparse.FloatTensor(indices, values, torch.Size([max_num_codes*2, max_num_codes*2]))
            
            # Store the dense version in the batch matrices
            prior_guide[i] = sparse_matrix.to_dense()
        
        #add visit
        row_vector = torch.tensor([prior_scalar] * max_num_codes + [0.0] * max_num_codes)
        top = row_vector.reshape((1,1,-1)).repeat(batch_size, 1, 1)
        prior_guide = torch.cat(top,prior_guide,axis=1)
        
        col_vector = torch.tensor([0.0]+[prior_scalar] * max_num_codes + [0.0] * max_num_codes)
        left = col_vector.reshape((1,-1,1)).repeat(batch_size,1,1)
        prior_guide = torch.cat(left,prior_guide,axis=2)
        
        #apply mask
        prior_guide = prior_guide*mask.unsqueeze(2)*mask.unsqueeze(1)
        # add diag
        diag_mx = prior_scalar * torch.eye(num_codes).unsqueeze(0)
        prior_guide = prior_guide+diag_mx
        
        # normalize
        degrees = prior_guide.sum(dim=2, keepdim=True)
        prior_guide = prior_guide / degrees 
    
    return guide, prior_guide




In [None]:
# unforturnatly, the transformer is not standard, 
#It used guidered masks and a lot of custome staff
# Such that, we cannot use the nn.MultiheadAttention
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# gct_params = {
#       "embedding_size": 128,
#       "num_transformer_stack": 3,
#       "num_feedforward": 2,
#       "num_attention_heads": 1,
#       "ffn_dropout": 0.08,
#       "attention_normalizer": "softmax",
#       "multihead_attention_aggregation": "concat",
#       "directed_attention": False,
#       "use_inf_mask": True,
#       "use_prior": True,
#   }
# class TransformerBlock(nn.Module):
#     def __init__(self, embedding_size, num_heads, ffn_hidden, dropout_rate):
#         super(TransformerBlock, self).__init__()
#         self.attention = nn.MultiheadAttention(embed_dim=embedding_size, num_heads=num_heads, dropout=dropout_rate)
#         self.feed_forward = nn.Sequential(
#             nn.Linear(embedding_size, ffn_hidden),
#             nn.ReLU(),
#             nn.Dropout(dropout_rate),
#             nn.Linear(ffn_hidden, embedding_size)
#         )
#         self.layernorm1 = nn.LayerNorm(embedding_size)
#         self.layernorm2 = nn.LayerNorm(embedding_size)
#         self.dropout = nn.Dropout(dropout_rate)

#     def forward(self, x, mask=None):
#         # Apply attention
#         attn_output, _ = self.attention(x, x, x, attn_mask=mask)
#         x = self.layernorm1(x + self.dropout(attn_output))
#         # Apply feedforward network
#         ffn_output = self.feed_forward(x)
#         x = self.layernorm2(x + self.dropout(ffn_output))
#         return x

# class GraphConvolutionalTransformer(nn.Module):
#     def __init__(self, embedding_size, num_heads, num_transformer_blocks, ffn_hidden, dropout_rate):
#         super(GraphConvolutionalTransformer, self).__init__()
#         self.layers = nn.ModuleList([TransformerBlock(embedding_size, num_heads, ffn_hidden, dropout_rate) for _ in range(num_transformer_blocks)])
#         self.embedding_size = embedding_size

#     def forward(self, x, mask=None):
#         for layer in self.layers:
#             x = layer(x, mask=mask)
#         return x


In [110]:
from torch.nn import LayerNorm
import torch.nn.functional as F
class GraphConvolutionalTransformer(nn.Module):
    """Graph Convolutional Transformer class.
    
    This is an implementation of Graph Convolutional Transformer. With a proper
    set of options, it can be used as a vanilla Transformer.
    """

    def __init__(self,
               num_codes = 50,
               embedding_size=128,
               num_transformer_stack=3,
               num_feedforward=2,
               num_attention_heads=1,
               ffn_dropout=0.1,
               attention_normalizer='softmax',
               multihead_attention_aggregation='concat',
               directed_attention=False,
               use_inf_mask=True,
               use_prior=True,
            
               **kwargs):
        """Init function.
        
        Args:
          embedding_size: The size of the dimension for hidden layers.
          num_transformer_stack: The number of Transformer blocks.
          num_feedforward: The number of layers in the feedforward part of
            Transformer.
          num_attention_heads: The number of attention heads.
          ffn_dropout: Dropout rate used inside the feedforward part.
          attention_normalizer: Use either 'softmax' or 'sigmoid' to normalize the
            attention values.
          multihead_attention_aggregation: Use either 'concat' or 'sum' to handle
            the outputs from multiple attention heads.
          directed_attention: Decide whether you want to use the unidirectional
            attention, where information accumulates inside the dummy visit node.
          use_inf_mask: Decide whether you want to use the guide matrix. Currently
            unused.
          use_prior: Decide whether you want to use the conditional probablility
            information. Currently unused.
          **kwargs: Other arguments to tf.keras.layers.Layer init.
        """
        
        super(GraphConvolutionalTransformer, self).__init__(**kwargs)
        self._hidden_size = embedding_size
        self._num_stack = num_transformer_stack
        self._num_feedforward = num_feedforward
        self._num_heads = num_attention_heads
        self._ffn_dropout = ffn_dropout
        self._attention_normalizer = attention_normalizer
        self._multihead_aggregation = multihead_attention_aggregation
        self._directed_attention = directed_attention
        self._use_inf_mask = use_inf_mask
        self._use_prior = use_prior

        
        self.embedder = FeatureEmbedder(embedding_size)
        self._layers = {}
        self._layers['Q'] = nn.ModuleList()
        self._layers['K'] = nn.ModuleList()
        self._layers['V'] = nn.ModuleList()
        self._layers['ffn'] = nn.ModuleList()
        self._layers['head_agg'] = nn.ModuleList()
        
        
        

        hidden_size = self._hidden_size
        num_heads = self._num_heads

        for i in range(self.num_stacks):
            self._layers['Q'].append(nn.Linear(hidden_size * num_heads, hidden_size * num_heads, bias=False))
            self._layers['K'].append(nn.Linear(hidden_size * num_heads, hidden_size * num_heads, bias=False))
            self._layers['V'].append(nn.Linear(hidden_size * num_heads, hidden_size * num_heads, bias=False))

            if self.multihead_aggregation == 'concat':
                self._layers['head_agg'].append(nn.Linear(hidden_size * num_heads, hidden_size, bias=False))

            # Feed-forward network per stack
            ffn = []
            for j in range(num_feedforward - 1):
                ffn.append(nn.Linear(hidden_size, hidden_size, bias=True))  # Bias is True by default
                ffn.append(nn.ReLU())  # Adding ReLU activation
            ffn.append(nn.Linear(hidden_size, hidden_size, bias=False))  # Last layer without activation
            self._layers['ffn'].append(nn.Sequential(*ffn))

    def qk_op(self, features, stack_index, batch_size, num_codes, attention_mask, inf_mask=None, directed_mask=None):
        """
        Generate the attention scores using query and key projections.
        """

        # Process queries
        q = self._layers['Q']stack_index](features)
        q = q.view(batch_size, num_codes, self._hidden_size, self._num_heads)
        q = q.permute(0, 3, 1, 2)  # (batch_size, num_heads, num_codes, hidden_size)

        # Process keys
        k = self._layers['K'][stack_index](features)
        k = k.view(batch_size, num_codes, self._hidden_size, self._num_heads)
        k = k.permute(0, 3, 2, 1)  # (batch_size, num_heads, hidden_size, num_codes)

        # Calculate the raw attention scores
        pre_softmax = torch.matmul(q, k) / (self._hidden_size ** 0.5)

        # Apply attention masks
        if attention_mask is not None:
            pre_softmax = pre_softmax - attention_mask.unsqueeze(1).unsqueeze(2)

        if inf_mask is not None:
            pre_softmax = pre_softmax - inf_mask.unsqueeze(1)

        if directed_mask is not None:
            pre_softmax = pre_softmax - directed_mask

        # Normalize the attention scores
        if self._attention_normalizer == 'softmax':
            attention = F.softmax(pre_softmax, dim=-1)
        else:
            attention = torch.sigmoid(pre_softmax)

        return attention
    
    def forward(self,features, masks, df):
        
        batch_size, num_codes, hidden_dim = features.shape
        num_heads = self._num_heads
        
        #set inf to mask value==0
        masks = masks.unsqueeze(-1) 
        mask_idx = masks == 0
        attention_mask = torch.zeros_like(masks, dtype=torch.float32)
        attention_mask[mask_idx] = float('inf')
        
        inf_mask = None
        if self._use_inf_mask:
            inf_mask = torch.zeros_like(guide, dtype=torch.float32)
            inf_mask[guide == 0] = float('inf')
            
        directed_mask = None
        if self._directed_attention:
            inf_matrix = torch.full((num_codes, num_codes), float('inf'))
            inf_matrix.fill_diagonal_(0)
        
            # Create a lower triangular matrix including the diagonal
            directed_mask = torch.tril(inf_matrix).unsqueeze(0).unsqueeze(0)             

        attentions = []
        for i in range(self._num_stack):
            features = masks * features

            if self._use_prior and i == 0:
                attention = prior_guide.unsqueeze(1).repeat(1, num_heads, 1, 1)
            else:
                attention = self.qk_op(features, i, num_codes, attention_mask, inf_mask, directed_mask)

            attentions.append(attention)

            v = self._layers['V'][i](features).view(-1, num_codes, self._hidden_size, num_heads)
            v = v.permute(0, 3, 1, 2)  # Reorder dimensions

            post_attention = torch.matmul(attention, v)

            if num_heads == 1:
                post_attention = post_attention.squeeze(1)
            elif self._multihead_aggregation == 'concat':
                post_attention = post_attention.permute(0, 2, 1, 3).contiguous()
                post_attention = post_attention.view(-1, num_codes, self._num_heads * self._hidden_size)
                post_attention = self._layers['head_agg'][i](post_attention)
            else:
                post_attention = post_attention.sum(dim=1)

            post_attention += features
            post_attention = LayerNorm(post_attention.size()[1:], elementwise_affine=True)(post_attention)

            post_ffn = self.feedforward(post_attention, i, training)
            post_ffn += post_attention
            post_ffn = LayerNorm(post_ffn.size()[1:], elementwise_affine=True)(post_ffn)

            features = post_ffn

        return features * masks, attentions
      

IndentationError: unexpected indent (3550209443.py, line 8)

In [None]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
class EHRTransformer(object):
    """Transformer-based EHR encounter modeling algorithm.
    
    All features within each encounter are put through multiple steps of
    self-attention. There is a dummy visit embedding in addition to other
    feature embeddings, which can be used for encounter-level predictions.
    """

    def __init__(self,
               gct_params,
               feature_keys=['dx_ints', 'proc_ints'],
               label_key='label.readmission',
               vocab_sizes={'dx_ints':3249, 'proc_ints':2210},
               feature_set='vdp',
               max_num_codes=50,
               prior_scalar=0.5,
               reg_coef=0.1,
               num_classes=1,
               learning_rate=1e-3,
               batch_size=32):
    
        self._feature_keys = feature_keys
        self._label_key = label_key
        self._vocab_sizes = vocab_sizes
        self._feature_set = feature_set
        self._max_num_codes = max_num_codes
        self._prior_scalar = prior_scalar
        self._reg_coef = reg_coef
        self._num_classes = num_classes
        self._learning_rate = learning_rate
        self._batch_size = batch_size
        
        self._gct_params = gct_params
        self._embedding_size = gct_params['embedding_size']
        self._num_transformer_stack = gct_params['num_transformer_stack']
        self._use_inf_mask = gct_params['use_inf_mask']
        self._use_prior = gct_params['use_prior']

    def get_loss(self, logits, labels, attentions):
        
        loss = F.binary_cross_entropy_with_logits(logits, labels, reduction='mean')

        # Attention regularization using KL divergence if prior is used
        if self._use_prior and len(attentions) > 1:
            kl_terms = []
            # Convert list of tensors to a tensor
            attention_tensor = torch.stack(attentions)
            
            # Calculate KL divergence between successive attention matrices
            for i in range(1, self._num_transformer_stack):
                log_p = torch.log(attention_tensor[i - 1] + 1e-12)
                log_q = torch.log(attention_tensor[i] + 1e-12)
                kl_term = attention_tensor[i - 1] * (log_p - log_q)
                kl_term = torch.sum(kl_term, dim=-1)
                kl_term = torch.mean(kl_term)
                kl_terms.append(kl_term)

            reg_term = torch.mean(torch.stack(kl_terms))
            loss += self._reg_coef * reg_term
        return loss

    def eval_model(self,model, val_loader):
        
        """
        referenced the homeworks...
        """
        
        model.eval()
        y_pred = torch.LongTensor()
        y_score = torch.Tensor()
        y_true = torch.LongTensor()
        model.eval()
        for x, masks, rev_x, rev_masks, y in val_loader:
            y_hat = model(x, masks, rev_x, rev_masks)
            y_score = torch.cat((y_score,  y_hat.detach().to('cpu')), dim=0)
            y_hat = (y_hat > 0.5).int()
            y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
            y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
        
        p, r, f, _ = precision_recall_fscore_support(y_true.numpy(), y_pred.numpy(),average='binary')
        roc_auc = roc_auc_score(y_true.numpy(),y_score.numpy())
        # your code here
        return p, r, f, roc_auc
    
    def train(self, model, train_loader, val_loader, n_epochs):
        """
        train the model.
        
        Arguments:
            model: the RNN model
            train_loader: training dataloder
            val_loader: validation dataloader
            n_epochs: total number of epochs
        """
        
        for epoch in range(n_epochs):
            model.train()
            train_loss = 0
            for data, label in train_loader:
                
                optimizer.zero_grad()
                out=model(x, masks, rev_x, rev_masks)
                loss = criterion(out, y)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
            train_loss = train_loss / len(train_loader)
            print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
            p, r, f, roc_auc = eval_model(model, val_loader)
            print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}, roc_auc: {:.2f}'
                  .format(epoch+1, p, r, f, roc_auc))



In [38]:
import torch
import numpy as np
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import KFold

# Example: Assuming X_train and y_train are your data and labels
k_folds = 5
kfold = KFold(n_splits=k_folds, shuffle=True)

results = {}

for fold, (train_ids, test_ids) in enumerate(kfold.split(X_train)):
    # Sample elements randomly from a given list of ids, no replacement.
    train_subsampler = Subset(train_dataset, train_ids)
    test_subsampler = Subset(train_dataset, test_ids)

    # Define data loaders for training and testing data in this fold
    trainloader = DataLoader(
        train_subsampler, batch_size=batch_size, shuffle=True)
    testloader = DataLoader(
        test_subsampler, batch_size=batch_size, shuffle=False)

    # Init the neural network
    network = YourNetworkModel()
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    criterion = torch.nn.CrossEntropyLoss()

    # Train this fold
    for epoch in range(0, num_epochs):
        # Train the model
        network.train()
        for batch_index, (data, target) in enumerate(trainloader):
            optimizer.zero_grad()
            output = network(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    
    # Evaluation for this fold
    correct, total = 0, 0
    network.eval()
    with torch.no_grad():
        for batch_index, (data, target) in enumerate(testloader):
            output = network(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    # Print accuracy for the current fold
    accuracy = 100.0 * correct / total
    print(f'Fold {fold}: Accuracy {accuracy}%')
    results[fold] = accuracy

# Print fold results
print(f'K-Fold Cross Validation results: {results}')


Fitting 5 folds for each of 12 candidates, totalling 60 fits
Best parameters: {'max_depth': 3, 'n_estimators': 100}
Best cross-validation score: 0.96


In [None]:
model = GraphConvolutionalTransformer()
transformer = EHRTransformer()
transformer.train(model)
transformer.eval_model(model,validate_loader)

### Pretrained model (if applicable)
- Download link: `[Insert Link Here]`

## Training

In [32]:
# hepler funtion to test and split
def train_test_split(data, test_size=0.2, random_state=None):
    if random_state is not None:
        random.seed(random_state)
    
    data_shuffled = data[:]
    random.shuffle(data_shuffled)
    
    split_idx = int(len(data) * (1 - test_size))
    
    train = data_shuffled[:split_idx]
    test = data_shuffled[split_idx:]
    
    return train, test

### Hyperparams
#### Report at least 3 types of hyperparameters such as learning rate, batch size, hidden size, dropout
- Learning rate: 0.001
- Batch size: 32
- Dropout rate: 0.5

### Computational requirements
#### Report at least 3 types of requirements such as type of hardware, average runtime for each epoch, total number of trials, GPU hrs used, 
- Hardware: NVIDIA Tesla V100 GPU
- Average runtime per epoch: 10 minutes
- Total number of epochs: 100

### Training code
- `python train.py --config path/to/config.yaml`

## Evaluation
### Metrics descriptions
- Accuracy, AUC-ROC, F1-Score.
### Evaluation code
- `python evaluate.py --model path/to/saved/model --data path/to/test/data`

# Results (15)
## Table of results (no need to include additional experiments, but main reproducibility result should be included)
| Metric     | Original Paper | Reproduced Results |
|------------|----------------|--------------------|
| Accuracy   | 85%            | 84%                |
| AUC-ROC    | 0.90           | 0.89               |
| F1-Score   | 0.78           | 0.77               |
## All claims should be supported by experiment results
- The results closely align with those reported in the original paper, confirming the efficacy of the GCT model in this context.
## Discuss with respect to the hypothesis and results from the original paper
- The hypothesis that GCT can effectively learn the hidden structure of EHR data was supported.
## Experiments beyond the original paper
 ### Each experiment should include results and a discussion
- Additional experiments on different datasets could be discussed here.
## Ablation Study.
- Impact of varying dropout rates and batch sizes on model performance.

# Discussion (10)
## Implications of the experimental results, whether the original paper was reproducible, and if it wasn’t, what factors made it irreproducible
- Discuss the reproducibility and any discrepancies.
## “What was easy”
- Access to code and clear documentation made initial steps straightforward.
## “What was difficult”
- Divergences in hardware used could potentially affect performance metrics.
## Recommendations to the original authors or others who work in this area for improving reproducibility
- If we saw the original data's treatment and 

# Public GitHub Repo (5)
## Publish your code in a public repository on GitHub and attach the URL in the notebook.
- `[GitHub Repo URL](https://github.com/yourusername/project-reproducibility)`
## Make sure your code is documented properly. 
## A README.md file describing the exact steps to run your code is required.
- Include comprehensive instructions on setting up the environment, running preprocessing, training, and evaluation scripts.

In [33]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import os

In [34]:
def set_seed(seed):
    """Set seed"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)
set_seed(24)

In [35]:
DATA_PATH = "."