In [1]:
import pandas as pd
import numpy as np
import torch
import os
import psycopg2
from sqlalchemy import create_engine 
import string
import spacy
import re
from datetime import date, datetime, timedelta
import random
#from sklearn.model_selection import GroupShuffleSplit, StratifiedShuffleSplit

gpu_access = torch.cuda.is_available()

In [3]:
# connect to the mimic database and set the search path to the 'mimiciii' schema

dbschema='mimiciii'
cnx = create_engine('postgresql+psycopg2://aa5118:mimic@localhost:5432/mimic',
                    connect_args={'options': '-csearch_path={}'.format(dbschema)})


In [9]:
# main dataframe - join 'patients' to 'noteevents' and only look at adults (>=15yo)

df_main = pd.read_sql_query('''
  SELECT
      p.subject_id, p.dob, p.gender,
      n.category, n.chartdate, n.row_id, n.charttime,
      ROUND((cast(chartdate as date) - cast(dob as date)) / 365.242,0)
          AS age_at_noteevent,
      n.text
  FROM patients p 
  INNER JOIN noteevents n 
  ON p.subject_id = n.subject_id
  WHERE ROUND((cast(chartdate as date) - cast(dob as date)) / 365.242,0) > 14
  ORDER BY subject_id
  --LIMIT 1000;
''', cnx)

print(df_main.shape)
df_main.head()

(1657776, 9)


Unnamed: 0,subject_id,dob,gender,category,chartdate,row_id,charttime,age_at_noteevent,text
0,3,2025-04-11,M,Radiology,2101-10-17,768523,2101-10-17 08:35:00,77.0,[**2101-10-17**] 8:35 AM\n ART DUP EXT LO UNI;...
1,3,2025-04-11,M,Radiology,2101-10-22,768950,2101-10-22 16:27:00,77.0,[**2101-10-22**] 4:27 PM\n CHEST (PORTABLE AP)...
2,3,2025-04-11,M,Radiology,2101-10-20,768808,2101-10-20 17:49:00,77.0,[**2101-10-20**] 5:49 PM\n CT ABDOMEN W/O CONT...
3,3,2025-04-11,M,Radiology,2101-10-21,768885,2101-10-21 16:43:00,77.0,[**2101-10-21**] 4:43 PM\n CHEST (PORTABLE AP)...
4,3,2025-04-11,M,Radiology,2101-10-24,769043,2101-10-24 08:05:00,77.0,[**2101-10-24**] 8:05 AM\n CHEST (PORTABLE AP)...


In [10]:
# function to preprocess the text from the 'noteevents' table and tokenise using the spaCy tokenizer

nlp = spacy.load('en')
regex = re.compile(r'([0-9])-([0-9][0-9]?)-([0-9])')
counter = 0

def tokenise_text(text):
    global counter
    
    text = regex.sub(r'\1/\2/\3',text)
    text = text.replace("[**","[").replace("**]","]")
    
    #text = text.lower()
    tokens = nlp.tokenizer(str(text))
    tokenised_text = ""
    
    for token in tokens:
        tokenised_text = tokenised_text + str(token) + " "
    
    tokenised_text = tokenised_text.replace("\n"," <par> ").replace("\r"," <par> ")
    tokenised_text = ' '.join(tokenised_text.split())
    
    counter += 1
    if (counter % 10000) == 0:
        print (counter)
    
    return tokenised_text

In [11]:
# apply tokenising function

df_main["text"] = df_main["text"].apply(tokenise_text)
df_main.head()

10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000
160000
170000
180000
190000
200000
210000
220000
230000
240000
250000
260000
270000
280000
290000
300000
310000
320000
330000
340000
350000
360000
370000
380000
390000
400000
410000
420000
430000
440000
450000
460000
470000
540000
550000
560000
570000
630000
640000
650000
660000
670000
680000
690000
700000
710000
720000
730000
740000
750000
760000
770000
780000
790000
800000
810000
820000
830000
840000
850000
860000
870000
880000
890000
900000
910000
920000
930000
940000
950000
960000
970000
980000
990000
1000000
1010000
1020000
1030000
1040000
1050000
1060000
1070000
1080000
1090000
1100000
1110000
1120000
1130000
1140000
1150000
1160000
1170000
1180000
1190000
1200000
1210000
1220000
1230000
1240000
1250000
1260000
1270000
1280000
1290000
1300000
1310000
1320000
1330000
1340000
1350000
1360000
1370000
1380000
1390000
1400000
1410000
1420000
1430000
1440000
1450000
1460000
1470000
1480000
1

Unnamed: 0,subject_id,dob,gender,category,chartdate,row_id,charttime,age_at_noteevent,text
0,3,2025-04-11,M,Radiology,2101-10-17,768523,2101-10-17 08:35:00,77.0,[ 2101/10/17 ] 8:35 AM <par> ART DUP EXT LO UN...
1,3,2025-04-11,M,Radiology,2101-10-22,768950,2101-10-22 16:27:00,77.0,[ 2101/10/22 ] 4:27 PM <par> CHEST ( PORTABLE ...
2,3,2025-04-11,M,Radiology,2101-10-20,768808,2101-10-20 17:49:00,77.0,[ 2101/10/20 ] 5:49 PM <par> CT ABDOMEN W / O ...
3,3,2025-04-11,M,Radiology,2101-10-21,768885,2101-10-21 16:43:00,77.0,[ 2101/10/21 ] 4:43 PM <par> CHEST ( PORTABLE ...
4,3,2025-04-11,M,Radiology,2101-10-24,769043,2101-10-24 08:05:00,77.0,[ 2101/10/24 ] 8:05 AM <par> CHEST ( PORTABLE ...


In [12]:
# use the first n tokens of the text as a hint
counter = 0
def produce_hint(text):
    global counter
    l = text.split()
    counter += 1
    if (counter % 10000) == 0:
        print (counter)
    return ' '.join(l[:10]) # first 10 tokens

df_main['hint'] = df_main['text'].map(lambda x: produce_hint(x))

print(df_main.shape)
df_main.head()

10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000
160000
170000
180000
190000
200000
210000
220000
230000
240000
250000
260000
270000
280000
290000
300000
310000
320000
330000
340000
350000
360000
370000
380000
390000
400000
410000
420000
430000
440000
450000
460000
470000
480000
490000
500000
510000
520000
530000
540000
550000
560000
570000
580000
590000
600000
610000
620000
630000
640000
650000
660000
670000
680000
690000
700000
710000
720000
730000
740000
750000
760000
770000
780000
790000
800000
810000
820000
830000
840000
850000
860000
870000
880000
890000
900000
910000
920000
930000
940000
950000
960000
970000
980000
990000
1000000
1010000
1020000
1030000
1040000
1050000
1060000
1070000
1080000
1090000
1100000
1110000
1120000
1130000
1140000
1150000
1160000
1170000
1180000
1190000
1200000
1210000
1220000
1230000
1240000
1250000
1260000
1270000
1280000
1290000
1300000
1310000
1320000
1330000
1340000
1350000
1360000
1370000
1380000
1390

Unnamed: 0,subject_id,dob,gender,category,chartdate,row_id,charttime,age_at_noteevent,text,hint
0,3,2025-04-11,M,Radiology,2101-10-17,768523,2101-10-17 08:35:00,77.0,[ 2101/10/17 ] 8:35 AM <par> ART DUP EXT LO UN...,[ 2101/10/17 ] 8:35 AM <par> ART DUP EXT LO
1,3,2025-04-11,M,Radiology,2101-10-22,768950,2101-10-22 16:27:00,77.0,[ 2101/10/22 ] 4:27 PM <par> CHEST ( PORTABLE ...,[ 2101/10/22 ] 4:27 PM <par> CHEST ( PORTABLE AP
2,3,2025-04-11,M,Radiology,2101-10-20,768808,2101-10-20 17:49:00,77.0,[ 2101/10/20 ] 5:49 PM <par> CT ABDOMEN W / O ...,[ 2101/10/20 ] 5:49 PM <par> CT ABDOMEN W /
3,3,2025-04-11,M,Radiology,2101-10-21,768885,2101-10-21 16:43:00,77.0,[ 2101/10/21 ] 4:43 PM <par> CHEST ( PORTABLE ...,[ 2101/10/21 ] 4:43 PM <par> CHEST ( PORTABLE AP
4,3,2025-04-11,M,Radiology,2101-10-24,769043,2101-10-24 08:05:00,77.0,[ 2101/10/24 ] 8:05 AM <par> CHEST ( PORTABLE ...,[ 2101/10/24 ] 8:05 AM <par> CHEST ( PORTABLE AP


In [13]:
# patients above 89 years of age had their dob modified to be 300 years old at time of first event for privacy reasons
# change their age to instead be 90

df_main.loc[df_main['age_at_noteevent'] > 200, 'age_at_noteevent'] = 90

In [14]:
# lab items data

df_labitems = pd.read_sql_query('''
  SELECT l.subject_id, l.charttime, l.value, l.valueuom, l.flag, d.label
  FROM labevents l
  INNER JOIN d_labitems d 
  USING (itemid)
  ORDER BY l.subject_id
  --LIMIT 20;
''', cnx)

print(df_labitems.shape)
df_labitems.head()

(27854055, 6)


Unnamed: 0,subject_id,charttime,value,valueuom,flag,label
0,2,2138-07-17 20:48:00,0,%,,Atypical Lymphocytes
1,2,2138-07-17 20:48:00,0,%,,Bands
2,2,2138-07-17 20:48:00,0,%,,Basophils
3,2,2138-07-17 20:48:00,0,%,,Eosinophils
4,2,2138-07-17 20:48:00,0,%,abnormal,Hematocrit


In [15]:
# prescriptions data

df_prescriptions = pd.read_sql_query('''
  SELECT subject_id, startdate, enddate, drug, prod_strength
  FROM prescriptions
  ORDER BY subject_id
  --LIMIT 20;
''', cnx)

print(df_prescriptions.shape)
df_prescriptions.head()

(4156450, 5)


Unnamed: 0,subject_id,startdate,enddate,drug,prod_strength
0,2,2138-07-18,2138-07-20,NEO*IV*Gentamicin,10mg/mL-2mL
1,2,2138-07-18,2138-07-20,Syringe (Neonatal) *D5W*,1 Syringe
2,2,2138-07-18,2138-07-21,Ampicillin Sodium,500mg Vial
3,2,2138-07-18,2138-07-21,Send 500mg Vial,Send 500mg Vial
4,4,2191-03-16,2191-03-23,Guaifenesin-Codeine Phosphate,5ML UDCUP


In [16]:
#%%timeit -n 3 -r 3

# Split the dataset in a grouped and stratified manner

def StratifiedGroupShuffleSplit(df_main):

    df_main = df_main.reindex(np.random.permutation(df_main.index)) # shuffle dataset
    
    # create empty train, val and test datasets
    df_train = pd.DataFrame()
    df_val = pd.DataFrame()
    df_test = pd.DataFrame()

    hparam_mse_wgt = 0.1 # must be between 0 and 1
    assert(0 <= hparam_mse_wgt <= 1)
    train_proportion = 0.8 # must be between 0 and 1
    assert(0 <= train_proportion <= 1)
    val_test_proportion = (1-train_proportion)/2

    subject_grouped_df_main = df_main.groupby(['subject_id'], sort=False, as_index=False)
    category_grouped_df_main = df_main.groupby('category').count()[['subject_id']]/len(df_main)*100 
    
    # function to calculate loss
    def calc_mse_loss(df):
        grouped_df = df.groupby('category').count()[['subject_id']]/len(df)*100
        df_temp = category_grouped_df_main.join(grouped_df, on = 'category', how = 'left', lsuffix = '_main')
        df_temp.fillna(0, inplace=True)
        df_temp['diff'] = (df_temp['subject_id_main'] - df_temp['subject_id'])**2
        mse_loss = np.mean(df_temp['diff'])
        return mse_loss
    
    directory = "/mimic/data/preprocessed/"
    
    f_train = open(directory + "src-train.txt","w+")
    f_val = open(directory + "src-val.txt","w+")
    f_test = open(directory + "src-test.txt","w+")
    
    len_train = 0
    len_val = 0
    len_test = 0
    total_records = 0
    i = 0

    # loop the groups of subjects one by one
    for _, group in subject_grouped_df_main:

        total_records = len_train + len_val + len_test
        g = pd.DataFrame(group)
        subject_id = g['subject_id'].iloc[0]
        
        pre_left = df_prescriptions['subject_id'].searchsorted(subject_id, 'left')
        pre_right = df_prescriptions['subject_id'].searchsorted(subject_id, 'right')
        
        lab_left = df_labitems['subject_id'].searchsorted(subject_id, 'left')
        lab_right = df_labitems['subject_id'].searchsorted(subject_id, 'right')
        
        g_prescriptions = df_prescriptions[pre_left:pre_right]
        g_labitems = df_labitems[lab_left:lab_right]
        i += 1
        
        train = False
        val = False
        test = False
        
        # first three groups only
        if (i < 4):
            if (i == 1):
                df_train = df_train.append(g, ignore_index=True)
                len_train += len(g)
                train = True
            elif (i == 2):
                df_val = df_val.append(g, ignore_index=True)
                len_val += len(g)
                val = True
            else:
                df_test = df_test.append(g, ignore_index=True)
                len_test += len(g)
                test = True
        
        # all the other groups except every 500th
        if ((i % 1000 != 0) & (i > 3)):
            
            if (train_proportion > (len_train/total_records)):
                df_train = df_train.append(g, ignore_index=True)
                len_train += len(g)
                train = True
            elif (val_test_proportion > (len_val/total_records)):
                df_val = df_val.append(g, ignore_index=True)
                len_val += len(g)
                val = True
            else:
                df_test = df_test.append(g, ignore_index=True)
                len_test += len(g)
                test = True
        
        # every 500th group, balance the groups by proportion and by categories
        elif (i % 1000 == 0):
            
            mse_loss_diff_train = calc_mse_loss(df_train) - calc_mse_loss(df_train.append(g, ignore_index=True))
            mse_loss_diff_val = calc_mse_loss(df_val) - calc_mse_loss(df_val.append(g, ignore_index=True))
            mse_loss_diff_test = calc_mse_loss(df_test) - calc_mse_loss(df_test.append(g, ignore_index=True))

            len_diff_train = (train_proportion - (len_train/total_records))
            len_diff_val = (val_test_proportion - (len_val/total_records))
            len_diff_test = (val_test_proportion - (len_test/total_records)) 

            len_loss_diff_train = len_diff_train * abs(len_diff_train)
            len_loss_diff_val = len_diff_val * abs(len_diff_val)
            len_loss_diff_test = len_diff_test * abs(len_diff_test)

            loss_train = (hparam_mse_wgt * mse_loss_diff_train) + ((1-hparam_mse_wgt) * len_loss_diff_train)
            loss_val = (hparam_mse_wgt * mse_loss_diff_val) + ((1-hparam_mse_wgt) * len_loss_diff_val)
            loss_test = (hparam_mse_wgt * mse_loss_diff_test) + ((1-hparam_mse_wgt) * len_loss_diff_test)

            if (max(loss_train,loss_val,loss_test) == loss_train):
                df_train = df_train.append(g, ignore_index=True)
                len_train += len(g)
                train = True
            elif (max(loss_train,loss_val,loss_test) == loss_val):
                df_val = df_val.append(g, ignore_index=True)
                len_val += len(g)
                val = True
            else:
                df_test = df_test.append(g, ignore_index=True)
                len_test += len(g)
                test = True
            
            print ("Group " + str(i) + ". loss_train: " + str(loss_train) + " | " + "loss_val: " + str(loss_val) + " | " + "loss_test: " + str(loss_test) + " | ")
        
        # loop through every row in the group to get relevant prescriptions and lab items before appending to file
        for j, row in enumerate(g.itertuples()):
            
            charttime = row[7]
            chartdate = datetime.combine(row[5], datetime.min.time())
            category = str(row[4])

            if (pd.isna(charttime)):
                if (category == "Discharge summary"):
                    cutoff = chartdate
                    chartdate = cutoff + timedelta(days=1)
                else:
                    cutoff = chartdate - timedelta(days=1)
                
                lab_condition = np.logical_and((g_labitems.charttime >= cutoff),
                                               (g_labitems.charttime < chartdate))
                lab_items = g_labitems[lab_condition]

            else:
                cutoff = charttime - timedelta(days=1)
                lab_condition = np.logical_and((g_labitems.charttime >= cutoff),
                                               (g_labitems.charttime < charttime))
                lab_items = g_labitems[lab_condition]
            
            pre_condition = np.logical_and((g_prescriptions.startdate >= cutoff),
                                           (g_prescriptions.startdate < chartdate))
            prescriptions = g_prescriptions[pre_condition]

            lab_items_list = ""
            lab_items_length = len(lab_items)
            if (lab_items_length > 0):
                for k, lab_row in enumerate(lab_items.itertuples()):
                    flag = ""
                    if (pd.isna(lab_row[5]) == False):
                        flag = " , " + str(lab_row[5])

                    lab_items_list += str(lab_row[6]) + " , " + str(lab_row[3]) + " , " + str(lab_row[4]) + flag
                    if (k != (lab_items_length - 1)):
                        lab_items_list += " | "

            prescriptions_list = ""
            prescriptions_length = len(prescriptions)
            if (prescriptions_length > 0):
                for k, pre_row in enumerate(prescriptions.itertuples()):
                    prescriptions_list += str(pre_row[4]) + " , " + str(pre_row[5])
                    if (k != (prescriptions_length - 1)):
                        prescriptions_list += " | "
            
            if (train == True):
                f_train.write(str(row[10]) + " <H> " + str(row[4]) + " <T> " + str(row[3]) + " <G> " + str(row[8]) + " <A> " + 
                    prescriptions_list + " <0> " + lab_items_list + " <1>" + "\n")
            elif (val == True):
                f_val.write(str(row[10]) + " <H> " + str(row[4]) + " <T> " + str(row[3]) + " <G> " + str(row[8]) + " <A> " + 
                    prescriptions_list + " <0> " + lab_items_list + " <1>" + "\n")
            else:
                f_test.write(str(row[10]) + " <H> " + str(row[4]) + " <T> " + str(row[3]) + " <G> " + str(row[8]) + " <A> " + 
                    prescriptions_list + " <0> " + lab_items_list + " <1>" + "\n")
        
        if (i % 100 == 0):
            print (i)
    
    f_train.close()
    f_val.close()
    f_test.close()
    
    return df_train, df_val, df_test

src_train, src_val, src_test = StratifiedGroupShuffleSplit(df_main)

100
200
300
400
500
600
700
800
900
Group 1000. loss_train: -0.009780177877207419 | loss_val: -0.06859242740061777 | loss_test: 0.025535763483189013 | 
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
Group 2000. loss_train: 0.0010680695764008696 | loss_val: 0.01235420354373026 | loss_test: 0.004414589996162895 | 
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
Group 3000. loss_train: 0.0004236930892738452 | loss_val: 0.0033952138924803503 | loss_test: 0.003301266507898932 | 
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
Group 4000. loss_train: 0.0006383022498184798 | loss_val: 0.006316876077024962 | loss_test: 0.007849462263128219 | 
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
Group 5000. loss_train: 0.000779708757901185 | loss_val: 0.00940934270178325 | loss_test: 0.01018220345436419 | 
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
Group 6000. loss_train: 0.00015105327689797924 | loss_val: 0.0016135225487527564 | loss_test: 0.0016726088544993082 | 
6000
6100
6200
630

In [19]:
# INSPECT STRATIFICATION

df = src_val #  change to src_train/src_test/src_val to inspect length and stratification
print (len(df))

category_grouped_df_main = df_main.groupby('category').count()[['subject_id']]/len(df_main)*100
grouped_df = df.groupby('category').count()[['subject_id']]/len(df)*100
df_temp = category_grouped_df_main.join(grouped_df, on = 'category', how = 'left', lsuffix = '_main')
df_temp.fillna(0, inplace=True)
df_temp

165778


Unnamed: 0_level_0,subject_id_main,subject_id
category,Unnamed: 1_level_1,Unnamed: 2_level_1
Case Management,0.058331,0.060925
Consult,0.005912,0.005429
Discharge summary,3.342068,3.287529
ECG,12.571904,12.493817
Echo,2.689748,2.690948
General,0.500731,0.481367
Nursing,13.485296,13.651389
Nursing/other,25.218365,25.228317
Nutrition,0.568111,0.585723
Pharmacy,0.006213,0.008445


In [24]:
src_train.tail()

print(df_labitems[df_labitems.subject_id == 60326])

          subject_id           charttime value valueuom      flag  \
22671346       60326 2142-06-06 00:55:00   8.1    mg/dL  abnormal   
22671347       60326 2142-06-06 00:55:00    25    mEq/L      None   
22671348       60326 2142-06-06 00:55:00    11    mEq/L      None   
22671349       60326 2142-06-07 04:40:00   9.8     K/uL      None   
22671350       60326 2142-06-07 04:40:00  3.66     m/uL  abnormal   
22671351       60326 2142-06-07 04:40:00  13.3        %      None   
22671352       60326 2142-06-07 04:40:00   250     K/uL      None   
22671353       60326 2142-06-07 04:40:00    84       fL      None   
22671354       60326 2142-06-07 04:40:00  33.0        %      None   
22671355       60326 2142-06-07 04:40:00  27.7       pg      None   
22671356       60326 2142-06-07 04:40:00  10.1     g/dL  abnormal   
22671357       60326 2142-06-07 04:40:00  30.7        %  abnormal   
22671358       60326 2142-06-07 04:40:00     5    mg/dL  abnormal   
22671359       60326 2142-06-07 04

In [25]:
# create target files

tgt_train = pd.DataFrame(src_train, columns = ["text"])
tgt_val = pd.DataFrame(src_val, columns = ["text"])
tgt_test = pd.DataFrame(src_test, columns = ["text"])

print(tgt_test.shape)
tgt_test.head()

(165778, 1)


Unnamed: 0,text
0,Demographics <par> Day of intubation : <par> D...
1,[ 2204/8/29 ] 5:41 AM <par> CHEST ( PORTABLE A...
2,TITLE : <par> Pt given medication nebulizer as...
3,NPN 1900 -0700 <par> <par> SIG EVENTS : PT BEC...
4,TITLE : <par> Chief Complaint : dyspnea <par> ...


In [26]:
# save target files to disk

np.savetxt('/mimic/data/preprocessed/tgt-train.txt', tgt_train, fmt='%s', newline=os.linesep)
np.savetxt('/mimic/data/preprocessed/tgt-val.txt', tgt_val, fmt='%s', newline=os.linesep)
np.savetxt('/mimic/data/preprocessed/tgt-test.txt', tgt_test, fmt='%s', newline=os.linesep)

In [28]:
%cd /mimic/data/preprocessed/
!wc -mlw src-train.txt src-test.txt src-val.txt

/mimic/data/preprocessed
   1326220  538649596 2270440965 src-train.txt
    165778   67321633  283911588 src-test.txt
    165778   67819433  286016623 src-val.txt
   1657776  673790662 2840369176 total


In [1]:
%cd /mimic/data/preprocessed/
!wc -mlw tgt-train.txt tgt-test.txt tgt-val.txt

/mimic/data/preprocessed
   1326220  636511443 3059913745 tgt-train.txt
    165778   79307995  381068082 tgt-test.txt
    165778   79770220  383320545 tgt-val.txt
   1657776  795589658 3824302372 total
