In [1]:
import sys
sys.executable

'/software/util/JupyterLab/capella/jupyterlab-4.0.4/bin/python3'

In [2]:
import gc
import os
import copy
import json
import random
import pickle
import numpy as np
import pandas as pd
# import networkx as nx
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn import preprocessing

import torch
import torch.nn.functional as F

ModuleNotFoundError: No module named 'pandas'

In [3]:
import time

In [4]:
# data_path = "/home/th748/scratch/ehr_dataset/physionet.org/files/mimiciv/2.2/hosp"
data_path = "mimic-iv-3.1/hosp"
data_path

'mimic-iv-3.1/hosp'

In [5]:
med_file = os.path.join(data_path, "prescriptions.csv.gz")
procedure_file = os.path.join(data_path, "procedures_icd.csv.gz")
diag_file = os.path.join(data_path, "diagnoses_icd.csv.gz")
admission_file = os.path.join(data_path, "admissions.csv.gz")
lab_test_file = os.path.join(data_path, "labevents.csv.gz")
patient_file = os.path.join(data_path, "patients.csv.gz")

In [6]:
# drug code mapping files from GAMENet repo
GAMENet_path = "/data/horse/ws/arsi805e-finetune/Thesis/GAMENet"
ndc2atc_file = f'{GAMENet_path}/data/ndc2atc_level4.csv' 
cid_atc = f'{GAMENet_path}/data/drug-atc.csv'
ndc2rxnorm_file = f'{GAMENet_path}/data/ndc2rxnorm_mapping.txt'

In [7]:
icd10 = False

# Dataset Preprocessing

In [42]:
def process_med():
    med_pd = pd.read_csv(med_file, dtype={'ndc':'category'})
    med_pd.drop(columns=['pharmacy_id', 'poe_id', 'poe_seq',
           'order_provider_id', 'stoptime', 'drug_type', 'drug',
           'formulary_drug_cd', 'gsn', 'prod_strength', 'form_rx',
           'dose_val_rx', 'dose_unit_rx', 'form_val_disp', 'form_unit_disp',
           'doses_per_24_hrs', 'route'], axis=1, inplace=True)
    med_pd.drop(index = med_pd[med_pd['ndc'] == '0'].index, axis=0, inplace=True)
    med_pd.fillna(method='pad', inplace=True)
    med_pd.dropna(inplace=True)
    med_pd.drop_duplicates(inplace=True)
    med_pd['starttime'] = pd.to_datetime(med_pd['starttime'], format='%Y-%m-%d %H:%M:%S')    
    med_pd.sort_values(by=['subject_id', 'hadm_id', 'starttime'], inplace=True)
    med_pd = med_pd.reset_index(drop=True)
    def filter_first24hour_med(med_pd):
        med_pd_new = med_pd.drop(columns=['ndc'])
        med_pd_new = med_pd_new.groupby(by=['subject_id','hadm_id']).head(1).reset_index(drop=True)
        med_pd_new = pd.merge(med_pd_new, med_pd, on=['subject_id','hadm_id','starttime'])
        med_pd_new = med_pd_new.drop(columns=['starttime'])
        return med_pd_new
    med_pd = filter_first24hour_med(med_pd) # or next line
    med_pd = med_pd.drop_duplicates()
    # visit >= 2
#     def process_visit_lg2(med_pd):
#         a = med_pd[['subject_id', 'hadm_id']].groupby(by='subject_id')['hadm_id'].unique().reset_index()
#         a['hadm_id_Len'] = a['hadm_id'].map(lambda x:len(x))
#         a = a[a['hadm_id_Len'] > 1]
#         return a 
#     med_pd_lg2 = process_visit_lg2(med_pd).reset_index(drop=True)    
#     med_pd = med_pd.merge(med_pd_lg2[['subject_id']], on='subject_id', how='inner')    
    return med_pd.reset_index(drop=True)

def process_procedure():
    pro_pd = pd.read_csv(procedure_file, dtype={'icd_code':'category'})
    if icd10:
        pro_pd = pro_pd[pro_pd['icd_version']==10]
    else:
        pro_pd = pro_pd[pro_pd['icd_version']==9]
    print("After filtering, pro_pd['icd_version'].unique():", pro_pd['icd_version'].unique() if 'icd_version' in pro_pd else "No icd_version")
    print("First 5 proc codes:", pro_pd['icd_code'].head())
    pro_pd.drop(columns=['chartdate', 'icd_version'], inplace=True)
    pro_pd.drop_duplicates(inplace=True)
    pro_pd.sort_values(by=['subject_id', 'hadm_id', 'seq_num'], inplace=True)
    pro_pd.drop(columns=['seq_num'], inplace=True)
    pro_pd["icd_code"] = "PRO_" + pro_pd["icd_code"].astype(str)
    pro_pd.drop_duplicates(inplace=True)
    pro_pd.reset_index(drop=True, inplace=True)
    return pro_pd

def process_diag():
    diag_pd = pd.read_csv(diag_file)
    diag_pd.dropna(inplace=True)
    ## my addidtion:
    if icd10:
        diag_pd = diag_pd[diag_pd['icd_version']==10]
    else:
        diag_pd = diag_pd[diag_pd['icd_version']==9]
    print("After filtering, diag_pd['icd_version'].unique():", diag_pd['icd_version'].unique() if 'icd_version' in diag_pd else "No icd_version")
    print("First 5 diag codes:", diag_pd['icd_code'].head())
    diag_pd.drop(columns=['seq_num','icd_version'],inplace=True)
    diag_pd.drop_duplicates(inplace=True)
    diag_pd.sort_values(by=['subject_id','hadm_id'], inplace=True)
    return diag_pd.reset_index(drop=True)

def process_admission():
    ad_pd = pd.read_csv(admission_file)
    patient_pd = pd.read_csv(patient_file)
    
    ad_pd.drop(columns=['admission_type', 'admit_provider_id', 'admission_location',
       'discharge_location', 'insurance', 'language', 'marital_status', 'race',
       'edregtime', 'edouttime', 'hospital_expire_flag'], axis=1, inplace=True)
    patient_pd.drop(columns=['anchor_year', 'anchor_year_group'], axis=1, inplace=True)
    
    ad_pd["admittime"] = pd.to_datetime(ad_pd['admittime'], format='%Y-%m-%d %H:%M:%S')
    ad_pd["dischtime"] = pd.to_datetime(ad_pd['dischtime'], format='%Y-%m-%d %H:%M:%S')  # time for leaving hospital
    ad_pd = ad_pd.merge(patient_pd, on=['subject_id'], how='inner')
    
    # create features: age, death, number of days in this encounter, readmission (next visit)
    ad_pd["age"] = ad_pd['anchor_age']
    # ad_pd[ad_pd["age"] >= 300] = 90
    age = ad_pd["age"]
    bins = np.linspace(age.min(), age.max(), 20 + 1)
    ad_pd['age'] = pd.cut(age, bins=bins, labels=False, include_lowest=True)
    ad_pd['age'] = "age_" + ad_pd["age"].astype("str")
    
    ad_pd["death"] = ad_pd["dod"].notna()
    ad_pd["stay_days"] = (ad_pd["dischtime"] - ad_pd["admittime"]).dt.days
    ad_pd['admittime'] = ad_pd['admittime'].astype(str)
    ad_pd.sort_values(by=['subject_id', 'hadm_id', 'admittime'], inplace=True)
    ad_pd['next_visit'] = ad_pd.groupby('subject_id')['hadm_id'].shift(-1)
    ad_pd['readmission'] = ad_pd['next_visit'].notnull().astype(int)
    ad_pd.drop('next_visit', axis=1, inplace=True)
    ad_pd.drop(columns=['dischtime', 'dod', 'deathtime', 'anchor_age'], axis=1, inplace=True)
    ad_pd.drop_duplicates(inplace=True)
    return ad_pd.reset_index(drop=True)

def process_lab_test(n_bins=5):
    lab_pd = pd.read_csv(lab_test_file)
    lab_pd = lab_pd.groupby(by=['subject_id','itemid']).head(1).reset_index(drop=True)  # only consider the first value
    lab_pd = lab_pd[lab_pd["valuenum"].notna()]
    lab_pd = lab_pd[lab_pd["hadm_id"].notna()]
    
    # lab_pd.drop(columns=['ROW_ID'], axis=1, inplace=True)
    
    def contains_text(group):
        for item in group:
            try:
                float(item)
            except ValueError:
                return True
        return False
    for itemid in lab_pd['itemid'].unique():
        group = lab_pd[lab_pd['itemid'] == itemid]['value']
        # if the lab test contains text value then directly copy the value
        if contains_text(group):
            lab_pd.loc[lab_pd['itemid'] == itemid, 'value_bin'] = group
        else:
            # value->numeric
            values_numeric = pd.to_numeric(group, errors='coerce')
            if len(values_numeric.dropna()) < n_bins:
                lab_pd.loc[lab_pd['itemid'] == itemid, 'value_bin'] = group
            else:
                # cut
    #             bins = np.linspace(values_numeric.min(), values_numeric.max(), n_bins + 1)
    #             lab_pd.loc[df['ITEMID'] == itemid, 'value_bin'] = pd.cut(values_numeric, bins=bins, labels=False, include_lowest=True)
                lab_pd.loc[lab_pd['itemid'] == itemid, 'value_bin'] = pd.qcut(values_numeric, q=n_bins, labels=False, duplicates='drop')
        
    lab_pd["itemid"] = lab_pd["itemid"].astype(str)
    lab_pd["value_bin"] = lab_pd["value_bin"].astype(str)
    lab_pd["lab_test"] = lab_pd[["itemid", "value_bin"]].apply("-".join, axis=1)
    
    lab_pd.drop(columns=['labevent_id', 'specimen_id', 'itemid',
           'order_provider_id', 'charttime', 'storetime', 'value', 'valuenum',
           'valueuom', 'ref_range_lower', 'ref_range_upper', 'flag', 'priority',
           'comments', 'value_bin'], axis=1, inplace=True)
    lab_pd.drop_duplicates(inplace=True)
    return lab_pd.reset_index(drop=True)

In [43]:
def ndc2atc4(med_pd):
    with open(ndc2rxnorm_file, 'r') as f:
        ndc2rxnorm = eval(f.read())
    med_pd['RXCUI'] = med_pd['ndc'].map(ndc2rxnorm)
    med_pd.dropna(inplace=True)

    rxnorm2atc = pd.read_csv(ndc2atc_file)
    rxnorm2atc = rxnorm2atc.drop(columns=['YEAR','MONTH','ndc'])
    rxnorm2atc.drop_duplicates(subset=['RXCUI'], inplace=True)
    med_pd.drop(index = med_pd[med_pd['RXCUI'].isin([''])].index, axis=0, inplace=True)
    
    med_pd['RXCUI'] = med_pd['RXCUI'].astype('int64')
    med_pd = med_pd.reset_index(drop=True)
    med_pd = med_pd.merge(rxnorm2atc, on=['RXCUI'])
    med_pd.drop(columns=['ndc', 'RXCUI'], inplace=True)
    med_pd = med_pd.rename(columns={'ATC4':'ndc'})
    med_pd['ndc'] = med_pd['ndc'].map(lambda x: x[:4])
    med_pd = med_pd.drop_duplicates()    
    return med_pd.reset_index(drop=True)

def filter_most_pro(pro_pd, threshold=800):
    pro_count = pro_pd.groupby(by=['icd_code']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    pro_pd = pro_pd[pro_pd['icd_code'].isin(pro_count.loc[:threshold, 'icd_code'])]
    return pro_pd.reset_index(drop=True)

def filter_most_diag(diag_pd, threshold=2000):
    diag_count = diag_pd.groupby(by=['icd_code']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    diag_pd = diag_pd[diag_pd['icd_code'].isin(diag_count.loc[:threshold, 'icd_code'])]
    return diag_pd.reset_index(drop=True)

def filter_most_lab(lab_pd, threshold=1500):
    lab_count = lab_pd.groupby(by=['lab_test']).size().reset_index().rename(columns={0:'count'}).sort_values(by=['count'],ascending=False).reset_index(drop=True)
    lab_pd = lab_pd[lab_pd['lab_test'].isin(lab_count.loc[:threshold, 'lab_test'])]
    return lab_pd.reset_index(drop=True)

In [59]:
from datetime import timedelta

def process_all():
    start = time.time()
    local_time = time.ctime(start)
    print("Local time:", local_time)
    print('\n')
    
    print('process_med')
    med_pd = process_med()
    med_pd = ndc2atc4(med_pd)

    print('process_diag')
    diag_pd = process_diag()
    diag_pd = filter_most_diag(diag_pd)
    print("Sample diag codes:", diag_pd['icd_code'].unique())

    print('process_pro')
    pro_pd = process_procedure()
    pro_pd = filter_most_pro(pro_pd)
    print("Sample proc codes:", pro_pd['icd_code'].unique())

    print('process_ad')
    ad_pd = process_admission()
    
    print('process_lab')
    lab_pd = process_lab_test()
    lab_pd = filter_most_lab(lab_pd)

    print("Processing complete....")
    
    med_pd_key = med_pd[['subject_id', 'hadm_id']].drop_duplicates()
    diag_pd_key = diag_pd[['subject_id', 'hadm_id']].drop_duplicates()
    pro_pd_key = pro_pd[['subject_id', 'hadm_id']].drop_duplicates()
    lab_pd_key = lab_pd[['subject_id', 'hadm_id']].drop_duplicates()
    ad_pd_key = ad_pd[['subject_id', 'hadm_id']].drop_duplicates()
    
    # filter key
    combined_key = med_pd_key.merge(diag_pd_key, on=['subject_id', 'hadm_id'], how='inner')
    combined_key = combined_key.merge(pro_pd_key, on=['subject_id', 'hadm_id'], how='inner')
    combined_key = combined_key.merge(lab_pd_key, on=['subject_id', 'hadm_id'], how='inner')
    combined_key = combined_key.merge(ad_pd_key, on=['subject_id', 'hadm_id'], how='inner')
    diag_pd = diag_pd.merge(combined_key, on=['subject_id', 'hadm_id'], how='inner')
    med_pd = med_pd.merge(combined_key, on=['subject_id', 'hadm_id'], how='inner')
    pro_pd = pro_pd.merge(combined_key, on=['subject_id', 'hadm_id'], how='inner')
    lab_pd = lab_pd.merge(combined_key, on=['subject_id', 'hadm_id'], how='inner')
    ad_pd = ad_pd.merge(combined_key, on=['subject_id', 'hadm_id'], how='inner')

    # flatten and merge
    diag_pd = diag_pd.groupby(by=['subject_id','hadm_id'])['icd_code'].unique().reset_index()  
    med_pd = med_pd.groupby(by=['subject_id', 'hadm_id'])['ndc'].unique().reset_index()
    pro_pd = pro_pd.groupby(by=['subject_id','hadm_id'])['icd_code'].unique().reset_index().rename(columns={'icd_code':'pro_code'})  
    lab_pd = lab_pd.groupby(by=['subject_id','hadm_id'])['lab_test'].unique().reset_index()
    
    med_pd['ndc'] = med_pd['ndc'].map(lambda x: list(x))
    pro_pd['pro_code'] = pro_pd['pro_code'].map(lambda x: list(x))
    lab_pd['lab_test'] = lab_pd['lab_test'].map(lambda x: list(x))
    
    data = diag_pd.merge(med_pd, on=['subject_id', 'hadm_id'], how='inner')
    data = data.merge(pro_pd, on=['subject_id', 'hadm_id'], how='inner')
    data = data.merge(lab_pd, on=['subject_id', 'hadm_id'], how='inner')
    data = data.merge(ad_pd, on=['subject_id', 'hadm_id'], how='inner')
#     data['icd_code_Len'] = data['icd_code'].map(lambda x: len(x))
#     data['ndc_Len'] = data['ndc'].map(lambda x: len(x))

    data = data.sort_values(by=['subject_id', 'admittime'])
    
    # create feature: readmission within 30/90 admittime
    data['admittime'] = pd.to_datetime(data['admittime'])
    data['READMISSION_1M'] = data.groupby('subject_id')['admittime'].shift(-1) - data['admittime']
    data['READMISSION_3M'] = data['READMISSION_1M'].apply(lambda x: 1 if x <= timedelta(days=90) else 0)
    data['READMISSION_1M'] = data['READMISSION_1M'].apply(lambda x: 1 if x <= timedelta(days=30) else 0)
    
    # create feature: diease in next 6/12 months    
    data['NEXT_DIAG_6M'] = data.apply(lambda x: data[(data['subject_id'] == x['subject_id']) & 
                                              (data['admittime'] > x['admittime']) & 
                                              (data['admittime'] <= x['admittime'] + timedelta(days=180))]['icd_code'].tolist(), axis=1)
    data['NEXT_DIAG_12M'] = data.apply(lambda x: data[(data['subject_id'] == x['subject_id']) & 
                                                   (data['admittime'] > x['admittime']) & 
                                                   (data['admittime'] <= x['admittime'] + timedelta(days=365))]['icd_code'].tolist(), axis=1)
    data['NEXT_DIAG_6M'] = data['NEXT_DIAG_6M'].apply(lambda x: x[0] if x else float('nan'))
    data['NEXT_DIAG_12M'] = data['NEXT_DIAG_12M'].apply(lambda x: x[0] if x else float('nan'))
    
    data.drop(columns=['admittime'], axis=1, inplace=True)

    end = time.time()

    print(f"Processing time = {end - start}s")
    
    return data.reset_index(drop=True), diag_pd, pro_pd

In [45]:
# start = time.time()
# print(f"Start time: {start}s")
# print('process_med')
# med_pd = process_med()
# med_pd = ndc2atc4(med_pd)

# print('process_diag')
# diag_pd = process_diag()
# diag_pd = filter_most_diag(diag_pd)

# print('process_pro')
# pro_pd = process_procedure()
# pro_pd = filter_most_pro(pro_pd)

# print('process_ad')
# ad_pd = process_admission()

# print('process_lab')
# lab_pd = process_lab_test()
# lab_pd = filter_most_lab(lab_pd)

# end = time.time()

# print(f"Processing complete.... in {end - start}s")

In [46]:
# med_pd

In [47]:
# diag_pd

In [48]:
# ad_pd

In [49]:
# lab_pd

In [50]:
# pro_pd

In [51]:
# med_pd_key = med_pd[['subject_id', 'hadm_id']].drop_duplicates()
# diag_pd_key = diag_pd[['subject_id', 'hadm_id']].drop_duplicates()
# pro_pd_key = pro_pd[['subject_id', 'hadm_id']].drop_duplicates()
# lab_pd_key = lab_pd[['subject_id', 'hadm_id']].drop_duplicates()
# ad_pd_key = ad_pd[['subject_id', 'hadm_id']].drop_duplicates()

In [52]:
# print('MED')
# print(med_pd_key)
# print('\n')
# print('DIAG')
# print(diag_pd_key)
# print('\n')
# print('PRO')
# print(pro_pd_key)
# print('\n')
# print('LAB')
# print(lab_pd_key)
# print('\n')
# print('ADMIS')
# print(ad_pd_key)
# print('\n')

In [53]:
# # filter key
# combined_key = med_pd_key.merge(diag_pd_key, on=['subject_id', 'hadm_id'], how='inner')
# combined_key = combined_key.merge(pro_pd_key, on=['subject_id', 'hadm_id'], how='inner')
# combined_key = combined_key.merge(lab_pd_key, on=['subject_id', 'hadm_id'], how='inner')
# combined_key = combined_key.merge(ad_pd_key, on=['subject_id', 'hadm_id'], how='inner')
# diag_pd = diag_pd.merge(combined_key, on=['subject_id', 'hadm_id'], how='inner')
# med_pd = med_pd.merge(combined_key, on=['subject_id', 'hadm_id'], how='inner')
# pro_pd = pro_pd.merge(combined_key, on=['subject_id', 'hadm_id'], how='inner')
# lab_pd = lab_pd.merge(combined_key, on=['subject_id', 'hadm_id'], how='inner')
# ad_pd = ad_pd.merge(combined_key, on=['subject_id', 'hadm_id'], how='inner')

In [54]:
# combined_key

In [55]:
# print('MED')
# print(med_pd)
# print('\n')
# print('DIAG')
# print(diag_pd)
# print('\n')
# print('PRO')
# print(pro_pd)
# print('\n')
# print('LAB')
# print(lab_pd)
# print('\n')
# print('ADMIS')
# print(ad_pd)
# print('\n')

In [60]:
def statistics(data):
    print('#patients ', data['subject_id'].unique().shape)
    print('#clinical events ', len(data))
    
    diag = data['icd_code'].values
    med = data['ndc'].values
    pro = data['pro_code'].values
    lab_test = data['lab_test'].values
    
    unique_diag = set([j for i in diag for j in list(i)])
    unique_med = set([j for i in med for j in list(i)])
    unique_pro = set([j for i in pro for j in list(i)])
    unique_lab = set([j for i in lab_test for j in list(i)])
    
    print('#diagnosis ', len(unique_diag))
    print('#med ', len(unique_med))
    print('#procedure', len(unique_pro))
    print('#lab', len(unique_lab))
    
    avg_diag = avg_med = avg_pro = avg_lab = 0
    max_diag = max_med = max_pro = max_lab = 0
    cnt = max_visit = avg_visit = 0

    for subject_id in data['subject_id'].unique():
        item_data = data[data['subject_id'] == subject_id]
        x, y, z, k = [], [], [], []
        visit_cnt = 0
        for index, row in item_data.iterrows():
            visit_cnt += 1
            cnt += 1
            x.extend(list(row['icd_code']))
            y.extend(list(row['ndc']))
            z.extend(list(row['pro_code']))
            k.extend(list(row['lab_test']))
        x, y, z, k = set(x), set(y), set(z), set(k)
        avg_diag += len(x)
        avg_med += len(y)
        avg_pro += len(z)
        avg_lab += len(k)
        avg_visit += visit_cnt
        if len(x) > max_diag:
            max_diag = len(x)
        if len(y) > max_med:
            max_med = len(y) 
        if len(z) > max_pro:
            max_pro = len(z)
        if len(k) > max_lab:
            max_lab = len(k)
        if visit_cnt > max_visit:
            max_visit = visit_cnt

    print('#avg of diagnoses ', avg_diag/ cnt)
    print('#avg of medicines ', avg_med/ cnt)
    print('#avg of procedures ', avg_pro/ cnt)
    print('#avg of lab_test ', avg_lab/ cnt)
    print('#avg of vists ', avg_visit/ len(data['subject_id'].unique()))

    print('#max of diagnoses ', max_diag)
    print('#max of medicines ', max_med)
    print('#max of procedures ', max_pro)
    print('#max of lab_test ', max_lab)
    print('#max of visit ', max_visit)

In [61]:
# med_pd = pd.read_csv(med_file, dtype={'ndc':'category'})
# med_pd

In [62]:
data, diag_pd, pro_pd = process_all()

Local time: Sun Jul 13 16:24:50 2025


process_med


  med_pd = pd.read_csv(med_file, dtype={'ndc':'category'})
  med_pd.fillna(method='pad', inplace=True)


process_diag
After filtering, diag_pd['icd_version'].unique(): [9]
First 5 diag codes: 0     5723
1    78959
2     5715
3    07070
4      496
Name: icd_code, dtype: object
Sample diag codes: ['5723' '78959' '5715' ... '78843' '4464' '20020']
process_pro
After filtering, pro_pd['icd_version'].unique(): [9]
First 5 proc codes: 0    5491
1    5491
2    5491
3    8938
5    8938
Name: icd_code, dtype: category
Categories (14911, object): ['0009', '0010', '0011', '0012', ..., 'DW064ZZ', 'DWY17ZZ', 'XW033C6', 'XW043W5']
Sample proc codes: ['PRO_5491' 'PRO_8938' 'PRO_5551' 'PRO_3734' 'PRO_3728' 'PRO_3727'
 'PRO_3893' 'PRO_4524' 'PRO_7569' 'PRO_5011' 'PRO_4562' 'PRO_5459'
 'PRO_4513' 'PRO_0066' 'PRO_3607' 'PRO_0045' 'PRO_0041' 'PRO_3722'
 'PRO_8856' 'PRO_0044' 'PRO_8674' 'PRO_8669' 'PRO_8672' 'PRO_0139'
 'PRO_0331' 'PRO_3897' 'PRO_7359' 'PRO_4576' 'PRO_4525' 'PRO_734'
 'PRO_8853' 'PRO_8855' 'PRO_8696' 'PRO_7094' 'PRO_3950' 'PRO_0055'
 'PRO_0046' 'PRO_0040' 'PRO_3615' 'PRO_3612' 'PRO_3491' 'PRO_

In [70]:
diag_pd

Unnamed: 0,subject_id,hadm_id,icd_code
0,10000032,22595853,"[5723, 78959, 5715, 07070, 496, 29680, 30981, ..."
1,10000690,25860671,"[5070, 42833, 51881, 5849, 5781, 2763, 5119, 5..."
2,10000690,26146595,"[5609, 5849, 42832, 2930, 1539, 4280, 42731, 7..."
3,10000826,20032235,"[5712, 486, 78959, 5723, 5990, 2639, 2761, 511..."
4,10000826,21086876,"[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3..."
...,...,...,...
78036,19999442,26785317,"[34541, 43491, 431, 3485, V6284, 11284, 5990, ..."
78037,19999565,20486927,"[82101, E8859, V0382]"
78038,19999840,21033226,"[3453, 51881, 5070, 5180, 42741, 43821, 43811,..."
78039,19999840,26071774,"[43491, 43820, 34590, 43811, 4019, 2724, 3051]"


In [77]:
diag_pd[diag_pd['subject_id']==10001401]

Unnamed: 0,subject_id,hadm_id,icd_code,unusual_icd_codes


In [76]:
import re

def find_unusual_codes(code_list):
    # Find codes that contain letters (A-Z/a-z), but not starting with E/V (case insensitive)
    unusual = [code for code in code_list
               if re.match(r'[A-DF-UW-Z]', code, re.I)]  # Any code starting with letter, but not E/V
    return unusual

# Apply to the entire DataFrame
diag_pd['unusual_icd_codes'] = diag_pd['icd_code'].apply(find_unusual_codes)

# Find rows with any unusual code
rows_with_unusual = diag_pd[diag_pd['unusual_icd_codes'].apply(lambda x: len(x) > 0)]

print(f"Number of rows with unusual codes: {len(rows_with_unusual)}")
print("Sample rows with unusual codes:")
print(rows_with_unusual[['subject_id', 'hadm_id', 'icd_code', 'unusual_icd_codes']].head())

Number of rows with unusual codes: 0
Sample rows with unusual codes:
Empty DataFrame
Columns: [subject_id, hadm_id, icd_code, unusual_icd_codes]
Index: []


In [63]:
data

Unnamed: 0,subject_id,hadm_id,icd_code,ndc,pro_code,lab_test,gender,age,death,stay_days,readmission,READMISSION_1M,READMISSION_3M,NEXT_DIAG_6M,NEXT_DIAG_12M
0,10000032,22595853,"[5723, 78959, 5715, 07070, 496, 29680, 30981, ...",[B01A],[PRO_5491],"[51114-1, 51120-0]",F,age_9,True,0,1,0,0,,
1,10000690,25860671,"[5070, 42833, 51881, 5849, 5781, 2763, 5119, 5...","[A06A, B01A]",[PRO_3893],[51009-___],F,age_18,True,9,1,0,0,,
2,10000690,26146595,"[5609, 5849, 42832, 2930, 1539, 4280, 42731, 7...","[A01A, A12C, N02B, N02A, A04A, J01M]",[PRO_4524],[50900-___],F,age_18,True,1,1,0,0,,
3,10000826,20032235,"[5712, 486, 78959, 5723, 5990, 2639, 2761, 511...",[B01A],[PRO_5491],"[51118-1, 50953-0, 50960-1.5, 50970-3.3, 50995...",F,age_3,False,6,1,1,1,"[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3...","[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3..."
4,10000826,21086876,"[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3...","[B03B, N07B]",[PRO_5491],"[51095-___, 51105-___]",F,age_3,False,6,1,1,1,"[5723, 78959, 2761, 5712, 2875, 5711, 7242, 33...","[5723, 78959, 2761, 5712, 2875, 5711, 7242, 33..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
78036,19999442,26785317,"[34541, 43491, 431, 3485, V6284, 11284, 5990, ...",[N06A],"[PRO_9671, PRO_9604, PRO_966]","[50910-___, 50911-4, 50802-0, 50804-23, 50808-...",M,age_6,False,15,0,0,0,,
78037,19999565,20486927,"[82101, E8859, V0382]","[J01M, B05C, A04A, A06A, N02B]","[PRO_7935, PRO_9955]","[50893-8.6, 50960-1.9, 50970-3.1]",F,age_13,False,3,0,0,0,,
78038,19999840,26071774,"[43491, 43820, 34590, 43811, 4019, 2724, 3051]",[N02B],"[PRO_8891, PRO_8841]","[50868-16, 50882-20, 50902-106, 50908-0, 50910...",M,age_10,True,3,0,0,1,"[3453, 51881, 5070, 5180, 42741, 43821, 43811,...","[3453, 51881, 5070, 5180, 42741, 43821, 43811,..."
78039,19999840,21033226,"[3453, 51881, 5070, 5180, 42741, 43821, 43811,...","[A01A, H02A, C10A]","[PRO_9604, PRO_9672, PRO_966, PRO_0331]","[51099-1, 50862-4.3, 50883-0.3, 50884-1, 50930...",M,age_10,True,6,1,0,0,,


In [78]:
data[:20]

Unnamed: 0,subject_id,hadm_id,icd_code,ndc,pro_code,lab_test,gender,age,death,stay_days,readmission,READMISSION_1M,READMISSION_3M,NEXT_DIAG_6M,NEXT_DIAG_12M
0,10000032,22595853,"[5723, 78959, 5715, 07070, 496, 29680, 30981, ...",[B01A],[PRO_5491],"[51114-1, 51120-0]",F,age_9,True,0,1,0,0,,
1,10000690,25860671,"[5070, 42833, 51881, 5849, 5781, 2763, 5119, 5...","[A06A, B01A]",[PRO_3893],[51009-___],F,age_18,True,9,1,0,0,,
2,10000690,26146595,"[5609, 5849, 42832, 2930, 1539, 4280, 42731, 7...","[A01A, A12C, N02B, N02A, A04A, J01M]",[PRO_4524],[50900-___],F,age_18,True,1,1,0,0,,
3,10000826,20032235,"[5712, 486, 78959, 5723, 5990, 2639, 2761, 511...",[B01A],[PRO_5491],"[51118-1, 50953-0, 50960-1.5, 50970-3.3, 50995...",F,age_3,False,6,1,1,1,"[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3...","[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3..."
4,10000826,21086876,"[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3...","[B03B, N07B]",[PRO_5491],"[51095-___, 51105-___]",F,age_3,False,6,1,1,1,"[5723, 78959, 2761, 5712, 2875, 5711, 7242, 33...","[5723, 78959, 2761, 5712, 2875, 5711, 7242, 33..."
5,10000826,28289260,"[5723, 78959, 2761, 5712, 2875, 5711, 7242, 33...","[N05B, N02A]",[PRO_5491],"[51114-1, 51067-___, 51094-1]",F,age_3,False,2,0,0,0,,
6,10000935,25849114,"[1970, 34830, 1977, 1539, 2762, 5780, 2869, 27...","[N02B, A06A]",[PRO_5011],"[50963-___, 50802--2, 50804-22, 50818-33, 5082...",F,age_9,True,15,1,0,0,,
7,10000980,26913865,"[41071, 42823, 41412, 5854, 4240, 4280, 41401,...","[A12C, V03A, A02B, C03C, C10A, B01A, C08C, H04...","[PRO_0066, PRO_3607, PRO_0045, PRO_0041, PRO_3...","[50863-61, 50883-0.1, 50884-0, 50885-0.3]",F,age_15,True,5,1,0,0,,
8,10000980,25242409,"[45342, 53551, 5854, 25040, 5849, 3572, 42832,...","[C07A, B01A, C08C, N06A, M04A, N02B, A06A, C10...",[PRO_4513],[50810-25],F,age_15,True,7,1,0,0,,
9,10001217,24597018,"[3484, 3485, 5180, 340, 04109, 3051, 4019, V16...","[A06A, N03A, N02A, N02B]","[PRO_0139, PRO_0331, PRO_3897]","[52264-___, 52272-0, 52281-___, 52286-___, 509...",F,age_10,False,6,1,0,0,,


In [64]:
print("Sample codes in final merged data:", data['icd_code'].explode().unique()[:20])

Sample codes in final merged data: ['5723' '78959' '5715' '07070' '496' '29680' '30981' 'V1582' '5070'
 '42833' '51881' '5849' '5781' '2763' '5119' '5180' 'V850' '4280' '42731'
 '42781']


In [79]:
data.iloc[12]

subject_id                                                 10001725
hadm_id                                                    25563031
icd_code          [78829, 9950, 49390, 53081, 30000, 311, 7291, ...
ndc               [N05B, M01A, D07A, C07A, C03B, J01D, N02A, C01...
pro_code                                       [PRO_8696, PRO_7094]
lab_test          [50868-14, 50882-24, 50893-9.1, 50912-0.8, 509...
gender                                                            F
age                                                           age_7
death                                                         False
stay_days                                                         2
readmission                                                       0
READMISSION_1M                                                    0
READMISSION_3M                                                    0
NEXT_DIAG_6M                                                    NaN
NEXT_DIAG_12M                                   

In [80]:
statistics(data)

#patients  (58894,)
#clinical events  78041
#diagnosis  2001
#med  140
#procedure 801
#lab 1387
#avg of diagnoses  10.567137786548097
#avg of medicines  3.017093579016158
#avg of procedures  2.9668507579349317
#avg of lab_test  10.313130277674556
#avg of vists  1.3251095187964819
#max of diagnoses  114
#max of medicines  27
#max of procedures  41
#max of lab_test  71
#max of visit  12


In [81]:
print("ICD flag: ", icd10)
if icd10:
    with open("./dataset/mimic_icd10.pkl", "wb") as outfile:
        pickle.dump(data, outfile)
else:
    with open("./dataset/mimic.pkl", "wb") as outfile:
        pickle.dump(data, outfile)

print(f"Data stored in: {outfile}")

ICD flag:  False
Data stored in: <_io.BufferedWriter name='./dataset/mimic.pkl'>


# Dataset Split

In [8]:
if icd10:
    dataset = "mimic_icd10.pkl"
else:
    dataset = "mimic.pkl"
print(dataset)
mimic_data = pickle.load(open("dataset/mimic.pkl", 'rb'))
mimic_data.columns

mimic.pkl


Index(['subject_id', 'hadm_id', 'icd_code', 'ndc', 'pro_code', 'lab_test',
       'gender', 'age', 'death', 'stay_days', 'readmission', 'READMISSION_1M',
       'READMISSION_3M', 'NEXT_DIAG_6M', 'NEXT_DIAG_12M'],
      dtype='object')

In [9]:
mimic_data.head(20)

Unnamed: 0,subject_id,hadm_id,icd_code,ndc,pro_code,lab_test,gender,age,death,stay_days,readmission,READMISSION_1M,READMISSION_3M,NEXT_DIAG_6M,NEXT_DIAG_12M
0,10000032,22595853,"[5723, 78959, 5715, 07070, 496, 29680, 30981, ...",[B01A],[PRO_5491],"[51114-1, 51120-0]",F,age_9,True,0,1,0,0,,
1,10000690,25860671,"[5070, 42833, 51881, 5849, 5781, 2763, 5119, 5...","[A06A, B01A]",[PRO_3893],[51009-___],F,age_18,True,9,1,0,0,,
2,10000690,26146595,"[5609, 5849, 42832, 2930, 1539, 4280, 42731, 7...","[A01A, A12C, N02B, N02A, A04A, J01M]",[PRO_4524],[50900-___],F,age_18,True,1,1,0,0,,
3,10000826,20032235,"[5712, 486, 78959, 5723, 5990, 2639, 2761, 511...",[B01A],[PRO_5491],"[51118-1, 50953-0, 50960-1.5, 50970-3.3, 50995...",F,age_3,False,6,1,1,1,"[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3...","[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3..."
4,10000826,21086876,"[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3...","[B03B, N07B]",[PRO_5491],"[51095-___, 51105-___]",F,age_3,False,6,1,1,1,"[5723, 78959, 2761, 5712, 2875, 5711, 7242, 33...","[5723, 78959, 2761, 5712, 2875, 5711, 7242, 33..."
5,10000826,28289260,"[5723, 78959, 2761, 5712, 2875, 5711, 7242, 33...","[N05B, N02A]",[PRO_5491],"[51114-1, 51067-___, 51094-1]",F,age_3,False,2,0,0,0,,
6,10000935,25849114,"[1970, 34830, 1977, 1539, 2762, 5780, 2869, 27...","[N02B, A06A]",[PRO_5011],"[50963-___, 50802--2, 50804-22, 50818-33, 5082...",F,age_9,True,15,1,0,0,,
7,10000980,26913865,"[41071, 42823, 41412, 5854, 4240, 4280, 41401,...","[A12C, V03A, A02B, C03C, C10A, B01A, C08C, H04...","[PRO_0066, PRO_3607, PRO_0045, PRO_0041, PRO_3...","[50863-61, 50883-0.1, 50884-0, 50885-0.3]",F,age_15,True,5,1,0,0,,
8,10000980,25242409,"[45342, 53551, 5854, 25040, 5849, 3572, 42832,...","[C07A, B01A, C08C, N06A, M04A, N02B, A06A, C10...",[PRO_4513],[50810-25],F,age_15,True,7,1,0,0,,
9,10001217,24597018,"[3484, 3485, 5180, 340, 04109, 3051, 4019, V16...","[A06A, N03A, N02A, N02B]","[PRO_0139, PRO_0331, PRO_3897]","[52264-___, 52272-0, 52281-___, 52286-___, 509...",F,age_10,False,6,1,0,0,,


In [10]:
len(mimic_data)

78041

In [11]:
mimic_data.iloc[0]

subject_id                                                 10000032
hadm_id                                                    22595853
icd_code          [5723, 78959, 5715, 07070, 496, 29680, 30981, ...
ndc                                                          [B01A]
pro_code                                                 [PRO_5491]
lab_test                                         [51114-1, 51120-0]
gender                                                            F
age                                                           age_9
death                                                          True
stay_days                                                         0
readmission                                                       1
READMISSION_1M                                                    0
READMISSION_3M                                                    0
NEXT_DIAG_6M                                                    NaN
NEXT_DIAG_12M                                   

In [10]:
mimic_data[mimic_data['subject_id']==10000032]

Unnamed: 0,subject_id,hadm_id,icd_code,ndc,pro_code,lab_test,gender,age,death,stay_days,readmission,READMISSION_1M,READMISSION_3M,NEXT_DIAG_6M,NEXT_DIAG_12M
0,10000032,22595853,"[5723, 78959, 5715, 07070, 496, 29680, 30981, ...",[B01A],[PRO_5491],"[51114-1, 51120-0]",F,age_9,True,0,1,0,0,,


In [12]:
# next diagnosis (6M) is not Nan
mimic_data[mimic_data["NEXT_DIAG_6M"].notna()]

Unnamed: 0,subject_id,hadm_id,icd_code,ndc,pro_code,lab_test,gender,age,death,stay_days,readmission,READMISSION_1M,READMISSION_3M,NEXT_DIAG_6M,NEXT_DIAG_12M
3,10000826,20032235,"[5712, 486, 78959, 5723, 5990, 2639, 2761, 511...",[B01A],[PRO_5491],"[51118-1, 50953-0, 50960-1.5, 50970-3.3, 50995...",F,age_3,False,6,1,1,1,"[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3...","[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3..."
4,10000826,21086876,"[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3...","[B03B, N07B]",[PRO_5491],"[51095-___, 51105-___]",F,age_3,False,6,1,1,1,"[5723, 78959, 2761, 5712, 2875, 5711, 7242, 33...","[5723, 78959, 2761, 5712, 2875, 5711, 7242, 33..."
13,10002013,21975601,"[99672, 42832, 4111, 5849, 41401, E8781, 4280,...","[N02B, H04A, B05C, R03A, A02B]","[PRO_0066, PRO_3607, PRO_3722, PRO_0045, PRO_0...",[51482-1],F,age_9,False,2,1,0,0,"[41401, 42832, 5180, 4280, 4139, 4400, 4019, 3...","[41401, 42832, 5180, 4280, 4139, 4400, 4019, 3..."
18,10002428,28662225,"[0383, 78552, 5184, 5845, 34831, 486, 51881, 0...","[A02B, H03A]","[PRO_9604, PRO_9671, PRO_3893, PRO_3891, PRO_3...","[50817-95, 50802-1, 50804-29, 50818-49, 50820-...",F,age_16,False,17,1,1,1,"[03843, 51881, 42843, 5990, 00845, 99591, 4280...","[03843, 51881, 42843, 5990, 00845, 99591, 4280..."
31,10003400,29483621,"[28412, 20300, 4589, 42731, 5853, 2851, V5861,...","[B01A, J01F, C05A, C07A]",[PRO_4523],"[51143-0, 51144-5, 51251-1, 51255-0, 50900-___...",F,age_14,True,7,0,0,0,"[2866, 51881, 5845, 20300, 2639, 99809, 2930, ...","[2866, 51881, 5845, 20300, 2639, 99809, 2930, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
78012,19997367,25038503,"[5781, 2851, 5723, 27900, 3481, 42832, 45621, ...","[B03B, C10A, A01A]","[PRO_4513, PRO_4523, PRO_8872]","[50883-0.4, 50884-3, 50864-___, 50802-0, 50804...",F,age_12,False,4,1,0,1,"[78551, 03849, 99592, 51851, 5722, 45620, 5070...","[78551, 03849, 99592, 51851, 5722, 45620, 5070..."
78018,19998330,24492004,"[51881, 42833, 486, 5849, 49121, 4280, 40390, ...","[S01E, C08C, C10A, A07E, C09A]","[PRO_9671, PRO_9604]","[50808-1.18, 51283-1.3, 50953-3]",F,age_14,True,7,1,0,1,"[49121, 42833, 51881, 4280, 42731, 32723, 2852...","[49121, 42833, 51881, 4280, 42731, 32723, 2852..."
78022,19998497,28129567,"[41401, 5184, 4111, 25000, 2809, 2724, 45829, ...","[A06A, B01A]","[PRO_0066, PRO_3607, PRO_3606, PRO_3723, PRO_8...",[50885-0.3],F,age_17,True,3,1,0,0,"[41401, 5845, 2762, 4111, 25000, 28521, 5854, ...","[41401, 5845, 2762, 4111, 25000, 28521, 5854, ..."
78031,19999287,22997012,"[1629, 486, 7907, 25000, 496, V0481, 2724, 716...","[R03A, R01A]","[PRO_3391, PRO_3324]","[51476-2, 51493-2, 51516-1]",F,age_14,True,5,1,1,1,"[51884, 4829, 9341, 496, V462, 1625, 99659, 25...","[51884, 4829, 9341, 496, V462, 1625, 99659, 25..."


This has duplicates as well. hence it is 4733 rows. Without duplicates it is below:

In [13]:
len(set(mimic_data[mimic_data["NEXT_DIAG_6M"].notna()]["subject_id"].values.tolist()))

8151

In [14]:
len(mimic_data["subject_id"].unique())

58894

38760 unique patients

In [17]:
mimic_data[mimic_data["NEXT_DIAG_12M"].notna()]

Unnamed: 0,subject_id,hadm_id,icd_code,ndc,pro_code,lab_test,gender,age,death,stay_days,readmission,READMISSION_1M,READMISSION_3M,NEXT_DIAG_6M,NEXT_DIAG_12M
3,10000826,20032235,"[5712, 486, 78959, 5723, 5990, 2639, 2761, 511...",[B01A],[PRO_5491],"[51118-1, 50953-0, 50960-1.5, 50970-3.3, 50995...",F,age_3,False,6,1,1,1,"[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3...","[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3..."
4,10000826,21086876,"[5711, 99591, 78959, 2761, 5990, 5119, 5710, 3...","[B03B, N07B]",[PRO_5491],"[51095-___, 51105-___]",F,age_3,False,6,1,1,1,"[5723, 78959, 2761, 5712, 2875, 5711, 7242, 33...","[5723, 78959, 2761, 5712, 2875, 5711, 7242, 33..."
13,10002013,21975601,"[99672, 42832, 4111, 5849, 41401, E8781, 4280,...","[N02B, H04A, B05C, R03A, A02B]","[PRO_0066, PRO_3607, PRO_3722, PRO_0045, PRO_0...",[51482-1],F,age_9,False,2,1,0,0,"[41401, 42832, 5180, 4280, 4139, 4400, 4019, 3...","[41401, 42832, 5180, 4280, 4139, 4400, 4019, 3..."
18,10002428,28662225,"[0383, 78552, 5184, 5845, 34831, 486, 51881, 0...","[A02B, H03A]","[PRO_9604, PRO_9671, PRO_3893, PRO_3891, PRO_3...","[50817-95, 50802-1, 50804-29, 50818-49, 50820-...",F,age_16,False,17,1,1,1,"[03843, 51881, 42843, 5990, 00845, 99591, 4280...","[03843, 51881, 42843, 5990, 00845, 99591, 4280..."
31,10003400,29483621,"[28412, 20300, 4589, 42731, 5853, 2851, V5861,...","[B01A, J01F, C05A, C07A]",[PRO_4523],"[51143-0, 51144-5, 51251-1, 51255-0, 50900-___...",F,age_14,True,7,0,0,0,"[2866, 51881, 5845, 20300, 2639, 99809, 2930, ...","[2866, 51881, 5845, 20300, 2639, 99809, 2930, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
78013,19997367,20617667,"[78551, 03849, 99592, 51851, 5722, 45620, 5070...","[C07A, C10A, A01A, B03B]","[PRO_3523, PRO_3723, PRO_3733, PRO_3961, PRO_8...","[51491-7.0, 51498-1.008, 50806-108, 50808-1.15...",F,age_12,False,29,1,0,0,,"[0380, 78552, 51881, 5722, 4210, 5119, 20500, ..."
78018,19998330,24492004,"[51881, 42833, 486, 5849, 49121, 4280, 40390, ...","[S01E, C08C, C10A, A07E, C09A]","[PRO_9671, PRO_9604]","[50808-1.18, 51283-1.3, 50953-3]",F,age_14,True,7,1,0,1,"[49121, 42833, 51881, 4280, 42731, 32723, 2852...","[49121, 42833, 51881, 4280, 42731, 32723, 2852..."
78022,19998497,28129567,"[41401, 5184, 4111, 25000, 2809, 2724, 45829, ...","[A06A, B01A]","[PRO_0066, PRO_3607, PRO_3606, PRO_3723, PRO_8...",[50885-0.3],F,age_17,True,3,1,0,0,"[41401, 5845, 2762, 4111, 25000, 28521, 5854, ...","[41401, 5845, 2762, 4111, 25000, 28521, 5854, ..."
78031,19999287,22997012,"[1629, 486, 7907, 25000, 496, V0481, 2724, 716...","[R03A, R01A]","[PRO_3391, PRO_3324]","[51476-2, 51493-2, 51516-1]",F,age_14,True,5,1,1,1,"[51884, 4829, 9341, 496, V462, 1625, 99659, 25...","[51884, 4829, 9341, 496, V462, 1625, 99659, 25..."


In [18]:
# get the patient id with the labels
pat_readmission = set(mimic_data[mimic_data["READMISSION_1M"] == 1]["subject_id"].values.tolist())
print(len(pat_readmission))
pat_nextdiag_6m = set(mimic_data[mimic_data["NEXT_DIAG_6M"].notna()]["subject_id"].values.tolist())
print(len(pat_nextdiag_6m))
pat_nextdiag_12m = set(mimic_data[mimic_data["NEXT_DIAG_12M"].notna()]["subject_id"].values.tolist())
print(len(pat_nextdiag_12m))
pat_death = set(mimic_data[mimic_data["death"]]["subject_id"].values.tolist())
print(len(pat_death))
# pat_all_label = list(pat_readmission | pat_nextdiag_6m | pat_nextdiag_12m | pat_death)
pat_all_label = list(pat_readmission | pat_nextdiag_6m | pat_nextdiag_12m | pat_death)

3706
8151
9591
15608


In [19]:
len(pat_all_label)

20867

In [20]:
pat_all = mimic_data["subject_id"].unique().tolist()

In [21]:
n_pretrain_patient = int(len(pat_all) * 0.7)
# pretrain_patient = np.random.choice(list(set(pat_all) - set(pat_all_label)), n_pretrain_patient, replace=False).tolist()
pretrain_patient = np.random.choice(list(set(pat_all) - set(pat_all_label)), n_pretrain_patient, replace=False).tolist()
downstream_patient = list(set(pat_all) - set(pretrain_patient))
print(len(pretrain_patient), len(downstream_patient))

ValueError: Cannot take a larger sample than population when 'replace=False'

In [22]:
len(pat_all_label)

20867

In [23]:
label_free_patients = list(set(pat_all) - set(pat_all_label))
len(label_free_patients)

38027

In [24]:
# Calculate the pool you want to sample from:
label_free_patients = list(set(pat_all) - set(pat_all_label))
n_pretrain_patient = int(len(label_free_patients) * 0.7)  # 70% of only those WITHOUT labels

# Now, you can safely sample without replacement:
pretrain_patient = np.random.choice(label_free_patients, n_pretrain_patient, replace=False).tolist()
downstream_patient = list(set(pat_all) - set(pretrain_patient))

print(len(pretrain_patient), len(downstream_patient))
print("Total unique patients in split:", len(set(pretrain_patient) | set(downstream_patient)))
print("Total unique in original:", len(set(pat_all)))

26618 32276
Total unique patients in split: 58894
Total unique in original: 58894


In [25]:
n_pretrain_patient

26618

In [26]:
pretrain_patient[:10]

[19053975,
 11991967,
 19013037,
 11576270,
 12589387,
 12978079,
 14745196,
 13032235,
 17997063,
 12714566]

In [27]:
downstream_patient[:10]

[18874374,
 10878995,
 11272213,
 18350105,
 12189736,
 13893673,
 10747946,
 11403312,
 18612273,
 19398714]

In [28]:
pretrain_dataset = mimic_data[mimic_data["subject_id"].isin(set(pretrain_patient))]
pretrain_dataset

Unnamed: 0,subject_id,hadm_id,icd_code,ndc,pro_code,lab_test,gender,age,death,stay_days,readmission,READMISSION_1M,READMISSION_3M,NEXT_DIAG_6M,NEXT_DIAG_12M
10,10001338,22119639,"[56211, 5849, 5695, 99859, 6822, 04111, E8786,...","[D07A, N06A]",[PRO_4576],"[51237-1.1, 51274-___, 51275-28.4, 51143-0, 51...",F,age_6,False,17,1,0,0,,
11,10001492,27463908,"[41071, 42983, 2449, 4240]",[H03A],"[PRO_3722, PRO_8853, PRO_8855]","[50908-3, 50903-1, 51000-___, 50893-8.4, 50960...",F,age_14,False,1,0,0,0,,
24,10002870,25351634,"[220, 2180, 2449, 2720, 3051]","[A12A, M01A]","[PRO_6563, PRO_6841]","[51146-0.5, 51200-1.9, 51254-5.6, 50960-1.9, 5...",F,age_10,False,2,0,0,0,,
25,10002976,27179825,"[7291, 5849, V5867, 412, 41401, 4019, V4582, 2...","[C07A, A01A]",[PRO_8321],"[50889-___, 50893-8.6, 50960-2.4, 50970-2.4, 5...",M,age_14,False,4,0,0,0,,
26,10003019,23693618,"[135, 7856, 51889, 2724, V4579]","[N02B, A04A, A12A, N02A]","[PRO_3220, PRO_3422, PRO_403, PRO_3323]","[50960-1.8, 50893-8.8, 50970-3.6]",M,age_13,False,1,1,0,0,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
78005,19996219,22115476,"[57400, 5531, 2720]",[C10A],"[PRO_5123, PRO_5349]","[50893-8.6, 50960-2.1, 50970-3.2]",M,age_12,False,1,0,0,0,,
78006,19997062,28549819,"[57400, 5761, 5715, 57410]","[B01A, A07E, A05A]","[PRO_5123, PRO_5014]","[50893-9.4, 50956-23, 50960-2.1, 50970-2.5, 51...",M,age_8,False,2,0,0,0,,
78026,19998626,24639135,"[566, 56942]",[C01B],[PRO_4901],"[50868-13, 50882-27, 50893-8.7, 50902-104, 509...",F,age_4,False,2,0,0,0,,
78036,19999442,26785317,"[34541, 43491, 431, 3485, V6284, 11284, 5990, ...",[N06A],"[PRO_9671, PRO_9604, PRO_966]","[50910-___, 50911-4, 50802-0, 50804-23, 50808-...",M,age_6,False,15,0,0,0,,


In [29]:
if icd10:
    with open("./dataset/mimic_pretrain_icd10.pkl", "wb") as outfile:
        pickle.dump(pretrain_dataset, outfile)
else:
    with open("./dataset/mimic_pretrain.pkl", "wb") as outfile:
        pickle.dump(pretrain_dataset, outfile)

print(icd10)
print(f"Pretrain dataset stored in: {outfile}")

False
Pretrain dataset stored in: <_io.BufferedWriter name='./dataset/mimic_pretrain.pkl'>


In [7]:
import pickle

pretrain_data = pickle.load(open("dataset/mimic_pretrain.pkl", 'rb'))
pretrain_data.head()

Unnamed: 0,subject_id,hadm_id,icd_code,ndc,pro_code,lab_test,gender,age,death,stay_days,readmission,READMISSION_1M,READMISSION_3M,NEXT_DIAG_6M,NEXT_DIAG_12M
10,10001338,22119639,"[56211, 5849, 5695, 99859, 6822, 04111, E8786,...","[D07A, N06A]",[PRO_4576],"[51237-1.1, 51274-___, 51275-28.4, 51143-0, 51...",F,age_6,False,17,1,0,0,,
11,10001492,27463908,"[41071, 42983, 2449, 4240]",[H03A],"[PRO_3722, PRO_8853, PRO_8855]","[50908-3, 50903-1, 51000-___, 50893-8.4, 50960...",F,age_14,False,1,0,0,0,,
24,10002870,25351634,"[220, 2180, 2449, 2720, 3051]","[A12A, M01A]","[PRO_6563, PRO_6841]","[51146-0.5, 51200-1.9, 51254-5.6, 50960-1.9, 5...",F,age_10,False,2,0,0,0,,
25,10002976,27179825,"[7291, 5849, V5867, 412, 41401, 4019, V4582, 2...","[C07A, A01A]",[PRO_8321],"[50889-___, 50893-8.6, 50960-2.4, 50970-2.4, 5...",M,age_14,False,4,0,0,0,,
26,10003019,23693618,"[135, 7856, 51889, 2724, V4579]","[N02B, A04A, A12A, N02A]","[PRO_3220, PRO_3422, PRO_403, PRO_3323]","[50960-1.8, 50893-8.8, 50970-3.6]",M,age_13,False,1,1,0,0,,


In [30]:
train_ratio, val_ratio = 0.4, 0.3

#6m
n_finetune_pat, n_val_pat = int(len(pat_nextdiag_6m) * train_ratio), int(len(pat_nextdiag_6m) * val_ratio)
pat_nextdiag_6m = list(pat_nextdiag_6m)
np.random.shuffle(pat_nextdiag_6m)
finetune_pat, val_pat, test_pat = pat_nextdiag_6m[:n_finetune_pat], \
                                    pat_nextdiag_6m[n_finetune_pat:n_finetune_pat+n_val_pat], \
                                    pat_nextdiag_6m[n_finetune_pat+n_val_pat:]
finetune_dataset6m = mimic_data[mimic_data["subject_id"].isin(set(finetune_pat))]
print(f"FineTune dataset for 6M diagnosis prediction = {len(finetune_dataset6m)}")
val_dataset6m = mimic_data[mimic_data["subject_id"].isin(set(val_pat))]
print(f"Val. dataset for 6M diagnosis prediction = {len(val_dataset6m)}")
test_dataset6m = mimic_data[mimic_data["subject_id"].isin(set(test_pat))]
print(f"Test dataset for 6M diagnosis prediction = {len(test_dataset6m)}")

# 12m
n_finetune_pat, n_val_pat = int(len(pat_nextdiag_12m) * train_ratio), int(len(pat_nextdiag_12m) * val_ratio)
pat_nextdiag_12m = list(pat_nextdiag_12m)
np.random.shuffle(pat_nextdiag_12m)
finetune_pat, val_pat, test_pat = pat_nextdiag_12m[:n_finetune_pat], \
                                    pat_nextdiag_12m[n_finetune_pat:n_finetune_pat+n_val_pat], \
                                    pat_nextdiag_12m[n_finetune_pat+n_val_pat:]
finetune_dataset12m = mimic_data[mimic_data["subject_id"].isin(set(finetune_pat))]
val_dataset12m = mimic_data[mimic_data["subject_id"].isin(set(val_pat))]
test_dataset12m = mimic_data[mimic_data["subject_id"].isin(set(test_pat))]

FineTune dataset for 6M diagnosis prediction = 8767
Val. dataset for 6M diagnosis prediction = 6524
Test dataset for 6M diagnosis prediction = 6526


In [31]:
if icd10:
    with open("./dataset/mimic_nextdiag_6m_icd10.pkl", "wb") as outfile:
        pickle.dump([finetune_dataset6m, val_dataset6m, test_dataset6m], outfile)
    with open("./dataset/mimic_nextdiag_12m_icd10.pkl", "wb") as outfile:
        pickle.dump([finetune_dataset12m, val_dataset12m, test_dataset12m], outfile)
else:
    with open("./dataset/mimic_nextdiag_6m.pkl", "wb") as outfile:
        pickle.dump([finetune_dataset6m, val_dataset6m, test_dataset6m], outfile)
    with open("./dataset/mimic_nextdiag_12m.pkl", "wb") as outfile:
        pickle.dump([finetune_dataset12m, val_dataset12m, test_dataset12m], outfile)

print(outfile)

<_io.BufferedWriter name='./dataset/mimic_nextdiag_12m.pkl'>


In [21]:
mimic_nextdiag_6m = pickle.load(open("dataset/mimic_nextdiag_6m.pkl", 'rb'))
mimic_nextdiag_6m[2]

Unnamed: 0,subject_id,hadm_id,icd_code,ndc,pro_code,lab_test,gender,age,death,stay_days,readmission,READMISSION_1M,READMISSION_3M,NEXT_DIAG_6M,NEXT_DIAG_12M
72,10008924,27441295,"[56723, 5849, 78959, 5723, 2761, 2867, 5712, 3...",[N02B],"[PRO_4516, PRO_5491]","[50893-8.8, 50960-1.5, 50970-4.6, 50953-0, 508...",F,age_7,True,8,0,0,1,"[99811, 5724, 5845, 5722, 56723, 2851, 5723, 7...","[99811, 5724, 5845, 5722, 56723, 2851, 5723, 7..."
73,10008924,23676183,"[99811, 5724, 5845, 5722, 56723, 2851, 5723, 7...",[B03B],"[PRO_9462, PRO_5491, PRO_3893]","[50813-1.7, 51099-0, 50802-1, 50804-27, 50818-...",F,age_7,True,20,1,0,0,,
138,10016991,24172189,"[1536, 1962, 5990, 0416]","[A02B, M01A, A04A, A06A, A03F, N02B, A12A]",[PRO_1733],"[50900-___, 50960-1.8, 50893-8.3, 50970-3.0, 5...",M,age_7,False,4,1,0,0,"[V5811, 1536]","[V5811, 1536]"
139,10016991,27389040,"[V5811, 1536]","[L01B, A04A, A01A]","[PRO_9925, PRO_8914]","[51237-1.0, 51274-11.5]",M,age_7,False,0,0,0,0,,
154,10018052,27285907,"[1977, 78951, 7994, 2536, 7892, 1749, 53081, 2...",[B01A],[PRO_5491],"[50893-8.5, 50960-2.0, 50970-3.4]",F,age_7,True,1,0,1,1,"[1749, 1977, 78951, 2767, V860, V4986]","[1749, 1977, 78951, 2767, V860, V4986]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
77925,19986198,25207205,"[431, 99702, 6822, 2768, 3485, 3439, 2449, 34590]","[J01D, A12C]","[PRO_966, PRO_3893]","[51144-0, 50889-___, 50813-0.9, 51491-7.0, 514...",F,age_3,False,13,1,0,0,,
78009,19997293,27119651,"[34839, 42823, 5849, 3241, 4260, 73008, 70703,...",[C01B],[PRO_8628],[50995-1],M,age_15,True,10,1,0,1,"[4260, 486, 51881, 42823, 70703, 496, 3441, 27...","[4260, 486, 51881, 42823, 70703, 496, 3441, 27..."
78010,19997293,28847872,"[4260, 486, 51881, 42823, 70703, 496, 3441, 27...","[A06A, N05C]","[PRO_3771, PRO_3781, PRO_9671, PRO_3778, PRO_8...","[50908-1, 50823-4]",M,age_15,True,12,0,0,0,,
78018,19998330,24492004,"[51881, 42833, 486, 5849, 49121, 4280, 40390, ...","[S01E, C08C, C10A, A07E, C09A]","[PRO_9671, PRO_9604]","[50808-1.18, 51283-1.3, 50953-3]",F,age_14,True,7,1,0,1,"[49121, 42833, 51881, 4280, 42731, 32723, 2852...","[49121, 42833, 51881, 4280, 42731, 32723, 2852..."


In [None]:
mimic_nextdiag_12m = pickle.load(open("dataset/mimic_nextdiag_12m.pkl", 'rb'))
mimic_nextdiag_12m.head()

# PRETRAINING TRIAL

In [15]:
def expand_level2():
    level2 = ['001-009', '010-018', '020-027', '030-041', '042', '045-049', '050-059', '060-066', '070-079', '080-088',
              '090-099', '100-104', '110-118', '120-129', '130-136', '137-139', '140-149', '150-159', '160-165',
              '170-176',
              '176', '179-189', '190-199', '200-208', '209', '210-229', '230-234', '235-238', '239', '240-246',
              '249-259',
              '260-269', '270-279', '280-289', '290-294', '295-299', '300-316', '317-319', '320-327', '330-337', '338',
              '339', '340-349', '350-359', '360-379', '380-389', '390-392', '393-398', '401-405', '410-414', '415-417',
              '420-429', '430-438', '440-449', '451-459', '460-466', '470-478', '480-488', '490-496', '500-508',
              '510-519',
              '520-529', '530-539', '540-543', '550-553', '555-558', '560-569', '570-579', '580-589', '590-599',
              '600-608',
              '610-611', '614-616', '617-629', '630-639', '640-649', '650-659', '660-669', '670-677', '678-679',
              '680-686',
              '690-698', '700-709', '710-719', '720-724', '725-729', '730-739', '740-759', '760-763', '764-779',
              '780-789',
              '790-796', '797-799', '800-804', '805-809', '810-819', '820-829', '830-839', '840-848', '850-854',
              '860-869',
              '870-879', '880-887', '890-897', '900-904', '905-909', '910-919', '920-924', '925-929', '930-939',
              '940-949',
              '950-957', '958-959', '960-979', '980-989', '990-995', '996-999', 'V01-V91', 'V01-V09', 'V10-V19',
              'V20-V29',
              'V30-V39', 'V40-V49', 'V50-V59', 'V60-V69', 'V70-V82', 'V83-V84', 'V85', 'V86', 'V87', 'V88', 'V89',
              'V90',
              'V91', 'E000-E899', 'E000', 'E001-E030', 'E800-E807', 'E810-E819', 'E820-E825', 'E826-E829', 'E830-E838',
              'E840-E845', 'E846-E849', 'E850-E858', 'E860-E869', 'E870-E876', 'E878-E879', 'E880-E888', 'E890-E899',
              'E900-E909', 'E910-E915', 'E916-E928', 'E929', 'E930-E949', 'E950-E959', 'E960-E969', 'E970-E978',
              'E980-E989', 'E990-E999']

    level2_expand = {}
    for i in level2:
        tokens = i.split('-')
        if i[0] == 'V':
            if len(tokens) == 1:
                level2_expand[i] = i
            else:
                for j in range(int(tokens[0][1:]), int(tokens[1][1:]) + 1):
                    level2_expand["V%02d" % j] = i
        elif i[0] == 'E':
            if len(tokens) == 1:
                level2_expand[i] = i
            else:
                for j in range(int(tokens[0][1:]), int(tokens[1][1:]) + 1):
                    level2_expand["E%03d" % j] = i
        else:
            if len(tokens) == 1:
                level2_expand[i] = i
            else:
                for j in range(int(tokens[0]), int(tokens[1]) + 1):
                    level2_expand["%03d" % j] = i
    return level2_expand


# for diagnosis
def build_icd9_tree(graph_voc, unique_codes):
    code2sentence = {}
    # graph_voc = Voc()

    root_node = 'icd9_root'
    level3_dict = expand_level2()
    for code in unique_codes:
        level1 = code
        level2 = level1[:4] if level1[0] == 'E' else level1[:3]
        level3 = level3_dict[level2]
        level4 = root_node

        # sample = [level1, level2, level3, level4]
        sample = [level1, level2, level3]

        graph_voc.add_sentence(sample)
        code2sentence[code] = sample

    return code2sentence, graph_voc


# for medication
def build_atc_tree(graph_voc, unique_codes):
    code2sentence = {}

    root_node = 'atc_root'
    for code in unique_codes:
        # sample = [code] + [code[:i] for i in [4, 3, 1]] + [root_node]
        sample = [code[:i] for i in [4, 3, 1]]
        graph_voc.add_sentence(sample)
        code2sentence[code] = sample
    
    return code2sentence, graph_voc


## MY ADDITION (for diagnosis):
def expand_icd10_blocks():
    icd10_blocks = [
        # Chapter I: Certain infectious and parasitic diseases (A00–B99)
        'A00-A09','A15-A19','A20-A28','A30-A49','A50-A64','A65-A69','A70-A74',
        'A75-A79','A80-A89','A90-A99',
        'B00-B09','B15-B19','B20-B24','B25-B34','B35-B49','B50-B64','B65-B83',
        'B85-B89','B90-B94','B95-B97','B99',
        # Chapter II: Neoplasms (C00–D49)
        'C00-C14','C15-C26','C30-C39','C40-C41','C43-C44','C45-C49','C50','C51-C58',
        'C60-C63','C64-C68','C69-C72','C73-C75','C76-C80','C81-C85','C88-C90',
        'C91-C95','C96-C97','D00-D09','D10-D36','D37-D48',
        # Chapter III: Blood/immune (D50–D89)
        'D50-D53','D55-D59','D60-D64','D65-D69','D70-D77','D80-D89',
        # Chapter IV: Endocrine/nutritional/metabolic (E00–E89)
        'E00-E07','E08-E13','E15-E16','E20-E35','E36','E40-E46','E50-E64',
        'E65-E68','E70-E88','E89',
        # Chapter V: Mental/behavioural (F00-F99)
        'F00-F09','F10-F19','F20-F29','F30-F39','F40-F48','F50-F59',
        'F60-F69','F70-F79','F80-F89','F90-F98','F99',
        # Chapter VI: Nervous system (G00-G99)
        'G00-G09','G10-G14','G20-G26','G30-G32','G35-G37','G40-G47',
        'G50-G59','G60-G64','G70-G73','G80-G83','G90-G99',
        # Chapter VII–VIII: Eye/ear (H00-H95)
        'H00-H05','H10-H13','H15-H22','H25-H28','H30-H36','H40-H42',
        'H43-H45','H46-H48','H49-H52','H53-H54','H55-H59',
        'H60-H62','H65-H75','H80-H83','H90-H95',
        # Chapter IX, X, XI, XII, XIII, XIV
        'I00-I02','I05-I09','I10-I15','I20-I25','I26-I28','I30-I52','I60-I69',
        'I70-I79','I80-I89','I95-I99',
        'J00-J06','J09-J18','J20-J22','J30-J39','J40-J47','J60-J70','J80-J84',
        'J85-J86','J90-J94','J95-J99',
        'K00-K14','K20-K31','K35-K38','K40-K46','K50-K52','K55-K64','K65-K68',
        'K70-K77','K80-K87','K90-K95',
        'L00-L08','L10-L14','L20-L30','L40-L45','L50-L54','L55-L59','L60-L75',
        'L80-L99',
        'M00-M03','M05-M14','M15-M19','M20-M25','M26-M27','M30-M36','M40-M54',
        'M60-M63','M65-M68','M70-M79','M80-M85','M86-M90','M91-M94','M95-M99',
        'N00-N08','N10-N16','N17-N19','N20-N23','N25-N29','N30-N39','N40-N51',
        'N60-N64','N70-N77','N80-N98','N99',
        # Chapter XV–XXI: O00-O99, P00-P96, Q00-Q99, R00-R99, S00-T98, V01-Y98, Z00-Z99
        'O00-O08','O09-O09','O10-O16','O20-O29','O30-O48','O60-O75','O80-O84',
        'O85-O92','O94-O99',
        'P00-P04','P05-P08','P10-P15','P20-P29','P35-P39','P50-P61','P70-P74',
        'P76-P78','P80-P83','P90-P96',
        'Q00-Q07','Q10-Q18','Q20-Q28','Q30-Q34','Q35-Q37','Q38-Q45','Q50-Q56',
        'Q60-Q64','Q65-Q79','Q80-Q89','Q90-Q99',
        'R00-R09','R10-R19','R20-R23','R25-R29','R30-R39','R40-R46','R47-R49',
        'R50-R69','R70-R79','R80-R82','R83-R89','R90-R99',
        'S00-S09','S10-S19','S20-S29','S30-S39','S40-S49','S50-S59','S60-S69',
        'S70-S79','S80-S89','S90-S99','T00-T07','T08-T14','T15-T19','T20-T32',
        'T33-T35','T36-T50','T51-T65','T66-T78','T79-T79','T80-T88','T90-T98',
        'V01-V09','V10-V19','V20-V29','V30-V39','V40-V49','V50-V59','V60-V69',
        'V70-V79','V80-V89','V90-V99','Y00-Y34','Y35-Y38','Y40-Y84','Y85-Y89',
        'Y90-Y98','Z00-Z13','Z14-Z15','Z16','Z17','Z18','Z19','Z20-Z29','Z30-Z39',
        'Z40-Z53','Z54','Z55-Z65','Z66','Z67','Z68','Z69-Z76','Z77-Z99',
    ]
    block_expand = {}
    for block in icd10_blocks:
        parts = block.split('-')
        if len(parts) == 1:
            block_expand[parts[0]] = block
        else:
            start_letter, start_num = parts[0][0], int(parts[0][1:])
            end_num = int(parts[1][1:])
            for i in range(start_num, end_num+1):
                code3 = f"{start_letter}{i:02d}"
                block_expand[code3] = block
    return block_expand


def build_icd10_tree(graph_voc, unique_codes):
    code2sentence = {}
    block_dict = expand_icd10_blocks()
    for code in unique_codes:
        short = code.split('.')[0][:3]
        block = block_dict.get(short, 'Unknown')
        sample = [code, short, block]
        graph_voc.add_sentence(sample)
        code2sentence[code] = sample
    return code2sentence, graph_voc

In [16]:
class Voc(object):
    def __init__(self):
        self.idx2word = {}
        self.word2idx = {}

    def add_sentence(self, sentence):
        for word in sentence:
            if word not in self.word2idx:
                self.idx2word[len(self.word2idx)] = word
                self.word2idx[word] = len(self.word2idx)


class EHRTokenizer(object):
    def __init__(self, diag_sentences, med_sentences, lab_sentences, pro_sentences, gender_set, age_set, age_gender_set=None, special_tokens=("[PAD]", "[CLS]", "[SEP]", "[MASK]")):

        self.vocab = Voc()

        # special tokens
        self.vocab.add_sentence(special_tokens)
        self.n_special_tokens = len(special_tokens)
        self.age_voc = self.add_vocab(age_set)
        self.diag_voc = self.add_vocab(diag_sentences)
        self.med_voc = self.add_vocab(med_sentences)
        self.lab_voc = self.add_vocab(lab_sentences)
        if pro_sentences is not None:
            self.pro_voc = self.add_vocab(pro_sentences)
        else:
            self.pro_voc = Voc()
        self.gender_voc = self.add_vocab(gender_set)
        if age_gender_set is not None:
            self.age_gender_voc = self.add_vocab(age_gender_set)
        else:
            self.age_gender_voc = Voc()

        assert len(special_tokens) + len(self.age_voc.idx2word) + len(self.diag_voc.idx2word) + len(self.med_voc.idx2word) + \
                len(self.lab_voc.idx2word) + len(self.pro_voc.idx2word) + len(self.gender_voc.idx2word) + len(self.age_gender_voc.idx2word) == len(self.vocab.idx2word)

    def build_tree(self):
        # create tree for diagnosis and medication
        diag2tree, self.diag_tree_voc = build_icd9_tree(Voc(), list(self.diag_voc.idx2word.values()))
        med2tree, self.med_tree_voc = build_atc_tree(Voc(), list(self.med_voc.idx2word.values()))
        
        diag_tree_table = []
        for diag_id in range(len(self.diag_voc.idx2word)):
            diag_tree = diag2tree[self.diag_voc.idx2word[diag_id]]  # [code1, code2, ...]
            diag_tree_table.append([self.diag_tree_voc.word2idx[code] for code in diag_tree])
        
        med_tree_table = []
        for med_id in range(len(self.med_voc.idx2word)):
            med_tree = med2tree[self.med_voc.idx2word[med_id]]  # [code1, code2, ...]
            med_tree_table.append([self.med_tree_voc.word2idx[code] for code in med_tree])
        
        # [n_diag/med, n_level]
        self.diag_tree_table, self.med_tree_table = torch.tensor(diag_tree_table), torch.tensor(med_tree_table)

    def add_vocab(self, sentences):
        voc = self.vocab
        specific_voc = Voc()
        for sentence in sentences:
            voc.add_sentence(sentence)
            specific_voc.add_sentence(sentence)
        return specific_voc

    def convert_tokens_to_ids(self, tokens, voc_type="all"):
        """Converts a sequence of tokens into ids using the vocab."""
        ids = []
        for token in tokens:
            if voc_type == "all":
                ids.append(self.vocab.word2idx[token])
            elif voc_type == "diag":
                ids.append(self.diag_voc.word2idx[token])
            elif voc_type == "med":
                ids.append(self.med_voc.word2idx[token])
            elif voc_type == "lab":
                ids.append(self.lab_voc.word2idx[token])
            elif voc_type == "pro":
                ids.append(self.pro_voc.word2idx[token])
        return ids

    def convert_ids_to_tokens(self, ids, voc_type="all"):
        """Converts a sequence of ids in wordpiece tokens using the vocab."""
        tokens = []
        for i in ids:
            if voc_type == "all":
                tokens.append(self.vocab.idx2word[i])
            elif voc_type == "diag":
                tokens.append(self.diag_voc.idx2word[i])
            elif voc_type == "med":
                tokens.append(self.med_voc.idx2word[i])
            elif voc_type == "lab":
                tokens.append(self.lab_voc.idx2word[i])
            elif voc_type == "pro":
                tokens.append(self.pro_voc.idx2word[i])
        return tokens
    
    def token_id_range(self, voc_type="diag"):
        age_size = len(self.age_voc.idx2word)
        diag_size = len(self.diag_voc.idx2word)
        med_size = len(self.med_voc.idx2word)
        lab_size = len(self.lab_voc.idx2word)

        if voc_type == "diag":
            return [self.n_special_tokens + age_size, self.n_special_tokens + age_size + diag_size]
        elif voc_type == "med":
            return [self.n_special_tokens + age_size + diag_size, self.n_special_tokens + age_size + diag_size + med_size]
        elif voc_type == "lab":
            return [self.n_special_tokens + age_size + diag_size + med_size, self.n_special_tokens + age_size + diag_size + med_size + lab_size]
        elif voc_type == "pro":
            return [self.n_special_tokens + age_size + diag_size + med_size + lab_size, len(self.vocab.idx2word)]
    
    def token_number(self, voc_type="diag"):
        if voc_type == "diag":
            return len(self.diag_voc.idx2word)
        elif voc_type == "med":
            return len(self.med_voc.idx2word)
        elif voc_type == "lab":
            return len(self.lab_voc.idx2word)
        elif voc_type == "pro":
            return len(self.pro_voc.idx2word)
    
    def random_token(self, voc_type="diag"):
        # randomly sample a token from the vocabulary
        if voc_type == "diag":
            return self.diag_voc.idx2word[np.random.randint(len(self.diag_voc.idx2word))]
        elif voc_type == "med":
            return self.med_voc.idx2word[np.random.randint(len(self.med_voc.idx2word))]
        elif voc_type == "lab":
            return self.lab_voc.idx2word[np.random.randint(len(self.lab_voc.idx2word))]
        elif voc_type == "pro":
            return self.pro_voc.idx2word[np.random.randint(len(self.pro_voc.idx2word))]

In [17]:
import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

class PretrainEHRDataset(Dataset):
    def __init__(self, data_pd, tokenizer: EHRTokenizer, token_type=['diag', 'med', 'pro', 'lab']):
        self.tokenizer = tokenizer

        def transform_data(data):
            records, ages = {}, {}
            genders = {}
            for subject_id in data['subject_id'].unique():
                item_df = data[data['subject_id'] == subject_id]
                genders[subject_id] = [item_df.head(1)["gender"].values[0]]

                patient, age = [], []
                for _, row in item_df.iterrows():
                    admission = []
                    if "diag" in token_type:
                        admission.append(list(row['icd_code']))
                    if "med" in token_type:
                        admission.append(list(row['ndc']))
                    if "pro" in token_type:
                        admission.append(list(row['pro_code']))
                    if "lab" in token_type:
                        admission.append(list(row['lab_test']))
                    patient.append(admission)
                    age.append(row['age'])
                records[subject_id] = list(patient)
                ages[subject_id] = age
            return records, ages, genders

        self.records, self.ages, self.genders = transform_data(data_pd)

    def __len__(self):
        return len(self.records)

    def __getitem__(self, item):
        subject_id = list(self.records.keys())[item]

        input_tokens = ['[CLS]']
        visit_positions = [0]
        age_tokens = [self.ages[subject_id][0]]
        for idx, adm in enumerate(self.records[subject_id]):  # each subject have multiple admissions, idx:visit id
            cur_input_tokens = []
            for i in range(len(adm)):  # one admission may have many kinds of entities, [[diag], [med], ...]
                cur_input_tokens.extend(list(adm[i]))
            if idx != len(self.records[subject_id]) - 1:  # add [SEP] token between visits
                cur_input_tokens.append('[SEP]')
            input_tokens.extend(cur_input_tokens)
            visit_positions.extend([idx] * len(cur_input_tokens))
            age_tokens.extend([self.ages[subject_id][idx]] * len(cur_input_tokens))

        visit_positions = torch.tensor(visit_positions)
        age_tokens = torch.tensor(self.tokenizer.convert_tokens_to_ids(age_tokens))

        # masked token prediction
        non_special_tokens_idx = [idx for idx, x in enumerate(input_tokens) if x != '[CLS]' and x != '[SEP]']
        masked_tokens_idx = np.random.choice(non_special_tokens_idx, max(1, int(len(non_special_tokens_idx) * 0.15)))
        masked_tokens = [input_tokens[idx] for idx in masked_tokens_idx]
        masked_input_tokens = input_tokens.copy()
        for idx in masked_tokens_idx:
            if np.random.random() < 0.8:
                masked_input_tokens[idx] = '[MASK]'
            elif np.random.random() < 0.5:
                # TODO: sample based on the entity type
                masked_input_tokens[idx] = np.random.choice(list(self.tokenizer.diag_voc.word2idx.keys()))
        masked_input_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(masked_input_tokens))
        masked_tokens_idx = torch.tensor(masked_tokens_idx)
        # TODO: consider all kinds of entities
        masked_lm_labels = torch.tensor(self.tokenizer.convert_tokens_to_ids(masked_tokens, voc_type='diag'))
        return masked_input_ids, visit_positions, age_tokens, masked_lm_labels, masked_tokens_idx


class FinetuneEHRDataset(Dataset):
    def __init__(self, data_pd, tokenizer: EHRTokenizer, token_type=['diag', 'med', 'pro', 'lab'], task='death'):
        self.tokenizer = tokenizer
        self.task = task

        def transform_data(data, task):
            age_records = {}
            hadm_records = {}  # including current admission and previous admissions
            genders = {}
            labels = {}
            for subject_id in data['subject_id'].unique():
                item_df = data[data['subject_id'] == subject_id]
                patient, ages = [], []
                
                for _, row in item_df.iterrows():
                    admission = []
                    hadm_id = row['hadm_id']
                    if "diag" in token_type:
                        admission.append(list(row['icd_code']))
                    if "med" in token_type:
                        admission.append(list(row['ndc']))
                    if "pro" in token_type:
                        admission.append(list(row['pro_code']))
                    if "lab" in token_type:
                        admission.append(list(row['lab_test']))
                    patient.append(admission)
                    ages.append(row['age'])
                    if task in ["death", "stay", "readmission"]:  # binary prediction
                        hadm_records[hadm_id] = list(patient)
                        age_records[hadm_id] = ages
                        genders[hadm_id] = [item_df.head(1)["gender"].values[0]]
                        if "READMISSION" in row:
                            labels[hadm_id] = [row["DEATH"], row["STAY_DAYS"], row["READMISSION"]]
                        else:
                            labels[hadm_id] = [row["DEATH"], row["STAY_DAYS"]]
                    else:  # next diagnosis prediction
                        label = row["NEXT_DIAG_6M"] if task == "next_diag_6m" else row["NEXT_DIAG_12M"]
                        if str(label) != "nan":  # only include the admission with next diagnosis
                            hadm_records[hadm_id] = list(patient)
                            age_records[hadm_id] = ages
                            genders[hadm_id] = [item_df.head(1)["gender"].values[0]]
                            labels[hadm_id] = list(label)

            return hadm_records, age_records, genders, labels
        self.records, self.ages, self.genders, self.labels = transform_data(data_pd, task)

    def __len__(self):
        return len(self.records)

    def __getitem__(self, item):
        hadm_id = list(self.records.keys())[item]

        input_tokens = ['[CLS]']
        visit_positions = [0]
        age_tokens = [self.ages[hadm_id][0]]
        for idx, adm in enumerate(self.records[hadm_id]):  # each subject have multiple admissions, idx:visit id
            cur_input_tokens = []
            for i in range(len(adm)):  # one admission may have many kinds of entities, [[diag], [med], ...]
                cur_input_tokens.extend(list(adm[i]))
            if idx != len(self.records[hadm_id]) - 1:  # add [SEP] token between visits
                cur_input_tokens.append('[SEP]')
            input_tokens.extend(cur_input_tokens)
            visit_positions.extend([idx] * len(cur_input_tokens))
            age_tokens.extend([self.ages[hadm_id][idx]] * len(cur_input_tokens))

        if self.task == "death":
            # predict if the patient will die in the hospital
            labels = torch.tensor([self.labels[hadm_id][0]]).float()
        elif self.task == "stay":
            # predict if the patient will stay in the hospital for more than 7 days
            labels = (torch.tensor([self.labels[hadm_id][1]]) > 7).float()
        elif self.task == "readmission":
            # predict if the patient will be readmitted within 1 month
            labels = torch.tensor([self.labels[hadm_id][2]]).float()
        else:
            # predict the next diagnosis in 6 months or 12 months
            input_tokens.extend(['[SEP]', '[MASK]'])
            visit_positions.extend([visit_positions[-1]+1, visit_positions[-1]+1])
            age_tokens.extend([age_tokens[-1], age_tokens[-1]])
            label_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(self.labels[hadm_id], voc_type='diag'))
            labels = torch.zeros(self.tokenizer.token_number(voc_type='diag')).long()
            labels[label_ids] = 1  # multi-hop vector

        visit_positions = torch.tensor(visit_positions)
        age_tokens = torch.tensor(self.tokenizer.convert_tokens_to_ids(age_tokens))
        input_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(input_tokens))

        return input_ids, visit_positions, age_tokens, labels


def batcher(tokenizer):
    def batcher_dev(batch):
        input_ids, visit_positions, age_tokens = [feat[0] for feat in batch], [feat[1] for feat in batch], [feat[2] for feat in batch]
        if len(batch[0]) > 3:
            batch_data = []
            for d in range(3, len(batch[0])):
                data = [feat[d] for feat in batch]
                max_len = max([len(x) for x in data])
                data = [F.pad(x, (0, max_len - len(x)), "constant", 0) for x in data]
                data = torch.stack(data)
                batch_data.append(data)

        # padding
        pad_id = tokenizer.vocab.word2idx['[PAD]']
        max_len = max([len(x) for x in input_ids])
        input_ids = [F.pad(x, (0, max_len - len(x)), "constant", pad_id) for x in input_ids]
        visit_positions = [F.pad(x, (0, max_len - len(x)), "constant", pad_id) for x in visit_positions]
        age_tokens = [F.pad(x, (0, max_len - len(x)), "constant", pad_id) for x in age_tokens]
        
        input_ids = torch.stack(input_ids)
        visit_positions = torch.stack(visit_positions)
        age_tokens = torch.stack(age_tokens)

        if len(batch[0]) > 3:
            return input_ids, age_tokens, visit_positions, *batch_data
        else:
            return input_ids, age_tokens, visit_positions
    
    return batcher_dev

In [18]:
def _pad_sequence(seqs, pad_id=0):
    # seqs: a list of tensor [n, m]
    max_len = max([x.shape[1] for x in seqs])
    return torch.cat([F.pad(x, (0, max_len - x.shape[1]), "constant", pad_id) for x in seqs], dim=0)


class HBERTPretrainEHRDataset(PretrainEHRDataset):
    def __init__(self, data_pd, tokenizer: EHRTokenizer, token_type=['diag', 'med', 'pro', 'lab'], mask_rate=0.15, anomaly_rate=0.1):
        super().__init__(data_pd, tokenizer, token_type)
        
        self.mask_rate = mask_rate
        self.anomaly_rate = anomaly_rate
        self.token_type = token_type
        self.token_type_map = {i:t for i, t in enumerate(token_type)}

    def _id2multi_hot(self, ids, dim):
        multi_hot = torch.zeros(dim)
        multi_hot[ids] = 1
        return multi_hot

    def __getitem__(self, item):
        subject_id = list(self.records.keys())[item]

        input_tokens, token_types, masked_labels, anomaly_labels = [], [], [None for _ in range(len(self.token_type))], []
        for idx, adm in enumerate(self.records[subject_id]):  # each subject have multiple admissions, idx:visit id
            adm_tokens, adm_token_types, adm_masked_labels = [str(self.ages[subject_id][idx]) + "_" + str(self.genders[subject_id][0])], [0], []  # replace [CLS] token with age
            adm_anomaly_labels = []

            for i in range(len(adm)):  # one admission have many kinds of entities, [[diag], [med], ...]
                cur_tokens = list(adm[i])

                # randomly mask tokens
                non_special_tokens_idx = [idx for idx, x in enumerate(cur_tokens)]
                masked_tokens_idx = np.random.choice(non_special_tokens_idx, max(1, int(len(non_special_tokens_idx) * self.mask_rate)))
                masked_tokens = [cur_tokens[idx] for idx in masked_tokens_idx]
                masked_tokens_idx_ = set(masked_tokens_idx.tolist())  # for fast lookup
                non_masked_tokens = [cur_tokens[idx] for idx in non_special_tokens_idx if idx not in masked_tokens_idx_]

                # randomly replace tokens with other tokens
                if self.anomaly_rate > 0 and len(non_masked_tokens) > 0:
                    candidate_token_idx = [idx for idx, x in enumerate(non_masked_tokens)]
                    anomaly_tokens_idx = np.random.choice(candidate_token_idx, max(1, int(len(candidate_token_idx) * self.anomaly_rate)))
                    for ano_idx in anomaly_tokens_idx:
                        non_masked_tokens[ano_idx] = self.tokenizer.random_token(voc_type=self.token_type_map[i])
                        adm_anomaly_labels.append(len(adm_tokens) + ano_idx + 1)  # the position of the anomaly token, +1 for [MASK] tokens

                adm_tokens.extend([f"[MASK{i}]"] + non_masked_tokens)  # [[MASK1], diag1, diag2, [MASK2], med1, med2]
                adm_token_types.extend([i + 1] * (len(non_masked_tokens) + 1))  # [0, 1, 1, 2, 2]
                adm_masked_labels.append(masked_tokens)  # [[diag1, diag2], [med1, med2]]

            input_tokens.append(torch.tensor([self.tokenizer.convert_tokens_to_ids(adm_tokens)]))
            token_types.append(torch.tensor([adm_token_types]))
            for i in range(len(self.token_type)):
                label_ids = self.tokenizer.convert_tokens_to_ids(adm_masked_labels[i], voc_type=self.token_type_map[i])
                label_hop = self._id2multi_hot(label_ids, dim=self.tokenizer.token_number(self.token_type_map[i])).unsqueeze(dim=0)
                if masked_labels[i] is None:
                    masked_labels[i] = label_hop
                else:
                    masked_labels[i] = torch.cat([masked_labels[i], label_hop])
            
            if len(adm_anomaly_labels) > 0:
                anomaly_labels.append(self._id2multi_hot(adm_anomaly_labels, dim=len(adm_tokens)).unsqueeze(dim=0))
            else:
                anomaly_labels.append(torch.zeros(len(adm_tokens)).unsqueeze(dim=0))

        visit_positions = torch.tensor(list(range(len(input_tokens))))  # [0, 1, 2, ...]
        input_tokens = _pad_sequence(input_tokens, pad_id=self.tokenizer.vocab.word2idx["[PAD]"])
        token_types = _pad_sequence(token_types, pad_id=0)
        anomaly_labels = _pad_sequence(anomaly_labels, pad_id=0) if len(anomaly_labels) > 0 else None
        n_adms = len(input_tokens)
        if n_adms > 1:
            # create a fully connected graph between admission
            edge_index = torch.tensor([[i, j] for i in range(n_adms) for j in range(n_adms)]).t()  # [2, n_adms * n_adms]
        else:
            edge_index = torch.tensor([])
        return input_tokens, token_types, edge_index, visit_positions, masked_labels, anomaly_labels


class HBERTFinetuneEHRDataset(FinetuneEHRDataset):
    def __init__(self, data_pd, tokenizer, token_type=['diag', 'med', 'pro', 'lab'], task='death'):
        super().__init__(data_pd, tokenizer, token_type, task)

    def __getitem__(self, item):
        hadm_id = list(self.records.keys())[item]

        input_tokens, token_types = [], []
        for idx, adm in enumerate(self.records[hadm_id]):  # each subject have multiple admissions, idx:visit id
            adm_tokens = [str(self.ages[hadm_id][idx]) + "_" + self.genders[hadm_id][0]]  # replace [CLS] token with age
            # adm_tokens = [self.ages[hadm_id][idx]]  # replace [CLS] token with age
            adm_token_types = [0]

            for i in range(len(adm)):
                cur_tokens = list(adm[i])
                adm_tokens.extend(cur_tokens)
                adm_token_types.extend([i + 1] * len(cur_tokens))
            
            # input_tokens.append(torch.tensor([self.tokenizer.convert_tokens_to_ids(adm_tokens)]))
            input_tokens.append(adm_tokens)
            # token_types.append(torch.tensor([adm_token_types]))
            token_types.append(adm_token_types)

        if self.task == "death":
            # predict if the patient will die in the hospital
            labels = torch.tensor([self.labels[hadm_id][0]]).float()
        elif self.task == "stay":
            # predict if the patient will stay in the hospital for more than 7 days
            labels = (torch.tensor([self.labels[hadm_id][1]]) > 7).float()
        elif self.task == "readmission":
            # predict if the patient will be readmitted within 1 month
            labels = torch.tensor([self.labels[hadm_id][2]]).float()
        else:
            # predict the next diagnosis in 6 months or 12 months
            input_tokens[-1] = [input_tokens[-1][0]] + ["[MASK0]"] + input_tokens[-1][1:]
            token_types[-1] = [token_types[-1][0]] + [1] + token_types[-1][1:]
            label_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(self.labels[hadm_id], voc_type='diag'))
            labels = torch.zeros(self.tokenizer.token_number(voc_type='diag')).long()
            labels[label_ids] = 1  # multi-hop vector

        visit_positions = torch.tensor(list(range(len(input_tokens))))  # [0, 1, 2, ...]
        input_tokens = [torch.tensor([self.tokenizer.convert_tokens_to_ids(x)]) for x in input_tokens]
        token_types = [torch.tensor([x]) for x in token_types]
        input_tokens = _pad_sequence(input_tokens, pad_id=self.tokenizer.vocab.word2idx["[PAD]"])
        token_types = _pad_sequence(token_types, pad_id=0)
        n_adms = len(input_tokens)
        if n_adms > 1:
            # create a fully connected graph between admission
            edge_index = torch.tensor([[i, j] for i in range(n_adms) for j in range(n_adms)]).t()  # [2, n_adms * n_adms]
        else:
            edge_index = torch.tensor([])
        
        return input_tokens, token_types, edge_index, visit_positions, labels


def batcher(tokenizer, n_token_type=3, is_train=True):
    def batcher_dev(batch):
        raw_input_ids, raw_input_types, raw_edge_indexs, raw_visit_positions, raw_labels = [feat[0] for feat in batch], [feat[1] for feat in batch], [feat[2] for feat in batch], [feat[3] for feat in batch], [feat[4] for feat in batch]

        pad_id = tokenizer.vocab.word2idx["[PAD]"]
        max_n_tokens = max([x.size(1) for x in raw_input_ids])
        input_ids = torch.cat([F.pad(raw_input_id, (0, max_n_tokens - raw_input_id.size(1)), "constant", pad_id) for raw_input_id in raw_input_ids], dim=0)

        max_n_token_types = max([x.size(1) for x in raw_input_types])
        input_types = torch.cat([F.pad(raw_input_type, (0, max_n_token_types - raw_input_type.size(1)), "constant", 0) for raw_input_type in raw_input_types], dim=0)

        n_cumsum_nodes = [0] + np.cumsum([input_id.size(0) for input_id in raw_input_ids]).tolist()
        edge_index = []
        for i, raw_edge_index in enumerate(raw_edge_indexs):
            if raw_edge_index.shape[0] > 0:
                edge_index.append(raw_edge_index + n_cumsum_nodes[i])
        edge_index = torch.cat(edge_index, dim=1) if len(edge_index) > 0 else None

        visit_positions = torch.cat(raw_visit_positions, dim=0)

        if is_train:
            labels = []  # [n_token_type, B, n_tokens], each element is a multi-hop label tensor
            for i in range(n_token_type):
                labels.append(torch.cat([x[i] for x in raw_labels]))
            
            raw_anomaly_labels = [feat[5] for feat in batch]
            if raw_anomaly_labels[0] is not None:
                max_n_anomaly_labels = max([x.size(1) for x in raw_anomaly_labels])
                anomaly_labels = torch.cat([F.pad(raw_anomaly_label, (0, max_n_anomaly_labels - raw_anomaly_label.size(1)), "constant", 0) for raw_anomaly_label in raw_anomaly_labels], dim=0)
            else:
                anomaly_labels = None
            return input_ids, input_types, edge_index, visit_positions, labels, anomaly_labels
        else:
            labels = torch.stack(raw_labels, dim=0)
            labeled_batch_idx = [n - 1 for n in n_cumsum_nodes[1:]]  # indicate the index of the to-be-predicted admission
            return input_ids, input_types, edge_index, visit_positions, labeled_batch_idx, labels
    
    return batcher_dev

## MODELS

In [19]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, config):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x):
        return self.fc2(self.dropout(gelu(self.fc1(x))))


class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super(MultiHeadAttention, self).__init__()

        self.d_k, self.n_heads = config.hidden_size // config.num_attention_heads, config.num_attention_heads

        self.W_Q = nn.Linear(config.hidden_size, config.hidden_size)
        self.W_K = nn.Linear(config.hidden_size, config.hidden_size)
        self.W_V = nn.Linear(config.hidden_size, config.hidden_size)
        self.W_output = nn.Linear(config.hidden_size, config.hidden_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
    
    def ScaledDotProductAttention(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(Q.size(-1)) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = self.dropout(nn.Softmax(dim=-1)(scores))
        context = torch.matmul(attn, V)
        return context, attn

    def forward(self, Q, K, V, attn_mask):
        batch_size = Q.size(0)

        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        context, attn = self.ScaledDotProductAttention(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_k) # context: [batch_size x len_q x n_heads * d_v]
        return self.W_output(context)


class TransformerBlock(nn.Module):
    def __init__(self, config):
        super(TransformerBlock, self).__init__()
        self.self_attn = MultiHeadAttention(config)
        self.pos_ffn = PoswiseFeedForwardNet(config)

        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.norm_attn = nn.LayerNorm(config.hidden_size)
        self.norm_ffn = nn.LayerNorm(config.hidden_size)

    def forward(self, x, self_attn_mask):
        norm_x = self.norm_attn(x)
        x = x + self.dropout(self.self_attn(norm_x, norm_x, norm_x, self_attn_mask))

        norm_x = self.norm_ffn(x)
        x = x + self.dropout(self.pos_ffn(norm_x))
        return x

In [20]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import einsum

class EdgeModule(nn.Module):
    def __init__(self, config):
        super(EdgeModule, self).__init__()

        self.n_types = 4 + 1  # +1 for [CLS]

        self.left_transform = nn.Parameter(torch.zeros(self.n_types, config.hidden_size, config.hidden_size))
        self.right_transform = nn.Parameter(torch.zeros(self.n_types, config.hidden_size, config.hidden_size))
        self.output = nn.Linear(config.hidden_size * 2, config.edge_hidden_size)

        self.init_parameters()

    def init_parameters(self):
        nn.init.xavier_uniform_(self.left_transform)
        nn.init.xavier_uniform_(self.right_transform)

    def forward(self, token_embs, token_types):
        # token_embs: [batch_size, seq_len, hidden_size], token_types: [batch_size, seq_len]
        batch_size, seq_len, hidden_size = token_embs.size()

        # encode token according to its type
        left_trans = self.left_transform[token_types]  # [batch_size, seq_len, hidden_size, hidden_size]
        right_trans = self.right_transform[token_types]  # [batch_size, seq_len, hidden_size, hidden_size]

        left_embs = einsum(token_embs, left_trans, 'b l d, b l m d -> b l m')  # [batch_size, seq_len, hidden_size]
        right_embs = einsum(token_embs, right_trans, 'b l d, b l m d -> b l m')  # [batch_size, seq_len, hidden_size]

        edge_embs = torch.cat((left_embs.unsqueeze(dim=2).repeat(1, 1, seq_len, 1), 
                               right_embs.unsqueeze(dim=1).repeat(1, seq_len, 1, 1)), dim=-1)
        return self.output(edge_embs)


class MultiHeadEdgeAttention(MultiHeadAttention):
    def __init__(self, config):
        super().__init__(config)

        self.d_edge = config.edge_hidden_size
        self.W_output = nn.Linear(config.hidden_size * 2, config.hidden_size)
        self.W_K_edge = nn.Linear(config.edge_hidden_size, config.edge_hidden_size)
        self.W_edge = nn.Linear(config.edge_hidden_size, 1)
        self.W_edge_output = nn.Linear(self.d_edge * self.n_heads, config.hidden_size)

    def forward(self, Q, K, V, attn_mask, edge_embs):
        batch_size, n_tokens = Q.size(0), Q.size(1)

        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

        k_s_edge = self.W_K_edge(edge_embs).view(batch_size, n_tokens, n_tokens, -1)
        edge_bias = self.W_edge(edge_embs).view(batch_size, 1, n_tokens, n_tokens)
        edge_bias = edge_bias * (2 ** -0.5)

        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

        # bias attention scores with edge representation
        scores = torch.matmul(q_s, k_s.transpose(-1, -2)) * ((2 * q_s.size(-1)) ** -0.5)
        scores = scores + edge_bias
        scores.masked_fill_(attn_mask, -1e9)  # [batch_size, n_heads, n_tokens, n_tokens]
        attn = self.dropout(nn.Softmax(dim=-1)(scores))
        
        # add edge context into aggregated context
        context = torch.matmul(attn, v_s)
        edge_context = einsum(attn, k_s_edge, 'b h n m, b n m d -> b h n d')  # [batch_size, n_tokens, n_heads, d_k]
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_k) # context: [batch_size x len_q x n_heads * d_v]
        edge_context = edge_context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_edge)
        edge_context = self.W_edge_output(edge_context)
        return self.W_output(torch.cat([context, edge_context], dim=-1))


class EdgeTransformerBlock(TransformerBlock):
    def __init__(self, config):
        super().__init__(config)

        self.self_attn = MultiHeadEdgeAttention(config)
        self.norm_edge = nn.LayerNorm(config.edge_hidden_size)

    def forward(self, x, edge_embs, self_attn_mask):
        norm_x = self.norm_attn(x)
        norm_edge_embs = self.norm_edge(edge_embs)
        x = x + self.dropout(self.self_attn(norm_x, norm_x, norm_x, self_attn_mask, norm_edge_embs))

        norm_x = self.norm_ffn(x)
        x = x + self.dropout(self.pos_ffn(norm_x))
        return x

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.utils import softmax
from torch_scatter import scatter_mean, scatter_sum


class DotAttnConv(nn.Module):
    def __init__(self, in_channels, out_channels, n_heads=1, n_max_visits=15, temp=1.):
        super(DotAttnConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n_heads, self.temp = n_heads, temp

        self.pos_encoding = nn.Embedding(n_max_visits, in_channels)
        self.W_q = nn.Linear(in_channels, out_channels, bias=False)
        self.W_k = nn.Linear(in_channels, out_channels, bias=False)
        self.W_v = nn.Linear(in_channels, out_channels, bias=False)
        self.W_out = nn.Linear(out_channels, out_channels, bias=False)
        self.ln = nn.LayerNorm(out_channels)

    def forward(self, x, edge_index, visit_pos):
        # x: [N, in_channels], edge_index: [2, E]
        N, device = x.size(0), x.device
        isolated_nodes_mask = ~torch.isin(torch.arange(N).to(x.device), edge_index[1].unique())
        isolated_nodes = isolated_nodes_mask.nonzero(as_tuple=False).squeeze()

        pos_encoding = self.pos_encoding(visit_pos)
        h_q, h_k, h_v = self.W_q(x + pos_encoding), self.W_k(x + pos_encoding), self.W_k(x)
        h_q, h_k, h_v = h_q.reshape(N, self.n_heads, -1), h_k.reshape(N, self.n_heads, -1), h_v.reshape(N, self.n_heads, -1)
        
        attn_scores = torch.sum(h_q[edge_index[0]] * h_k[edge_index[1]], dim=-1) / self.temp  # [N_edges, n_heads]
        dst_nodes = torch.cat([edge_index[1] + N*i for i in range(self.n_heads)], dim=0).to(device)
        attn_scores = softmax(attn_scores.reshape(-1), dst_nodes, num_nodes=N * self.n_heads).unsqueeze(dim=-1)  # [N_edges * n_heads, 1]

        # aggregation
        h_v = h_v.permute(1, 0, 2).reshape(N*self.n_heads, -1)
        src_nodes = torch.cat([edge_index[0] + N*i for i in range(self.n_heads)], dim=0).to(device)
        out = scatter_sum(src=h_v[src_nodes] * attn_scores, index=dst_nodes, dim_size=N * self.n_heads, dim=0)
        out = out.reshape(self.n_heads, N, -1).permute(1, 0, 2).reshape(N, -1)

        out = self.W_out(self.ln(out)) + x
        out[isolated_nodes] = x[isolated_nodes]
        return out

HEART

In [22]:
class TreeEmbeddings(nn.Module):
    def __init__(self, config, diag_tree_table, med_tree_table, n_diag_tokens, n_med_tokens, diag_range, med_range):
        super(TreeEmbeddings, self).__init__()
        # tree_table: [n_diag/n_med, n_level]
        self.n_dim = config.hidden_size
        self.diag_range, self.med_range = diag_range, med_range
        self.diag_tree_table, self.med_tree_table = diag_tree_table, med_tree_table

        self.diag_tokens = nn.Embedding(n_diag_tokens, config.hidden_size // diag_tree_table.shape[1])
        self.med_tokens = nn.Embedding(n_med_tokens, config.hidden_size // med_tree_table.shape[1])
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)  # [PAD] token
        self.emb_dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids, token_types):
        B, N = input_ids.shape[0], input_ids.shape[1]

        # concat the embedding at each layer
        diag_tree_tokens = self.diag_tokens(self.diag_tree_table.to(input_ids.device)).reshape(-1, self.n_dim)
        med_tree_tokens = self.med_tokens(self.med_tree_table.to(input_ids.device)).reshape(-1, self.n_dim)

        input_ids = input_ids.reshape(-1)
        diag_mask = (input_ids >= self.diag_range[0]) * (input_ids < self.diag_range[1])
        med_mask = (input_ids >= self.med_range[0]) * (input_ids < self.med_range[1])

        words_embeddings = self.word_embeddings(input_ids)
        diag_embeddings = diag_tree_tokens[input_ids[diag_mask] - self.diag_range[0]]
        med_embeddings = med_tree_tokens[input_ids[med_mask] - self.med_range[0]]

        # replace the diagnosis and medication embeddings with tree embeddings
        words_embeddings[diag_mask] = diag_embeddings
        words_embeddings[med_mask] = med_embeddings
        words_embeddings = words_embeddings.reshape(B, N, -1)

        # words_embeddings = words_embeddings + self.type_embedding(token_types)

        return self.emb_dropout(words_embeddings)


class HBERTEmbeddings(nn.Module):
    def __init__(self, config):
        super(HBERTEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)  # [PAD] token
        self.emb_dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids, token_types):
        # embedding the indexed sequence to sequence of vectors
        words_embeddings = self.word_embeddings(input_ids)
        return self.emb_dropout(words_embeddings)


# Hierarchical Transformer
class HiTransformer(nn.Module):
    def __init__(self, config):
        super(HiTransformer, self).__init__()
        
        # multi-layers transformer blocks, deep network
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(config) for _ in range(config.num_hidden_layers)])

        # multi-layers transformer blocks, deep network
        if config.gat == "dotattn":
            self.cross_attentions = nn.ModuleList(
                [DotAttnConv(config.hidden_size, config.hidden_size, config.gnn_n_heads, config.max_visit_size, config.gnn_temp) for _ in range(config.num_hidden_layers)])
        elif config.gat == "None":
            self.cross_attentions = None

    def forward(self, x, edge_index, mask, visit_positions):
        # running over multiple transformer blocks
        for i in range(len(self.transformer_blocks)):
            x = self.transformer_blocks[i](x, mask)  # [B, L, D]
            if edge_index is not None and self.cross_attentions is not None:
                x = torch.cat([self.cross_attentions[i](x[:, 0], edge_index, visit_positions).unsqueeze(dim=1), 
                                x[:, 1:]], dim=1)  # communicate between visits
        return x


# Hierarchical Transformer with Edge Representation
class HiEdgeTransformer(nn.Module):
    def __init__(self, config):
        super(HiEdgeTransformer, self).__init__()

        self.edge_module = EdgeModule(config)
        
        self.transformer_blocks = nn.ModuleList(
            [EdgeTransformerBlock(config) for _ in range(config.num_hidden_layers)])
        
        if config.gat == "dotattn":
            self.cross_attentions = nn.ModuleList(
                [DotAttnConv(config.hidden_size, config.hidden_size, config.gnn_n_heads, config.max_visit_size, config.gnn_temp) for _ in range(config.num_hidden_layers)])
        elif config.gat == "None":
            self.cross_attentions = None

    def forward(self, x, x_types, edge_index, mask, visit_positions):
        edge_embs = self.edge_module(x, x_types)  # [B, L, L, D]
        for i in range(len(self.transformer_blocks)):
            x = self.transformer_blocks[i](x, edge_embs, mask)  # [B, L, D]
            if edge_index is not None and self.cross_attentions is not None:
                x = torch.cat([self.cross_attentions[i](x[:, 0], edge_index, visit_positions).unsqueeze(dim=1), 
                                x[:, 1:]], dim=1)  # communicate between visits
        return x


class MaskedPredictionHead(nn.Module):
    def __init__(self, config, voc_size):
        super(MaskedPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(config.hidden_size, config.hidden_size), 
                nn.ReLU(), 
                nn.Linear(config.hidden_size, voc_size)
            )

    def forward(self, input):
        return self.cls(input)


# binary classification task
class BinaryPredictionHead(nn.Module):
    def __init__(self, config):
        super(BinaryPredictionHead, self).__init__()
        self.cls = nn.Sequential(
                nn.Linear(config.hidden_size, config.hidden_size), 
                nn.ReLU(), 
                nn.Linear(config.hidden_size, 1)
            )

    def forward(self, input):
        return self.cls(input)


class HBERT_Pretrain(nn.Module):
    def __init__(self, config, tokenizer):
        super(HBERT_Pretrain, self).__init__()

        if config.diag_med_emb == "simple":
            self.embeddings = HBERTEmbeddings(config)
        elif config.diag_med_emb == "tree":
            diag_tree_table, med_tree_table = tokenizer.diag_tree_table, tokenizer.med_tree_table
            n_diag_tokens, n_med_tokens = len(tokenizer.diag_tree_voc.idx2word), len(tokenizer.med_tree_voc.idx2word)
            diag_range, med_range = tokenizer.token_id_range("diag"), tokenizer.token_id_range("med")
            self.embeddings = TreeEmbeddings(config, diag_tree_table, med_tree_table, 
                                            n_diag_tokens, n_med_tokens, diag_range, med_range)

        # self.loss_fn = torch.nn.BCEWithLogitsLoss
        self.loss_fn = F.binary_cross_entropy_with_logits
        self.pos_weight = torch.tensor(config.pos_weight)
        self.encoder = config.encoder

        if config.encoder == "hi":
            self.transformer = HiTransformer(config)
        elif config.encoder == "hi_edge":
            self.transformer = HiEdgeTransformer(config)
        else:
            raise NotImplementedError

        self.mask_token_id = config.mask_token_id  # {token_type: masked_id}
        predicted_token_type = config.predicted_token_type  # ["diag", "med", "pro", "lab"]
        label_vocab_size = config.label_vocab_size  # {token_type: vocab_size}
        for token_type in predicted_token_type:
            self.add_module(f"{token_type}_cls", MaskedPredictionHead(config, label_vocab_size[token_type]))
        if config.anomaly_rate > 0:
            self.anomaly_loss_weight = config.anomaly_loss_weight
            self.anomaly_detection_head = BinaryPredictionHead(config)

    def forward(self, input_ids, token_types, edge_index, visit_positions, masked_labels, anomaly_labels):
        device = input_ids.device
        pad_mask = (input_ids > 0)
        pair_pad_mask = pad_mask.unsqueeze(1).repeat(1, input_ids.size(1), 1)

        # embedding the indexed sequence to sequence of vectors
        x = self.embeddings(input_ids, token_types)

        if self.encoder == "hi":
            x = self.transformer(x, edge_index, ~pair_pad_mask, visit_positions)
        elif self.encoder == "hi_edge":
            x = self.transformer(x, token_types, edge_index, ~pair_pad_mask, visit_positions)

        ave_loss, loss_dict = 0, {}
        for i, (token_type, mask_id) in enumerate(self.mask_token_id.items()):
            masked_token_emb = x[input_ids == mask_id]
            prediction = self._modules[f"{token_type}_cls"](masked_token_emb)
            loss = self.loss_fn(prediction, masked_labels[i].to(input_ids.device))
            # loss = self.loss_fn(prediction.view(-1), masked_labels[i].view(-1).to(device), pos_weight=self.pos_weight.to(device))
            ave_loss += loss
            loss_dict[token_type] = loss.cpu().item()
        
        if anomaly_labels is not None:
            anomaly_prediction = self.anomaly_detection_head(x)
            # anomaly_loss = self.loss_fn(reduction='none')(anomaly_prediction.view(-1), anomaly_labels.view(-1))
            anomaly_loss = self.loss_fn(anomaly_prediction.view(-1), anomaly_labels.view(-1), reduction='none')
            anomaly_loss = (anomaly_loss * pad_mask.view(-1)).sum() / pad_mask.sum()
            ave_loss += self.anomaly_loss_weight * anomaly_loss
            loss_dict["anomaly"] = anomaly_loss.cpu().item()
        else:
            loss_dict["anomaly"] = 0.

        return ave_loss / len(loss_dict), loss_dict


class HBERT_Finetune(nn.Module):
    def __init__(self, config, tokenizer):
        super(HBERT_Finetune, self).__init__()

        if config.diag_med_emb == "simple":
            self.embeddings = HBERTEmbeddings(config)
        elif config.diag_med_emb == "tree":
            diag_tree_table, med_tree_table = tokenizer.diag_tree_table, tokenizer.med_tree_table
            n_diag_tokens, n_med_tokens = len(tokenizer.diag_tree_voc.idx2word), len(tokenizer.med_tree_voc.idx2word)
            diag_range, med_range = tokenizer.token_id_range("diag"), tokenizer.token_id_range("med")
            self.embeddings = TreeEmbeddings(config, diag_tree_table, med_tree_table, 
                                            n_diag_tokens, n_med_tokens, diag_range, med_range)

        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        self.encoder = config.encoder
        self.diag_mask_id = 3  # the idx of [MASK0] token
        self.task = config.task

        if config.encoder == "hi":
            self.transformer = HiTransformer(config)
        elif config.encoder == "hi_edge":
            self.transformer = HiEdgeTransformer(config)
        else:
            raise NotImplementedError

        if config.task in ["death", "stay", "readmission"]:
            self.downstream_cls = BinaryPredictionHead(config)
        else:
            self.downstream_cls = MaskedPredictionHead(config, config.label_vocab_size)

    def load_weight(self, checkpoint_dict):
        param_dict = dict(self.named_parameters())
        for key in checkpoint_dict.keys():
            if key in param_dict:
                param_dict[key].data.copy_(checkpoint_dict[key])
    
    def forward(self, input_ids, token_types, edge_index, visit_positions, labeled_ids):
        pad_mask = (input_ids > 0).unsqueeze(1).repeat(1, input_ids.size(1), 1)

        # embedding the indexed sequence to sequence of vectors
        x = self.embeddings(input_ids, token_types)

        if self.encoder == "hi":
            x = self.transformer(x, edge_index, ~pad_mask, visit_positions)
        elif self.encoder == "hi_edge":
            x = self.transformer(x, token_types, edge_index, ~pad_mask, visit_positions)

        if self.task in ["death", "stay", "readmission"]:
            prediction = self.downstream_cls(x[labeled_ids][:, 0])
        else:
            labeled_ids, labeled_x = input_ids[labeled_ids], x[labeled_ids]
            masked_pos_embs = labeled_x[labeled_ids == self.diag_mask_id]
            prediction = self.downstream_cls(masked_pos_embs)
        return prediction

In [23]:
def read_data(args, all_data_path, pretrain_data_path):
    ehr_data = pickle.load(open(all_data_path, 'rb'))
    diag_sentences = ehr_data["icd_code"].values.tolist()
    med_sentences = ehr_data["ndc"].values.tolist()
    lab_sentences = ehr_data["lab_test"].values.tolist()
    if args.dataset == "mimic":
        pro_sentences = ehr_data["pro_code"].values.tolist()
        gender_set = [["M"], ["F"]]
        age_gender_set = [[str(c) + "_" + gender] for c in set(ehr_data["age"].values.tolist()) for gender in ["M", "F"]]
    else:
        pro_sentences = None
        gender_set = [["Female"], ["Male"], ["Unknown"], ["Other"]]
        age_gender_set = [[str(c) + "_" + gender] for c in set(ehr_data["age"].values.tolist()) for gender in ["Female", "Male", "Unknown", "Other"]]
    age_set = [[c] for c in set(ehr_data["age"].values.tolist())]
    
    ehr_pretrain_data = pickle.load(open(pretrain_data_path, 'rb'))
    tokenizer = EHRTokenizer(diag_sentences, med_sentences, lab_sentences, pro_sentences, gender_set, age_set, age_gender_set, special_tokens=args.special_tokens)
    if args.dataset == "mimic":
        tokenizer.build_tree()
    dataset = HBERTPretrainEHRDataset(ehr_pretrain_data, tokenizer, token_type=args.predicted_token_type, mask_rate=args.mask_rate)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=batcher(tokenizer, n_token_type=len(args.predicted_token_type)), shuffle=True)
    return tokenizer, dataloader

In [24]:
def set_random_seed(seed: int = 42):
    import os, random, numpy as np
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    try:
        import torch
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except ImportError:
        pass

In [25]:
import os
import sys
import copy
import wandb
import random
import pickle
import logging
import numpy as np
from tqdm.notebook import tqdm  # <--- better for Jupyter
from copy import deepcopy
from collections import defaultdict

import torch
from torch.utils.data import DataLoader

LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)

from argparse import Namespace

args = Namespace(
    seed=0,
    dataset="mimic",
    device=0,
    batch_size=32,
    lr=2e-5,
    epochs=50,
    use_wandb=False,
    encoder="hi_edge",
    mask_rate=0.7,
    anomaly_rate=0.05,
    anomaly_loss_weight=1,
    pos_weight=1,
    num_hidden_layers=5,
    num_attention_heads=6,
    attention_probs_dropout_prob=0.2,
    hidden_dropout_prob=0.2,
    edge_hidden_size=32,
    hidden_size=288,
    intermediate_size=288,
    gnn_n_heads=1,
    gnn_temp=1,
    gat="dotattn",
    diag_med_emb="tree"
)

print(vars(args))  # See all your parameters

{'seed': 0, 'dataset': 'mimic', 'device': 0, 'batch_size': 32, 'lr': 2e-05, 'epochs': 50, 'use_wandb': False, 'encoder': 'hi_edge', 'mask_rate': 0.7, 'anomaly_rate': 0.05, 'anomaly_loss_weight': 1, 'pos_weight': 1, 'num_hidden_layers': 5, 'num_attention_heads': 6, 'attention_probs_dropout_prob': 0.2, 'hidden_dropout_prob': 0.2, 'edge_hidden_size': 32, 'hidden_size': 288, 'intermediate_size': 288, 'gnn_n_heads': 1, 'gnn_temp': 1, 'gat': 'dotattn', 'diag_med_emb': 'tree'}


In [26]:
vars(args)

{'seed': 0,
 'dataset': 'mimic',
 'device': 0,
 'batch_size': 32,
 'lr': 2e-05,
 'epochs': 50,
 'use_wandb': False,
 'encoder': 'hi_edge',
 'mask_rate': 0.7,
 'anomaly_rate': 0.05,
 'anomaly_loss_weight': 1,
 'pos_weight': 1,
 'num_hidden_layers': 5,
 'num_attention_heads': 6,
 'attention_probs_dropout_prob': 0.2,
 'hidden_dropout_prob': 0.2,
 'edge_hidden_size': 32,
 'hidden_size': 288,
 'intermediate_size': 288,
 'gnn_n_heads': 1,
 'gnn_temp': 1,
 'gat': 'dotattn',
 'diag_med_emb': 'tree'}

In [44]:
if args.dataset == "mimic":
    args.max_visit_size = 15
    args.predicted_token_type = ["diag", "med", "pro", "lab"]
    args.mask_token_id = {"diag":3, "med":4, "pro":5, "lab":6}
    args.special_tokens = ("[PAD]", "[CLS]", "[SEP]", "[MASK0]", "[MASK1]", "[MASK2]", "[MASK3]")
    data_path = "/data/horse/ws/arsi805e-finetune/Thesis/dataset/mimic.pkl"
    pretrain_data_path = "/data/horse/ws/arsi805e-finetune/Thesis/dataset/mimic_pretrain.pkl"
elif args.dataset == "eicu":
    args.max_visit_size = 24
    args.predicted_token_type = ["diag", "med", "lab"]
    args.mask_token_id = {"diag":3, "med":4, "lab":5}
    args.special_tokens = ("[PAD]", "[CLS]", "[SEP]", "[MASK0]", "[MASK1]", "[MASK2]")
    data_path = "/home/username/ehr_bert/dataset/eicu.pkl"
    pretrain_data_path = "/home/username/ehr_bert/dataset/eicu_pretrain.pkl"
else:
    raise ValueError("Unknown dataset")


In [45]:
def read_data(args, all_data_path, pretrain_data_path):
    ehr_data = pickle.load(open(all_data_path, 'rb'))
    diag_sentences = ehr_data["icd_code"].values.tolist()
    med_sentences = ehr_data["ndc"].values.tolist()
    lab_sentences = ehr_data["lab_test"].values.tolist()
    if args.dataset == "mimic":
        pro_sentences = ehr_data["pro_code"].values.tolist()
        gender_set = [["M"], ["F"]]
        age_gender_set = [[str(c) + "_" + gender] for c in set(ehr_data["age"].values.tolist()) for gender in ["M", "F"]]
    else:
        pro_sentences = None
        gender_set = [["Female"], ["Male"], ["Unknown"], ["Other"]]
        age_gender_set = [[str(c) + "_" + gender] for c in set(ehr_data["age"].values.tolist()) for gender in ["Female", "Male", "Unknown", "Other"]]
    age_set = [[c] for c in set(ehr_data["age"].values.tolist())]
    
    ehr_pretrain_data = pickle.load(open(pretrain_data_path, 'rb'))
    tokenizer = EHRTokenizer(diag_sentences, med_sentences, lab_sentences, pro_sentences, gender_set, age_set, age_gender_set, special_tokens=args.special_tokens)
    print(f"Tokenzier: \n {tokenizer}")
    if args.dataset == "mimic":
        tokenizer.build_tree()
    dataset = HBERTPretrainEHRDataset(ehr_pretrain_data, tokenizer, token_type=args.predicted_token_type, mask_rate=args.mask_rate)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=batcher(tokenizer, n_token_type=len(args.predicted_token_type)), shuffle=True)
    return tokenizer, dataloader

In [46]:
args

Namespace(seed=0, dataset='mimic', device=0, batch_size=32, lr=2e-05, epochs=50, use_wandb=False, encoder='hi_edge', mask_rate=0.7, anomaly_rate=0.05, anomaly_loss_weight=1, pos_weight=1, num_hidden_layers=5, num_attention_heads=6, attention_probs_dropout_prob=0.2, hidden_dropout_prob=0.2, edge_hidden_size=32, hidden_size=288, intermediate_size=288, gnn_n_heads=1, gnn_temp=1, gat='dotattn', diag_med_emb='tree', max_visit_size=15, predicted_token_type=['diag', 'med', 'pro', 'lab'], mask_token_id={'diag': 3, 'med': 4, 'pro': 5, 'lab': 6}, special_tokens=('[PAD]', '[CLS]', '[SEP]', '[MASK0]', '[MASK1]', '[MASK2]', '[MASK3]'))

In [47]:
data_path

'/data/horse/ws/arsi805e-finetune/Thesis/dataset/mimic.pkl'

In [48]:
tokenizer, dataloader = read_data(args, data_path, pretrain_data_path)
logging.info(f"Loaded data from {pretrain_data_path}")

Tokenzier: 
 <__main__.EHRTokenizer object at 0x1514053127b0>


2025-07-13 17:10:50,347 - INFO - Loaded data from /data/horse/ws/arsi805e-finetune/Thesis/dataset/mimic_pretrain.pkl


In [49]:
set_random_seed(args.seed)

args.vocab_size = 7 + len(tokenizer.diag_voc.idx2word) + \
                  len(tokenizer.pro_voc.idx2word) + \
                  len(tokenizer.med_voc.idx2word) + \
                  len(tokenizer.lab_voc.idx2word) + \
                  len(tokenizer.age_voc.idx2word) + \
                  len(tokenizer.gender_voc.idx2word) + \
                  len(tokenizer.age_gender_voc.idx2word)

args.label_vocab_size = {"diag":len(tokenizer.diag_voc.idx2word), 
                        "pro":len(tokenizer.pro_voc.idx2word), 
                        "med":len(tokenizer.med_voc.idx2word), 
                        "lab":len(tokenizer.lab_voc.idx2word)}  # {token_type: vocab_size}

loss_entity = ["diag", "med", "pro", "lab", "anomaly"]

device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
model = HBERT_Pretrain(args, tokenizer).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
logging.info(f"Initialized model")
print(f"MODEL: \n {model}")

Device: cuda:0


2025-07-13 17:10:50,605 - INFO - Initialized model


MODEL: 
 HBERT_Pretrain(
  (embeddings): TreeEmbeddings(
    (diag_tokens): Embedding(2651, 96)
    (med_tokens): Embedding(226, 96)
    (word_embeddings): Embedding(4398, 288, padding_idx=0)
    (emb_dropout): Dropout(p=0.2, inplace=False)
  )
  (transformer): HiEdgeTransformer(
    (edge_module): EdgeModule(
      (output): Linear(in_features=576, out_features=32, bias=True)
    )
    (transformer_blocks): ModuleList(
      (0-4): 5 x EdgeTransformerBlock(
        (self_attn): MultiHeadEdgeAttention(
          (W_Q): Linear(in_features=288, out_features=288, bias=True)
          (W_K): Linear(in_features=288, out_features=288, bias=True)
          (W_V): Linear(in_features=288, out_features=288, bias=True)
          (W_output): Linear(in_features=576, out_features=288, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
          (W_K_edge): Linear(in_features=32, out_features=32, bias=True)
          (W_edge): Linear(in_features=32, out_features=1, bias=True)
          (W_

In [50]:
import time

if args.use_wandb:
    wandb.init(project="ehr_bert", name="Pretrain-HBERT")
    wandb.config.update(vars(args))
    wandb.watch(model, log='all')

save_path = "/data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model/"
os.makedirs(save_path, exist_ok=True)

for epoch in range(1, 1 + args.epochs):
    train_iter = tqdm(dataloader, ncols=140)
    model.train()
    ave_loss, ave_loss_dict = 0., {token_type: 0. for token_type in loss_entity}

    for step, batch in enumerate(train_iter):
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        loss, loss_dict = model(*batch)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_iter.set_description(f"Epoch:{epoch:03d}, Step:{step:03d}, loss:{loss.item():.4f}")
        ave_loss += loss.item()
        ave_loss_dict = {token_type: ave_loss_dict[token_type] + loss_dict[token_type] for token_type in loss_entity}

    ave_loss /= (step + 1)
    ave_loss_dict = {token_type: ave_loss_dict[token_type] / (step + 1) for token_type in loss_entity}
    print(f"Epoch:{epoch:03d}, average loss:{ave_loss:.4f}")

    if args.use_wandb:
        record_dict = {f"loss": ave_loss}
        record_dict.update({f"loss_{token_type}": ave_loss_dict[token_type] for token_type in loss_entity})
        wandb.log(record_dict)

    if epoch % 5 == 0 or epoch == 1:
        torch.save(model.cpu().state_dict(), f"{save_path}/pretrained_{epoch}.pt")
        logging.info(f"Saved model to {save_path}/pretrained_{epoch}.pt")
        model.to(device)

if args.use_wandb:
    wandb.finish()


  0%|                                                                                                         …

2025-07-13 17:11:26,763 - INFO - Saved model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model//pretrained_1.pt


Epoch:001, average loss:0.1519


  0%|                                                                                                         …

Epoch:002, average loss:0.0821


  0%|                                                                                                         …

Epoch:003, average loss:0.0782


  0%|                                                                                                         …

Epoch:004, average loss:0.0757


  0%|                                                                                                         …

2025-07-13 17:13:30,735 - INFO - Saved model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model//pretrained_5.pt


Epoch:005, average loss:0.0733


  0%|                                                                                                         …

Epoch:006, average loss:0.0714


  0%|                                                                                                         …

Epoch:007, average loss:0.0701


  0%|                                                                                                         …

Epoch:008, average loss:0.0693


  0%|                                                                                                         …

Epoch:009, average loss:0.0682


  0%|                                                                                                         …

2025-07-13 17:16:06,762 - INFO - Saved model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model//pretrained_10.pt


Epoch:010, average loss:0.0675


  0%|                                                                                                         …

Epoch:011, average loss:0.0669


  0%|                                                                                                         …

Epoch:012, average loss:0.0663


  0%|                                                                                                         …

Epoch:013, average loss:0.0658


  0%|                                                                                                         …

Epoch:014, average loss:0.0650


  0%|                                                                                                         …

2025-07-13 17:18:42,832 - INFO - Saved model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model//pretrained_15.pt


Epoch:015, average loss:0.0648


  0%|                                                                                                         …

Epoch:016, average loss:0.0642


  0%|                                                                                                         …

Epoch:017, average loss:0.0639


  0%|                                                                                                         …

Epoch:018, average loss:0.0637


  0%|                                                                                                         …

Epoch:019, average loss:0.0633


  0%|                                                                                                         …

2025-07-13 17:21:18,283 - INFO - Saved model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model//pretrained_20.pt


Epoch:020, average loss:0.0628


  0%|                                                                                                         …

Epoch:021, average loss:0.0625


  0%|                                                                                                         …

Epoch:022, average loss:0.0623


  0%|                                                                                                         …

Epoch:023, average loss:0.0617


  0%|                                                                                                         …

Epoch:024, average loss:0.0616


  0%|                                                                                                         …

2025-07-13 17:23:54,231 - INFO - Saved model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model//pretrained_25.pt


Epoch:025, average loss:0.0614


  0%|                                                                                                         …

Epoch:026, average loss:0.0609


  0%|                                                                                                         …

Epoch:027, average loss:0.0606


  0%|                                                                                                         …

Epoch:028, average loss:0.0605


  0%|                                                                                                         …

Epoch:029, average loss:0.0602


  0%|                                                                                                         …

2025-07-13 17:26:30,660 - INFO - Saved model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model//pretrained_30.pt


Epoch:030, average loss:0.0601


  0%|                                                                                                         …

Epoch:031, average loss:0.0597


  0%|                                                                                                         …

Epoch:032, average loss:0.0595


  0%|                                                                                                         …

Epoch:033, average loss:0.0593


  0%|                                                                                                         …

Epoch:034, average loss:0.0590


  0%|                                                                                                         …

2025-07-13 17:29:06,961 - INFO - Saved model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model//pretrained_35.pt


Epoch:035, average loss:0.0588


  0%|                                                                                                         …

Epoch:036, average loss:0.0586


  0%|                                                                                                         …

Epoch:037, average loss:0.0584


  0%|                                                                                                         …

Epoch:038, average loss:0.0581


  0%|                                                                                                         …

Epoch:039, average loss:0.0579


  0%|                                                                                                         …

2025-07-13 17:31:41,841 - INFO - Saved model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model//pretrained_40.pt


Epoch:040, average loss:0.0578


  0%|                                                                                                         …

Epoch:041, average loss:0.0577


  0%|                                                                                                         …

Epoch:042, average loss:0.0575


  0%|                                                                                                         …

Epoch:043, average loss:0.0575


  0%|                                                                                                         …

Epoch:044, average loss:0.0571


  0%|                                                                                                         …

2025-07-13 17:34:16,874 - INFO - Saved model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model//pretrained_45.pt


Epoch:045, average loss:0.0571


  0%|                                                                                                         …

Epoch:046, average loss:0.0571


  0%|                                                                                                         …

Epoch:047, average loss:0.0568


  0%|                                                                                                         …

Epoch:048, average loss:0.0568


  0%|                                                                                                         …

Epoch:049, average loss:0.0565


  0%|                                                                                                         …

2025-07-13 17:36:52,062 - INFO - Saved model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model//pretrained_50.pt


Epoch:050, average loss:0.0563


# FINETUNING 

In [27]:
import os
import logging
import random
import pickle
import numpy as np
from tqdm.auto import tqdm
from collections import defaultdict
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve
from types import SimpleNamespace

import wandb  # optional, skip if not using

LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)

# ------ Setup args as a Namespace or SimpleNamespace -------
args = SimpleNamespace(
    seed=42,
    dataset="mimic",
    device=0,
    task="next_diag_6m",  # change as needed
    pretrain_epoch=50,
    batch_size=8,
    eval_batch_size=8,
    encoder="hi_edge",
    pretrain_mask_rate=0.5,
    pretrain_anomaly_rate=0.05,
    pretrain_anomaly_loss_weight=1,
    pretrain_pos_weight=1,
    lr=1e-4,
    epochs=10,
    use_wandb=False,  # change as needed
    num_hidden_layers=5,
    num_attention_heads=6,
    attention_probs_dropout_prob=0.2,
    hidden_dropout_prob=0.2,
    edge_hidden_size=32,
    hidden_size=288,
    intermediate_size=288,
    save_model=True,
    gnn_n_heads=1,
    gnn_temp=1.0,
    gat="dotattn",
    diag_med_emb="tree",
    # Add others if needed below
)

args

namespace(seed=42,
          dataset='mimic',
          device=0,
          task='next_diag_6m',
          pretrain_epoch=50,
          batch_size=8,
          eval_batch_size=8,
          encoder='hi_edge',
          pretrain_mask_rate=0.5,
          pretrain_anomaly_rate=0.05,
          pretrain_anomaly_loss_weight=1,
          pretrain_pos_weight=1,
          lr=0.0001,
          epochs=10,
          use_wandb=False,
          num_hidden_layers=5,
          num_attention_heads=6,
          attention_probs_dropout_prob=0.2,
          hidden_dropout_prob=0.2,
          edge_hidden_size=32,
          hidden_size=288,
          intermediate_size=288,
          save_model=True,
          gnn_n_heads=1,
          gnn_temp=1.0,
          gat='dotattn',
          diag_med_emb='tree')

In [28]:
vars(args)

{'seed': 42,
 'dataset': 'mimic',
 'device': 0,
 'task': 'next_diag_6m',
 'pretrain_epoch': 50,
 'batch_size': 8,
 'eval_batch_size': 8,
 'encoder': 'hi_edge',
 'pretrain_mask_rate': 0.5,
 'pretrain_anomaly_rate': 0.05,
 'pretrain_anomaly_loss_weight': 1,
 'pretrain_pos_weight': 1,
 'lr': 0.0001,
 'epochs': 10,
 'use_wandb': False,
 'num_hidden_layers': 5,
 'num_attention_heads': 6,
 'attention_probs_dropout_prob': 0.2,
 'hidden_dropout_prob': 0.2,
 'edge_hidden_size': 32,
 'hidden_size': 288,
 'intermediate_size': 288,
 'save_model': True,
 'gnn_n_heads': 1,
 'gnn_temp': 1.0,
 'gat': 'dotattn',
 'diag_med_emb': 'tree'}

In [29]:
# ----- Set file paths here (modify for your data location) -----
root = "/data/horse/ws/arsi805e-finetune/Thesis"
root_ehr_bert = root + '/ehr_bert'
args.special_tokens = ("[PAD]", "[CLS]", "[SEP]", "[MASK0]", "[MASK1]", "[MASK2]", "[MASK3]")

if args.dataset == "mimic":
    args.predicted_token_type = ["diag", "med", "pro", "lab"]
    all_data_path = f"{root}/dataset/mimic.pkl"
    if args.task == "next_diag_6m":
        finetune_data_path = f"{root}/dataset/mimic_nextdiag_6m.pkl"
    elif args.task == "next_diag_12m":
        finetune_data_path = f"{root}/dataset/mimic_nextdiag_12m.pkl"
    else:    
        finetune_data_path = f"{root}/dataset/mimic_downstream.pkl"
    args.max_visit_size = 15
else:
    args.predicted_token_type = ["diag", "med", "lab"]
    all_data_path = f"{root}/dataset/eicu.pkl"
    finetune_data_path = f"{root}/dataset/eicu_downstream.pkl"
    args.max_visit_size = 24

print("[PATH LOADED SUCCESSFULLY....]")
print(f"MIMIC file loaded from: {all_data_path}")

[PATH LOADED SUCCESSFULLY....]
MIMIC file loaded from: /data/horse/ws/arsi805e-finetune/Thesis/dataset/mimic.pkl


In [30]:
exp_name = (
    "HBERT" +
    "-" + str(args.encoder) +
    "-" + str(args.pretrain_mask_rate) +
    "-" + str(args.pretrain_anomaly_rate) +
    "-" + str(args.pretrain_anomaly_loss_weight) +
    "-" + str(args.pretrain_pos_weight) +
    "-" + str(args.hidden_size) +
    "-" + str(args.edge_hidden_size) +
    "-" + str(args.num_hidden_layers) +
    "-" + str(args.num_attention_heads) +
    "-" + str(args.attention_probs_dropout_prob) +
    "-" + str(args.hidden_dropout_prob) +
    "-" + str(args.intermediate_size) +
    "-" + str(args.gat) +
    "-" + str(args.gnn_n_heads) +
    "-" + str(args.gnn_temp) +
    "-" + str(args.diag_med_emb)
)
exp_name

'HBERT-hi_edge-0.5-0.05-1-1-288-32-5-6-0.2-0.2-288-dotattn-1-1.0-tree'

In [31]:
# ------ Data loading ------
def read_data_FineTune(args, all_data_path, finetune_data_path):
    ehr_data = pickle.load(open(all_data_path, 'rb'))
    diag_sentences = ehr_data["icd_code"].values.tolist()
    med_sentences = ehr_data["ndc"].values.tolist()
    lab_sentences = ehr_data["lab_test"].values.tolist()
    if args.dataset == "mimic":
        pro_sentences = ehr_data["pro_code"].values.tolist()
        gender_set = [["M"], ["F"]]
        age_gender_set = [[str(c) + "_" + gender] for c in set(ehr_data["age"].values.tolist()) for gender in ["M", "F"]]
    else:
        pro_sentences = None
        gender_set = [["Female"], ["Male"], ["Unknown"], ["Other"]]
        age_gender_set = [[str(c) + "_" + gender] for c in set(ehr_data["age"].values.tolist()) for gender in ["Female", "Male", "Unknown", "Other"]]
    age_set = [[c] for c in set(ehr_data["age"].values.tolist())]    

    tokenizer = EHRTokenizer(diag_sentences, med_sentences, lab_sentences, pro_sentences, gender_set, age_set, age_gender_set, special_tokens=args.special_tokens)
    if args.dataset == "mimic":
        tokenizer.build_tree()

    train_data, val_data, test_data = pickle.load(open(finetune_data_path, 'rb'))
    train_dataset = HBERTFinetuneEHRDataset(train_data, tokenizer, token_type=args.predicted_token_type, task=args.task)
    val_dataset = HBERTFinetuneEHRDataset(val_data, tokenizer, token_type=args.predicted_token_type, task=args.task)
    test_dataset = HBERTFinetuneEHRDataset(test_data, tokenizer, token_type=args.predicted_token_type, task=args.task)
    
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=batcher(tokenizer, is_train=False), shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=args.eval_batch_size, collate_fn=batcher(tokenizer, is_train=False), shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=args.eval_batch_size, collate_fn=batcher(tokenizer, is_train=False), shuffle=False)

    return tokenizer, train_dataloader, val_dataloader, test_dataloader

In [45]:
# ----- Helper: evaluation -----
@torch.no_grad()
def evaluate(model, dataloader, device, task_type="binary"):
    model.eval()
    predicted_scores, gt_labels = [], []
    for step, batch in enumerate(dataloader):
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        labels = batch[-1]
        output_logits = model(*batch[:-1])
        predicted_scores.append(output_logits)
        gt_labels.append(labels)
    
    if task_type == "binary":
        predicted_scores = torch.cat(predicted_scores, dim=0).view(-1)
        gt_labels = torch.cat(gt_labels, dim=0).view(-1).cpu().numpy()
        scores = predicted_scores.cpu().numpy()
        predicted_labels = (predicted_scores > 0).float().cpu().numpy()

        precision = (predicted_labels * gt_labels).sum() / (predicted_labels.sum() + 1e-8)
        recall = (predicted_labels * gt_labels).sum() / (gt_labels.sum() + 1e-8)
        f1 = 2 * precision * recall / (precision + recall + 1e-8)
        roc_auc = roc_auc_score(gt_labels, scores)
        precision_curve, recall_curve, _ = precision_recall_curve(gt_labels, scores)
        pr_auc = auc(recall_curve, precision_curve)

        return {"precision":precision, "recall":recall, "f1":f1, "roc_auc":roc_auc, "pr_auc":pr_auc}
    else:
        predicted_scores = torch.cat(predicted_scores, dim=0).cpu()  # [B, -1]
        gt_labels = torch.cat(gt_labels, dim=0).cpu()

        print(f"Predicted Sores: {predicted_scores}")
        print(f"Ground truth labels: {gt_labels}")

        print("----- Some Analysis -----")
        print(f"Predicted score shape: {predicted_scores.shape}")
        print(f"GT label shape: {gt_labels.shape}")

        ave_f1, ave_auc, ave_prauc, ave_recall, ave_precision = [], [], [], [], []
        for i in range(predicted_scores.size(0)):
            scores, labels = predicted_scores[i].squeeze().clone(), gt_labels[i].squeeze().clone()

            predicted_labels = (scores > 0).float().cpu().numpy()
            labels = labels.float().cpu().numpy()
            precision = (predicted_labels * labels).sum() / (predicted_labels.sum() + 1e-8)
            recall = (predicted_labels * labels).sum() / (labels.sum() + 1e-8)
            ave_f1.append(2 * precision * recall / (precision + recall + 1e-8))
            ave_auc.append(roc_auc_score(labels, scores))
            precision_curve, recall_curve, _ = precision_recall_curve(labels, scores)
            ave_prauc.append(auc(recall_curve, precision_curve))
            ave_recall.append(recall)
            ave_precision.append(precision)

        ave_f1, ave_auc, ave_prauc, ave_recall, ave_precision = (
            np.mean(ave_f1), np.mean(ave_auc), np.mean(ave_prauc),
            np.mean(ave_recall), np.mean(ave_precision)
        )
        return {"recall":ave_recall, "precision":ave_precision, "f1":ave_f1, "auc":ave_auc, "prauc":ave_prauc}

In [33]:
# ---- Training setup ----
set_random_seed(args.seed)

tokenizer, train_dataloader, val_dataloader, test_dataloader = read_data_FineTune(args, all_data_path, finetune_data_path)

In [34]:
len(test_dataloader)

406

In [37]:
pretrained_weight_path = "/data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model/" + f"pretrained_{args.pretrain_epoch}.pt"
print(f"Pretrained weights loaded from file: {pretrained_weight_path}")

Pretrained weights loaded from file: /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model/pretrained_50.pt


In [46]:
args.vocab_size = len(args.special_tokens) + \
                  len(tokenizer.diag_voc.idx2word) + \
                  len(tokenizer.pro_voc.idx2word) + \
                  len(tokenizer.med_voc.idx2word) + \
                  len(tokenizer.lab_voc.idx2word) + \
                  len(tokenizer.age_voc.idx2word) + \
                  len(tokenizer.gender_voc.idx2word) + \
                  len(tokenizer.age_gender_voc.idx2word)

args.label_vocab_size = len(tokenizer.diag_voc.idx2word)  # only for diagnosis

device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
model = HBERT_Finetune(args, tokenizer)
print("Model initialized....")
print(f"\nModel: \n {model}")

if args.pretrain_epoch > 0:
    model.load_weight(torch.load(pretrained_weight_path, map_location=device))
    print(f"\nLoaded pretrained model from {pretrained_weight_path}")

finetune_exp_name = f"Finetune-{args.task}-{exp_name}"
save_path = f"{root_ehr_bert}/saved_model/{finetune_exp_name}"
os.makedirs(save_path, exist_ok=True)
print(f"\nFinetuned model saved to: {save_path}")

model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

if args.task in ["death", "stay", "readmission"]:
    eval_metric = "f1"
    task_type = "binary"
    loss_fn = F.binary_cross_entropy_with_logits
else:
    eval_metric = "prauc"
    task_type = "l2r"
    loss_fn = lambda x, y: F.binary_cross_entropy_with_logits(x, y)

print(f"\nTask Type: {task_type}")

if args.use_wandb:
    wandb.init(project="ehr_bert", name=finetune_exp_name)
    wandb.config.update(vars(args))
    wandb.watch(model, log='all')

best_score, best_val_metric, best_test_metric = 0., None, None

for epoch in range(1, 1 + args.epochs):
    train_iter = tqdm(train_dataloader, ncols=140)
    model.train()
    ave_loss = 0.

    for step, batch in enumerate(train_iter):
        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        labels = batch[-1].float()
        output_logits = model(*batch[:-1])

        loss = loss_fn(output_logits.view(-1), labels.view(-1))

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_iter.set_description(f"Epoch:{epoch: 03d}, Step:{step: 03d}, loss:{loss.item():.4f}")
        ave_loss += loss.item()

    ave_loss /= (step + 1)
    val_metric = evaluate(model, val_dataloader, device, task_type=task_type)
    test_metric = evaluate(model, test_dataloader, device, task_type=task_type)
    print(f"Epoch:{epoch: 03d}, average loss:{ave_loss:.4f}")
    print("Val:", val_metric)
    print("Test:", test_metric)

    if val_metric[eval_metric] > best_score:
        best_score = val_metric[eval_metric]
        best_val_metric = val_metric
        best_test_metric = test_metric

    if args.use_wandb:
        record_dict = {f"loss": ave_loss}
        record_dict.update({f"val_{k}":v for k, v in val_metric.items()})
        record_dict.update({f"test_{k}":v for k, v in test_metric.items()})
        wandb.log(record_dict)

    if args.save_model:
        torch.save(model.cpu().state_dict(), f"{save_path}/{epoch}.pt")
        logging.info(f"save model to {save_path}/{epoch}.pt")
        model.to(device)

print("-----------------")
print(f"best val metric: {best_val_metric}")
print(f"best test metric: {best_test_metric}")

if args.use_wandb:
    wandb.finish()

Device: cuda:0
Model initialized....

Model: 
 HBERT_Finetune(
  (embeddings): TreeEmbeddings(
    (diag_tokens): Embedding(2651, 96)
    (med_tokens): Embedding(226, 96)
    (word_embeddings): Embedding(4398, 288, padding_idx=0)
    (emb_dropout): Dropout(p=0.2, inplace=False)
  )
  (loss_fn): BCEWithLogitsLoss()
  (transformer): HiEdgeTransformer(
    (edge_module): EdgeModule(
      (output): Linear(in_features=576, out_features=32, bias=True)
    )
    (transformer_blocks): ModuleList(
      (0-4): 5 x EdgeTransformerBlock(
        (self_attn): MultiHeadEdgeAttention(
          (W_Q): Linear(in_features=288, out_features=288, bias=True)
          (W_K): Linear(in_features=288, out_features=288, bias=True)
          (W_V): Linear(in_features=288, out_features=288, bias=True)
          (W_output): Linear(in_features=576, out_features=288, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
          (W_K_edge): Linear(in_features=32, out_features=32, bias=True)
          (W

  0%|                                                                                                         …

Predicted Sores: tensor([[ -3.6145,  -2.1973,  -3.2359,  ...,  -9.0892,  -9.9497,  -9.9118],
        [ -3.1866,  -1.7978,  -3.0658,  ...,  -8.7547, -10.0428, -10.2816],
        [ -2.8221,  -1.6359,  -2.6502,  ...,  -8.1609,  -9.2824,  -9.3210],
        ...,
        [ -3.6656,  -2.1279,  -3.7517,  ...,  -9.4805, -10.8054, -11.3812],
        [ -3.5529,  -2.2493,  -3.4050,  ...,  -9.0637,  -9.8782,  -9.9200],
        [ -3.6572,  -2.2082,  -3.6477,  ...,  -9.7365, -10.6272, -11.1785]])
Ground truth labels: tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
----- Some Analysis -----
Predicted score shape: torch.Size([3252, 2001])
GT label shape: torch.Size([3252, 2001])
Predicted Sores: tensor([[ -2.7997,  -1.3305,  -3.1107,  ...,  -8.4252,  -9.9410, -10.5253],
        [ -3.5956,  -2.1410,  -3.7670,  ...,  -9.8894, -11.1102, -

2025-07-14 10:24:25,196 - INFO - save model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model/Finetune-next_diag_6m-HBERT-hi_edge-0.5-0.05-1-1-288-32-5-6-0.2-0.2-288-dotattn-1-1.0-tree/1.pt


Epoch: 01, average loss:0.0677
Val: {'recall': np.float32(0.008708145), 'precision': np.float32(0.0997335), 'f1': np.float32(0.015754258), 'auc': np.float64(0.8645676079355915), 'prauc': np.float64(0.1522231524831325)}
Test: {'recall': np.float32(0.008512921), 'precision': np.float32(0.08987784), 'f1': np.float32(0.01518192), 'auc': np.float64(0.862878897148297), 'prauc': np.float64(0.15606537425298653)}


  0%|                                                                                                         …

Predicted Sores: tensor([[ -4.5174,  -3.9595,  -4.6507,  ..., -10.4043, -10.0170, -10.1102],
        [ -3.5042,  -3.0154,  -2.7237,  ...,  -9.4488, -10.4488, -10.4434],
        [ -2.9625,  -2.7444,  -2.4664,  ...,  -8.9354,  -9.9238,  -9.5685],
        ...,
        [ -4.4686,  -3.7752,  -4.3978,  ..., -10.8625, -10.8472, -11.4905],
        [ -4.6091,  -4.0935,  -4.9563,  ..., -10.5095,  -9.7289, -10.2643],
        [ -3.8921,  -3.5555,  -3.9436,  ..., -10.6084, -10.4061, -11.1773]])
Ground truth labels: tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
----- Some Analysis -----
Predicted score shape: torch.Size([3252, 2001])
GT label shape: torch.Size([3252, 2001])
Predicted Sores: tensor([[ -1.8357,  -1.3216,  -1.5965,  ...,  -8.9128, -10.9110, -10.9077],
        [ -3.5857,  -3.5154,  -3.6659,  ..., -10.6778, -11.4703, -

2025-07-14 10:24:53,092 - INFO - save model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model/Finetune-next_diag_6m-HBERT-hi_edge-0.5-0.05-1-1-288-32-5-6-0.2-0.2-288-dotattn-1-1.0-tree/2.pt


Epoch: 02, average loss:0.0345
Val: {'recall': np.float32(0.06911934), 'precision': np.float32(0.42247117), 'f1': np.float32(0.112225026), 'auc': np.float64(0.8726785764803058), 'prauc': np.float64(0.19364868743184)}
Test: {'recall': np.float32(0.072617956), 'precision': np.float32(0.46204922), 'f1': np.float32(0.11807963), 'auc': np.float64(0.8712624830299209), 'prauc': np.float64(0.19678725640231612)}


  0%|                                                                                                         …

Predicted Sores: tensor([[ -6.1494,  -5.0640,  -5.5158,  ..., -13.0290, -13.0021, -12.8992],
        [ -3.7291,  -2.8153,  -2.5991,  ..., -10.0156, -11.0711, -11.0816],
        [ -3.6770,  -3.0100,  -2.7166,  ...,  -9.9831, -11.1621, -10.9490],
        ...,
        [ -5.9657,  -4.8546,  -5.6999,  ..., -11.4388, -11.1506, -11.8699],
        [ -6.4027,  -5.3688,  -6.0335,  ..., -12.7834, -12.3712, -12.7713],
        [ -4.8488,  -4.2636,  -4.6142,  ..., -11.2460, -10.9108, -11.9671]])
Ground truth labels: tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
----- Some Analysis -----
Predicted score shape: torch.Size([3252, 2001])
GT label shape: torch.Size([3252, 2001])
Predicted Sores: tensor([[ -0.7637,   0.3837,   0.3398,  ...,  -9.5742, -11.6979, -12.1421],
        [ -4.2102,  -4.0393,  -4.3584,  ..., -10.4789, -11.1474, -

2025-07-14 10:25:20,806 - INFO - save model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model/Finetune-next_diag_6m-HBERT-hi_edge-0.5-0.05-1-1-288-32-5-6-0.2-0.2-288-dotattn-1-1.0-tree/3.pt


Epoch: 03, average loss:0.0333
Val: {'recall': np.float32(0.08458793), 'precision': np.float32(0.45441583), 'f1': np.float32(0.13507324), 'auc': np.float64(0.8814322489961537), 'prauc': np.float64(0.2188281335101017)}
Test: {'recall': np.float32(0.086001836), 'precision': np.float32(0.4751101), 'f1': np.float32(0.13793363), 'auc': np.float64(0.8792959206077511), 'prauc': np.float64(0.22236822287195504)}


  0%|                                                                                                         …

Predicted Sores: tensor([[ -6.4630,  -5.3724,  -5.9918,  ..., -13.2629, -13.7330, -12.7194],
        [ -3.6167,  -2.5673,  -3.3195,  ..., -10.7131, -11.9107, -11.8995],
        [ -3.4000,  -2.6081,  -3.3774,  ..., -11.1285, -12.8326, -12.3716],
        ...,
        [ -7.3594,  -5.7484,  -7.1348,  ..., -11.8198, -12.0070, -11.5319],
        [ -7.4424,  -5.9737,  -6.8388,  ..., -13.0569, -13.4256, -13.2712],
        [ -6.1815,  -5.0549,  -6.4380,  ..., -11.1038, -11.3826, -11.5499]])
Ground truth labels: tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
----- Some Analysis -----
Predicted score shape: torch.Size([3252, 2001])
GT label shape: torch.Size([3252, 2001])
Predicted Sores: tensor([[  0.3253,   1.2449,  -0.4281,  ..., -10.7470, -13.8855, -13.1709],
        [ -6.0822,  -5.4594,  -7.1142,  ..., -11.1099, -11.4804, -

2025-07-14 10:25:48,418 - INFO - save model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model/Finetune-next_diag_6m-HBERT-hi_edge-0.5-0.05-1-1-288-32-5-6-0.2-0.2-288-dotattn-1-1.0-tree/4.pt


Epoch: 04, average loss:0.0323
Val: {'recall': np.float32(0.1006805), 'precision': np.float32(0.5277901), 'f1': np.float32(0.16032055), 'auc': np.float64(0.8873185467592815), 'prauc': np.float64(0.239293042546235)}
Test: {'recall': np.float32(0.10334113), 'precision': np.float32(0.55508846), 'f1': np.float32(0.16482471), 'auc': np.float64(0.8851736323923826), 'prauc': np.float64(0.24284987175139472)}


  0%|                                                                                                         …

Predicted Sores: tensor([[ -6.8582,  -6.4077,  -5.7406,  ..., -14.5964, -13.6856, -12.8219],
        [ -4.5635,  -3.7417,  -3.2570,  ..., -11.5918, -11.9174, -12.4433],
        [ -4.7443,  -4.2258,  -3.9980,  ..., -12.4884, -13.3430, -13.5984],
        ...,
        [ -7.6120,  -5.8315,  -6.5352,  ..., -12.0507, -11.5663, -11.0707],
        [ -8.0457,  -6.8380,  -6.5165,  ..., -14.3495, -13.6647, -13.9722],
        [ -6.0326,  -5.0253,  -6.0140,  ..., -10.6257, -10.9013, -11.0606]])
Ground truth labels: tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
----- Some Analysis -----
Predicted score shape: torch.Size([3252, 2001])
GT label shape: torch.Size([3252, 2001])
Predicted Sores: tensor([[  0.3250,   0.7672,  -0.1590,  ..., -10.7320, -14.1912, -12.8502],
        [ -6.6763,  -5.6684,  -7.0010,  ..., -11.5323, -11.5066, -

2025-07-14 10:26:16,044 - INFO - save model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model/Finetune-next_diag_6m-HBERT-hi_edge-0.5-0.05-1-1-288-32-5-6-0.2-0.2-288-dotattn-1-1.0-tree/5.pt


Epoch: 05, average loss:0.0314
Val: {'recall': np.float32(0.12590444), 'precision': np.float32(0.5533016), 'f1': np.float32(0.19400993), 'auc': np.float64(0.8922930981161225), 'prauc': np.float64(0.24871936660572166)}
Test: {'recall': np.float32(0.12944558), 'precision': np.float32(0.57222706), 'f1': np.float32(0.19943379), 'auc': np.float64(0.8906072863019037), 'prauc': np.float64(0.255244870431211)}


  0%|                                                                                                         …

Predicted Sores: tensor([[ -5.5154,  -5.1422,  -5.0709,  ..., -14.0520, -13.4623, -12.9973],
        [ -4.2153,  -2.8065,  -3.3860,  ..., -11.0823, -11.4728, -12.1030],
        [ -4.8364,  -3.6759,  -4.5046,  ..., -12.7320, -13.7594, -14.6120],
        ...,
        [ -7.9207,  -5.2851,  -7.0926,  ..., -11.7644, -11.5949, -11.3437],
        [ -7.8884,  -6.2291,  -6.4677,  ..., -14.2223, -13.8600, -15.0756],
        [ -5.6288,  -4.4216,  -5.9800,  ...,  -9.8597, -10.7577, -11.5900]])
Ground truth labels: tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
----- Some Analysis -----
Predicted score shape: torch.Size([3252, 2001])
GT label shape: torch.Size([3252, 2001])
Predicted Sores: tensor([[  0.9446,   1.7681,   0.0835,  ..., -10.8267, -14.3889, -13.1325],
        [ -7.2166,  -5.4313,  -7.3616,  ..., -12.4648, -11.4793, -

2025-07-14 10:26:43,599 - INFO - save model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model/Finetune-next_diag_6m-HBERT-hi_edge-0.5-0.05-1-1-288-32-5-6-0.2-0.2-288-dotattn-1-1.0-tree/6.pt


Epoch: 06, average loss:0.0307
Val: {'recall': np.float32(0.13341683), 'precision': np.float32(0.5620838), 'f1': np.float32(0.20299445), 'auc': np.float64(0.8959654876688778), 'prauc': np.float64(0.25913739480396936)}
Test: {'recall': np.float32(0.13657007), 'precision': np.float32(0.583517), 'f1': np.float32(0.20824853), 'auc': np.float64(0.8941724868521465), 'prauc': np.float64(0.2646814254769317)}


  0%|                                                                                                         …

Predicted Sores: tensor([[ -5.5977,  -5.0665,  -4.3719,  ..., -14.6600, -14.0313, -12.9975],
        [ -4.6273,  -2.6414,  -3.6711,  ..., -11.7055, -12.0137, -12.7368],
        [ -5.1199,  -3.9296,  -5.1013,  ..., -13.6919, -15.0952, -15.6021],
        ...,
        [ -9.2480,  -5.6272,  -7.6742,  ..., -12.1386, -11.9606, -11.3120],
        [ -7.8225,  -5.5798,  -5.3377,  ..., -13.7431, -14.5117, -15.1996],
        [ -5.2204,  -4.0100,  -5.5247,  ...,  -9.7020, -11.3757, -11.8162]])
Ground truth labels: tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
----- Some Analysis -----
Predicted score shape: torch.Size([3252, 2001])
GT label shape: torch.Size([3252, 2001])
Predicted Sores: tensor([[  0.1151,   0.8308,  -1.8219,  ..., -12.1971, -16.4017, -14.8391],
        [ -8.7486,  -6.0036,  -8.2906,  ..., -13.5704, -12.4306, -

2025-07-14 10:27:11,283 - INFO - save model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model/Finetune-next_diag_6m-HBERT-hi_edge-0.5-0.05-1-1-288-32-5-6-0.2-0.2-288-dotattn-1-1.0-tree/7.pt


Epoch: 07, average loss:0.0301
Val: {'recall': np.float32(0.143976), 'precision': np.float32(0.5609398), 'f1': np.float32(0.21533673), 'auc': np.float64(0.8988223543471447), 'prauc': np.float64(0.2691411556489929)}
Test: {'recall': np.float32(0.14674005), 'precision': np.float32(0.58894706), 'f1': np.float32(0.22038631), 'auc': np.float64(0.8971939317801387), 'prauc': np.float64(0.27315512580137175)}


  0%|                                                                                                         …

Predicted Sores: tensor([[ -7.1426,  -5.7206,  -4.7859,  ..., -16.1973, -14.5680, -13.8038],
        [ -4.9845,  -2.4967,  -3.2662,  ..., -10.5924, -10.5996, -11.3280],
        [ -5.3998,  -4.0680,  -5.0065,  ..., -13.9324, -14.7439, -15.6854],
        ...,
        [ -9.9958,  -5.5687,  -8.5907,  ..., -12.7839, -13.1363, -11.9336],
        [ -8.6132,  -5.8908,  -5.8094,  ..., -15.6885, -15.9015, -16.8325],
        [ -6.5986,  -5.1613,  -6.5606,  ..., -10.7697, -11.7442, -12.4037]])
Ground truth labels: tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
----- Some Analysis -----
Predicted score shape: torch.Size([3252, 2001])
GT label shape: torch.Size([3252, 2001])
Predicted Sores: tensor([[  0.5597,   1.2163,  -0.7397,  ..., -11.3831, -14.9858, -13.4319],
        [ -7.4666,  -4.5903,  -6.2562,  ..., -14.0231, -13.1399, -

2025-07-14 10:27:38,846 - INFO - save model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model/Finetune-next_diag_6m-HBERT-hi_edge-0.5-0.05-1-1-288-32-5-6-0.2-0.2-288-dotattn-1-1.0-tree/8.pt


Epoch: 08, average loss:0.0295
Val: {'recall': np.float32(0.14811444), 'precision': np.float32(0.56238574), 'f1': np.float32(0.22082554), 'auc': np.float64(0.9008671707150797), 'prauc': np.float64(0.27319878001191544)}
Test: {'recall': np.float32(0.15123168), 'precision': np.float32(0.5896405), 'f1': np.float32(0.22708263), 'auc': np.float64(0.8995144117942292), 'prauc': np.float64(0.2801306007080481)}


  0%|                                                                                                         …

Predicted Sores: tensor([[ -7.3699,  -6.0922,  -4.6081,  ..., -16.1994, -14.4121, -13.9667],
        [ -5.5936,  -2.8883,  -3.0530,  ..., -11.4291, -10.9953, -11.9799],
        [ -6.1493,  -4.6486,  -5.5485,  ..., -14.0388, -14.9354, -16.2184],
        ...,
        [-11.1010,  -6.6787,  -9.0507,  ..., -13.9647, -13.4843, -12.6927],
        [ -9.5797,  -6.6748,  -6.2637,  ..., -15.4299, -15.5048, -17.1078],
        [ -7.9207,  -6.3729,  -7.4805,  ..., -11.5065, -12.1186, -13.4430]])
Ground truth labels: tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
----- Some Analysis -----
Predicted score shape: torch.Size([3252, 2001])
GT label shape: torch.Size([3252, 2001])
Predicted Sores: tensor([[  0.5587,   1.6716,  -0.1856,  ..., -11.8280, -15.2589, -13.9680],
        [ -7.9994,  -4.5343,  -6.3284,  ..., -15.3845, -13.8694, -

2025-07-14 10:28:06,538 - INFO - save model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model/Finetune-next_diag_6m-HBERT-hi_edge-0.5-0.05-1-1-288-32-5-6-0.2-0.2-288-dotattn-1-1.0-tree/9.pt


Epoch: 09, average loss:0.0289
Val: {'recall': np.float32(0.1573699), 'precision': np.float32(0.55600625), 'f1': np.float32(0.2306454), 'auc': np.float64(0.9030000054062169), 'prauc': np.float64(0.27744277263635647)}
Test: {'recall': np.float32(0.16003361), 'precision': np.float32(0.58153456), 'f1': np.float32(0.23553608), 'auc': np.float64(0.900786546700267), 'prauc': np.float64(0.28311852886060784)}


  0%|                                                                                                         …

Predicted Sores: tensor([[ -7.1189,  -6.1312,  -5.0557,  ..., -16.2848, -13.3306, -13.2514],
        [ -6.3592,  -3.5092,  -4.3789,  ..., -11.3406, -10.9547, -11.8675],
        [ -7.1041,  -5.4248,  -6.3510,  ..., -14.7640, -15.3594, -16.7681],
        ...,
        [-11.2289,  -7.0510,  -9.9360,  ..., -13.5963, -13.3008, -12.8809],
        [ -9.2740,  -6.3913,  -6.1975,  ..., -14.8917, -14.2556, -15.9404],
        [ -8.3528,  -6.9007,  -8.1331,  ..., -11.4716, -11.7346, -13.1652]])
Ground truth labels: tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
----- Some Analysis -----
Predicted score shape: torch.Size([3252, 2001])
GT label shape: torch.Size([3252, 2001])
Predicted Sores: tensor([[ -0.4897,  -0.1799,  -2.1954,  ..., -12.3655, -15.6033, -13.3512],
        [ -8.0335,  -5.7194,  -7.0833,  ..., -15.9088, -14.4768, -

2025-07-14 10:28:34,455 - INFO - save model to /data/horse/ws/arsi805e-finetune/Thesis/ehr_bert/saved_model/Finetune-next_diag_6m-HBERT-hi_edge-0.5-0.05-1-1-288-32-5-6-0.2-0.2-288-dotattn-1-1.0-tree/10.pt


Epoch: 10, average loss:0.0284
Val: {'recall': np.float32(0.15291238), 'precision': np.float32(0.57836694), 'f1': np.float32(0.22681643), 'auc': np.float64(0.9040776082079067), 'prauc': np.float64(0.28434501990170685)}
Test: {'recall': np.float32(0.155953), 'precision': np.float32(0.60882443), 'f1': np.float32(0.23367436), 'auc': np.float64(0.9024004351060698), 'prauc': np.float64(0.2899076529444907)}
-----------------
best val metric: {'recall': np.float32(0.15291238), 'precision': np.float32(0.57836694), 'f1': np.float32(0.22681643), 'auc': np.float64(0.9040776082079067), 'prauc': np.float64(0.28434501990170685)}
best test metric: {'recall': np.float32(0.155953), 'precision': np.float32(0.60882443), 'f1': np.float32(0.23367436), 'auc': np.float64(0.9024004351060698), 'prauc': np.float64(0.2899076529444907)}


In [40]:
wandb.init()

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33masn5898[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## TEXT MODALITY

In [None]:
import pandas as pd

# Load EHR and Notes
ehr_df = pd.read_pickle("mimic.pkl")       # or read_csv/read_feather/...
notes_df = pd.read_csv("cleaned_notes.csv")  # or whichever format you use



In [None]:
# Ensure hadm_id is string or int in both
ehr_df['hadm_id'] = ehr_df['hadm_id'].astype(str)
notes_df['hadm_id'] = notes_df['hadm_id'].astype(str)

# Merge
merged_df = ehr_df.merge(notes_df, on="hadm_id", how="left")