# Preprocess Discharge Summaries

In [1]:
import pandas as pd
import numpy as np
import os
import psycopg2
import sqlalchemy
import string
import spacy
from spacy.symbols import ORTH
import scispacy
from collections import Counter
import re
from datetime import date, datetime, timedelta
import random
from sklearn.model_selection import GroupShuffleSplit, StratifiedShuffleSplit
from spellchecker import SpellChecker

Connect to the mimic database and set the search path to the 'mimiciii' schema

In [2]:
dbschema='mimiciii'
cnx = sqlalchemy.create_engine('postgresql+psycopg2://aa5118:mimic@localhost:5432/mimic',
                    connect_args={'options': '-csearch_path={}'.format(dbschema)})


Query the discharge summary notes joined on to patient data

In [3]:
sql = """
  SELECT
      p.subject_id, p.dob, p.gender,
      n.hadm_id, n.category, n.chartdate, n.row_id,
      ROUND((cast(chartdate as date) - cast(dob as date)) / 365.242,0)
          AS age_at_noteevent,
      n.text
  FROM patients p 
  INNER JOIN noteevents n 
  ON p.subject_id = n.subject_id
  WHERE ROUND((cast(chartdate as date) - cast(dob as date)) / 365.242,0) > 14
  AND n.category = 'Discharge summary'
  ORDER BY subject_id
  LIMIT 100;
"""

df = pd.read_sql_query(sqlalchemy.text(sql), cnx)
df.head()

Unnamed: 0,subject_id,dob,gender,hadm_id,category,chartdate,row_id,age_at_noteevent,text
0,3,2025-04-11,M,145834,Discharge summary,2101-10-31,44005,77.0,Admission Date: [**2101-10-20**] Discharg...
1,4,2143-05-12,F,185777,Discharge summary,2191-03-23,4788,48.0,Admission Date: [**2191-3-16**] Discharge...
2,6,2109-06-21,F,107064,Discharge summary,2175-06-15,20825,66.0,Admission Date: [**2175-5-30**] Dischar...
3,9,2108-01-26,M,150750,Discharge summary,2149-11-14,57115,42.0,"Name: [**Known lastname 10050**], [**Known fi..."
4,9,2108-01-26,M,150750,Discharge summary,2149-11-13,20070,42.0,Admission Date: [**2149-11-9**] Dischar...


Change data type of age to the smallest possible type of integer to save memory and get rid of decimal point

In [4]:
df['age_at_noteevent'] = pd.to_numeric(df['age_at_noteevent'], downcast='integer')
df.head(20)

Unnamed: 0,subject_id,dob,gender,hadm_id,category,chartdate,row_id,age_at_noteevent,text
0,3,2025-04-11,M,145834,Discharge summary,2101-10-31,44005,77,Admission Date: [**2101-10-20**] Discharg...
1,4,2143-05-12,F,185777,Discharge summary,2191-03-23,4788,48,Admission Date: [**2191-3-16**] Discharge...
2,6,2109-06-21,F,107064,Discharge summary,2175-06-15,20825,66,Admission Date: [**2175-5-30**] Dischar...
3,9,2108-01-26,M,150750,Discharge summary,2149-11-14,57115,42,"Name: [**Known lastname 10050**], [**Known fi..."
4,9,2108-01-26,M,150750,Discharge summary,2149-11-13,20070,42,Admission Date: [**2149-11-9**] Dischar...
5,11,2128-02-22,F,194540,Discharge summary,2178-05-11,30120,50,Admission Date: [**2178-4-16**] ...
6,12,2032-03-24,M,112213,Discharge summary,2104-08-20,50972,72,Admission Date: [**2104-8-7**] Discharge ...
7,13,2127-02-27,F,143045,Discharge summary,2167-01-15,57099,40,"Name: [**Known lastname 9900**], [**Known fir..."
8,13,2127-02-27,F,143045,Discharge summary,2167-01-15,20168,40,Admission Date: [**2167-1-8**] Discharg...
9,17,2087-07-14,F,194023,Discharge summary,2134-12-31,51782,47,Admission Date: [**2134-12-27**] ...


In [5]:
df.shape

(100, 9)

55404 'adult' (15 or over) discharge summaries - this is what we expect from our exploratory data analysis

The following punctuation marks frequently appear in the middle of words or between words without spacing meaning they are missed by the tokenizer. What we need to is to split the tokens on these punctuation marks after we have tokenized. We do this with regex. We then retokenize. This will substantially decreases the number of our unique words which we will replace with <UNK>

- ampersand
- brackets
- colons
- forward slashes(make sure to leave dates alone though)
- full stops
- hyphens
- equals signs
- semicolons
- plus signs

In [6]:
date_regex = re.compile(r'([0-9])-([0-9][0-9]?)-([0-9])') # change date format so spacy can recognise
newline_regex = re.compile(r'(\\n){3,}') # cap number of consecutive newline characters to 2
newline_regex2 = re.compile(r'(\\r){3,}') # cap number of consecutive newline characters to 2
ellipsis_regex = re.compile(r'(\.){2,}')
tilda_mult_regex = re.compile(r'(~){2,}')
atsign_mult_regex = re.compile(r'(@){2,}')

bracket_regex = re.compile(r'(.)(\()(.)')
bracket_regex2 = re.compile(r'(.)(\))(.)')
slash_regex = re.compile(r'(.)(\/)([^0-9])')
slash_regex2 = re.compile(r'([^0-9])(\/)(.)')
equals_regex = re.compile(r'(.)(=)(.)')
colon_regex = re.compile(r'(.)(:)(.)')
sq_bracket_regex = re.compile(r'(.)(\[)(.)')
dash_regex = re.compile(r'(.)(-)(.)')
dash_regex2 = re.compile(r'(-)([\S])')
plus_regex = re.compile(r'(.)(\+)(.)')
amp_regex = re.compile(r'(.)(&)(.)')
star_regex = re.compile(r'(.)(\*)(.)') 
comma_regex = re.compile(r'(.)(,)(.)')
tilda_regex = re.compile(r'(.)(~)(.)')
pipe_regex = re.compile(r'(.)(\|)(.)')
atsign_regex = re.compile(r'(.)(@)(.)')
dot_regex = re.compile(r'([^.][^0-9])(\.)([^0-9,][^.])')

dot_regex2 = re.compile(r'([^0-9])(\.)(.)')
semicol_regex = re.compile(r'(.);(.)')
caret_regex = re.compile(r'(.)\^(.)')

In [7]:
nlp = spacy.load('en_core_sci_md') # sciSpaCy

nlp.tokenizer.add_special_case(u'<PAR>', [{ORTH: u'<PAR>'}])
nlp.tokenizer.add_special_case(u'<UNK>', [{ORTH: u'<UNK>'}])

i = 0

def tokenise_text(text, counter):
    global i
    
    text = str(text)
    text = date_regex.sub(r'\1/\2/\3',text)
    text = newline_regex.sub(r' \\n\\n ',text)
    text = newline_regex2.sub(r' \\n\\n ',text)
    text = ellipsis_regex.sub(r'.',text)
    text = tilda_mult_regex.sub(r'~',text)
    text = atsign_mult_regex.sub(r'@',text)
    
    text = text.replace("[**","[").replace("**]","]")
    
    tokens = nlp.tokenizer(text)
    tokenised_text = ""
    
    for token in tokens:
        tokenised_text = tokenised_text + token.text + " "
    
    tokenised_text = tokenised_text.replace("\n"," <PAR> ")
    
    tokenised_text = bracket_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = bracket_regex2.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = slash_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = slash_regex2.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = slash_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = equals_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = colon_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = sq_bracket_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = dash_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = dash_regex.sub(r'\1 \2 \3',tokenised_text) # dash twice because sometimes it appears twice
    tokenised_text = dash_regex.sub(r'\1 \2 \3',tokenised_text) # dash thrice because sometimes it appears thrice
    tokenised_text = dash_regex2.sub(r'\1 \2',tokenised_text) # dash thrice because sometimes it appears thrice
    tokenised_text = plus_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = star_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = amp_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = comma_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = dot_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = atsign_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = tilda_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = pipe_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = dot_regex2.sub(r'\1 \3',tokenised_text)
    tokenised_text = semicol_regex.sub(r'\1 \2',tokenised_text)
    tokenised_text = caret_regex.sub(r'\1 \2',tokenised_text)
        
    tokenised_text = ' '.join(tokenised_text.split())
    
    tokens = nlp.tokenizer(tokenised_text)
    tokenised_text = ""
    
    for token in tokens:
        tokenised_text = tokenised_text + token.text + " "
    
    counter.update(tokenised_text.lower().split())
    
    i += 1
    if (i % 100) == 0:
        print (i)
    
    return tokenised_text

In [8]:
word_freq = Counter()
df["text"] = df["text"].apply(tokenise_text, args = (word_freq,))

100


Below we isolate the tokens which appear 3 times or fewer. They are mostly misspellings.

In [9]:
infreq_words = [word for word in word_freq.keys() if word_freq[word] <= 3 and word[0].isdigit() == False]
print(len(infreq_words))
sorted(infreq_words)[10000:11000]

4337


[]

We try and see if we can correct the misspellings using the `pyspellchecker` library by using the Levenshtein Distance algorithm and comparing against a dictionary. We first add the words with >3 occurrence to our dictionary. This is because they include a lot of scientific/medical terms which might not already be there

In [10]:
freq_words = [word for word in word_freq.keys() if word_freq[word] > 3]
add_to_dictionary = " ".join(freq_words)
f=open("data/mimic_dict.txt", "w+")
f.write(add_to_dictionary)
f.close()

In [11]:
spell = SpellChecker()
spell.distance = 1  # set the distance parameter to just 1 edit away - much quicker
spell.word_frequency.load_text_file('data/mimic_dict.txt')

In [12]:
misspelled = spell.unknown(infreq_words)
misspell_dict = {}
for i, word in enumerate(misspelled):
    if (word != spell.correction(word)):
        misspell_dict[word] = spell.correction(word)
    if (i % 100 == 0):
        print(i)

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400


In [13]:
print(len(misspell_dict))
misspell_dict

683


{'titers': 'tigers',
 'nonionic': 'non-ionic',
 'megace': 'menace',
 'actos': 'acts',
 'occ': 'oct',
 'hct%': 'hctc',
 'enought': 'enough',
 'titrate': 'nitrate',
 'lll': 'all',
 'labetelol': 'labetalol',
 'derm': 'erm',
 'schduling': 'scheduling',
 'lpm': 'pm',
 'spo2': 'spot',
 'mpa': 'map',
 'heliox': 'helix',
 'intesive': 'intensive',
 'isrs': 'isis',
 'ugi': 'ugh',
 'qsat': 'sat',
 'tachyardia': 'tachycardia',
 'hyst': 'host',
 'ptn': 'pen',
 'cnii': 'nii',
 'psh': 'push',
 'palpations': 'palpation',
 'chiari': 'chiaro',
 'flomax': 'lomax',
 'papular': 'popular',
 'prev': 'prey',
 'dilatations': 'dilatation',
 'wma': 'ma',
 'contigious': 'contiguous',
 'tvi': 'tv',
 'createnine': 'creatinine',
 'convinving': 'convincing',
 'strenght': 'strength',
 'wtih': 'with',
 'tranfer': 'transfer',
 'avute': 'acute',
 'integrillin': 'integrilin',
 'erythematosus': 'erythematous',
 'qwednesday': 'wednesday',
 'rusb': 'rush',
 'cartia': 'cardia',
 'hcfa': 'cfa',
 'alternans': 'alternants',
 'pr

We now have correct spellings for many words in our dictionary that occurred <= 3 times. Anything else will be marked as `UNK`. We will save these as text files to avoid having to run this computation again.

In [14]:
unk_words = [word for word in infreq_words if word not in list(misspell_dict.keys())]
len(unk_words)

3654

In [15]:
if 'treatement' in misspell_dict.keys():
    print (misspell_dict['treatement'])

treatment


In [16]:
np.savetxt('data/discharge_unk_words.txt', unk_words, fmt='%s', newline=os.linesep)
f=open("data/discharge_typos.txt", "w+")

for key in misspell_dict:
    f.write(key + '\t' + misspell_dict[key] + '\n')
f.close()

#pd.read_csv('data/discharge_typos.txt', sep='\t',header=None)

 We will correct spelling mistakes whilst any word left uncorrected will be replaced with `<UNK>`

In [40]:
def fix_typos(text, typos, unks):
    
    tokens = text.split()
    tokenised_text = ""
    
    for token in tokens:
        if token.lower() in typos.keys():
            token = typos[token.lower()]
        elif token.lower() in unks:
            token = "<UNK>"
        tokenised_text = tokenised_text + token + " "
    
    return tokenised_text

In [41]:
# apply tokenising function elementwise
df["text"] = df["text"].apply(fix_typos, args = (misspell_dict, unk_words))
df.head()

Unnamed: 0,subject_id,dob,gender,hadm_id,category,chartdate,row_id,age_at_noteevent,text,hint,ethnicity,diagnosis,admission_type,30d_unplan_readmit
0,3,2025-04-11,M,145834,Discharge summary,2101-10-31,44005,77,Admission Date : [ 2101/10/20 ] Discharge Date...,Admission Date : [ 2101/10/20 ] Discharge Date...,white,hypotension,EMERGENCY,N
1,4,2143-05-12,F,185777,Discharge summary,2191-03-23,4788,48,Admission Date : [ 2191/3/16 ] Discharge Date ...,Admission Date : [ 2191/3/16 ] Discharge Date : [,white,"fever , dehydration , failure to thrive",EMERGENCY,N
2,6,2109-06-21,F,107064,Discharge summary,2175-06-15,20825,66,Admission Date : [ 2175/5/30 ] Discharge Date ...,Admission Date : [ 2175/5/30 ] Discharge Date : [,white,chronic renal failure/sda,ELECTIVE,N
3,9,2108-01-26,M,150750,Discharge summary,2149-11-14,57115,42,"Name : [ Known lastname 10050 ] , [ Known firs...","Name : [ Known lastname 10050 ] , [ Known",other,hemorrhagic cva,EMERGENCY,N
4,9,2108-01-26,M,150750,Discharge summary,2149-11-13,20070,42,Admission Date : [ 2149/11/9 ] Discharge Date ...,Admission Date : [ 2149/11/9 ] Discharge Date : [,other,hemorrhagic cva,EMERGENCY,N


In [19]:
' '.join(df.head()['text'])

"Admission Date : [ 2101/10/20 ] Discharge Date : [ 2101/10/31 ] <PAR> <PAR> Date of Birth : [ 2025/4/11 ] Sex : M <PAR> <PAR> Service : Medicine <PAR> <PAR> CHIEF COMPLAINT : Admitted from rehabilitation for <PAR> hypotension ( systolic blood pressure to the 70s ) and <PAR> decreased urine output <PAR> <PAR> HISTORY OF PRESENT ILLNESS : The patient is a 76 - year - old <PAR> male who had been hospitalized at the [ Hospital1 190 ] from [ 10 - 11 ] through [ 10 - 19 ] of [ 2101 ] <PAR> after undergoing a left femoral - AT bypass graft and was <PAR> subsequently discharged to a rehabilitation facility <PAR> <PAR> On [ 2101/10/20 ] , he presented again to the [ Hospital1 346 ] after being found to have a systolic <PAR> blood pressure in the 70s and no urine output for 17 hours <PAR> A Foley catheter placed at the rehabilitation facility <PAR> yielded 100 cc of murky / brown urine There may also have <PAR> been purulent discharge at the penile meatus at this time <PAR> <PAR> On presentatio

The text field has now been fully cleaned and tokenised. We can proceed to extract the first few tokens to use as a hint and move forth with joining other tables and partitioning the dataset.

In [20]:
counter = 0
def produce_hint(text):
    global counter
    l = text.split()
    counter += 1
    if (counter % 10000) == 0:
        print (counter)
    return ' '.join(l[:10]) # first 10 tokens

df['hint'] = df['text'].map(lambda x: produce_hint(x))

print(df.shape)
df.head()

(100, 10)


Unnamed: 0,subject_id,dob,gender,hadm_id,category,chartdate,row_id,age_at_noteevent,text,hint
0,3,2025-04-11,M,145834,Discharge summary,2101-10-31,44005,77,Admission Date : [ 2101/10/20 ] Discharge Date...,Admission Date : [ 2101/10/20 ] Discharge Date...
1,4,2143-05-12,F,185777,Discharge summary,2191-03-23,4788,48,Admission Date : [ 2191/3/16 ] Discharge Date ...,Admission Date : [ 2191/3/16 ] Discharge Date : [
2,6,2109-06-21,F,107064,Discharge summary,2175-06-15,20825,66,Admission Date : [ 2175/5/30 ] Discharge Date ...,Admission Date : [ 2175/5/30 ] Discharge Date : [
3,9,2108-01-26,M,150750,Discharge summary,2149-11-14,57115,42,"Name : [ Known lastname 10050 ] , [ Known firs...","Name : [ Known lastname 10050 ] , [ Known"
4,9,2108-01-26,M,150750,Discharge summary,2149-11-13,20070,42,Admission Date : [ 2149/11/9 ] Discharge Date ...,Admission Date : [ 2149/11/9 ] Discharge Date : [


In [21]:
# patients above 89 years of age had their dob modified to be 300 years old at time of first event for privacy reasons
# change their age to instead be 90

df.loc[df['age_at_noteevent'] > 200, 'age_at_noteevent'] = 90

We now merge our dataframe with some extra information from the admissions table. This preprocessing has already been done in another notebook so here we just load the csv file. Primarily, this table provides us with the ethnicity of the patient and whether or not they go on to have an unplanned readmission within 30 days of being discharged. This is relevant for one of our downstream tasks using the artificial data

In [22]:
# admissions data

admissions = pd.read_csv('data/df_adm.csv', sep=',')
admissions.head()
df = pd.merge(df, admissions,  how='left', left_on=['subject_id','hadm_id'], right_on = ['subject_id','hadm_id'])

Now we load all the context data from other tables. We'll be using this to construct the structured input to our encoder

In [24]:
# lab items data

df_labitems = pd.read_sql_query('''
  SELECT l.subject_id, l.charttime, l.value, l.valueuom, l.flag, d.label
  FROM labevents l
  INNER JOIN d_labitems d 
  USING (itemid)
  ORDER BY l.subject_id
  LIMIT 10000;
''', cnx)

print(df_labitems.shape)
df_labitems.head()

(10000, 6)


Unnamed: 0,subject_id,charttime,value,valueuom,flag,label
0,2,2138-07-17 20:48:00,0,%,,Atypical Lymphocytes
1,2,2138-07-17 20:48:00,0,%,,Bands
2,2,2138-07-17 20:48:00,0,%,,Basophils
3,2,2138-07-17 20:48:00,0,%,,Eosinophils
4,2,2138-07-17 20:48:00,0,%,abnormal,Hematocrit


In [26]:
# prescriptions data

df_prescriptions = pd.read_sql_query('''
  SELECT subject_id, startdate, enddate, drug, prod_strength
  FROM prescriptions
  ORDER BY subject_id
  LIMIT 10000;
''', cnx)

print(df_prescriptions.shape)
df_prescriptions.head()

(10000, 5)


Unnamed: 0,subject_id,startdate,enddate,drug,prod_strength
0,2,2138-07-18,2138-07-20,NEO*IV*Gentamicin,10mg/mL-2mL
1,2,2138-07-18,2138-07-20,Syringe (Neonatal) *D5W*,1 Syringe
2,2,2138-07-18,2138-07-21,Ampicillin Sodium,500mg Vial
3,2,2138-07-18,2138-07-21,Send 500mg Vial,Send 500mg Vial
4,4,2191-03-16,2191-03-23,Guaifenesin-Codeine Phosphate,5ML UDCUP


In [27]:
# diagnoses data

df_diagnoses = pd.read_sql_query('''
  SELECT d.subject_id, d.hadm_id, d.seq_num, d.icd9_code, icd.short_title, icd.long_title
  FROM diagnoses_icd d
  INNER JOIN d_icd_diagnoses icd 
  USING (icd9_code)
  ORDER BY d.subject_id, d.seq_num
  --LIMIT 10000;
''', cnx)

print(df_diagnoses.shape)
df_diagnoses.head(10)

(634709, 6)


Unnamed: 0,subject_id,hadm_id,seq_num,icd9_code,short_title,long_title
0,2,163353,1,V3001,Single lb in-hosp w cs,"Single liveborn, born in hospital, delivered b..."
1,2,163353,2,V053,Need prphyl vc vrl hepat,Need for prophylactic vaccination and inoculat...
2,2,163353,3,V290,NB obsrv suspct infect,Observation for suspected infectious condition
3,3,145834,1,0389,Septicemia NOS,Unspecified septicemia
4,3,145834,2,78559,Shock w/o trauma NEC,Other shock without mention of trauma
5,3,145834,3,5849,Acute kidney failure NOS,"Acute kidney failure, unspecified"
6,3,145834,4,4275,Cardiac arrest,Cardiac arrest
7,3,145834,5,41071,"Subendo infarct, initial","Subendocardial infarction, initial episode of ..."
8,3,145834,6,4280,CHF NOS,"Congestive heart failure, unspecified"
9,3,145834,7,6826,Cellulitis of leg,"Cellulitis and abscess of leg, except foot"


In [28]:
# procedures data

hadm_proc_subset = list(pd.read_csv('data/df_proc_hadm_ids.csv', sep=',')['hadm_id'])

df_procedures = pd.read_sql_query('''
  SELECT p.subject_id, p.hadm_id, p.seq_num, p.icd9_code, icd.short_title, icd.long_title
  FROM procedures_icd p
  INNER JOIN d_icd_procedures icd 
  USING (icd9_code)
  ORDER BY p.subject_id, p.seq_num
  --LIMIT 10000;
''', cnx)

print(df_procedures.shape)
df_procedures = df_procedures[df_procedures['hadm_id'].isin(hadm_proc_subset)]
print(df_procedures.shape)
df_procedures.head(20)

(237948, 6)
(212258, 6)


Unnamed: 0,subject_id,hadm_id,seq_num,icd9_code,short_title,long_title
1,3,145834,1,9604,Insert endotracheal tube,Insertion of endotracheal tube
2,3,145834,2,9962,Heart countershock NEC,Other electric countershock of heart
3,3,145834,3,8964,Pulmon art wedge monitor,Pulmonary artery wedge monitoring
4,3,145834,4,9672,Cont inv mec ven 96+ hrs,Continuous invasive mechanical ventilation for...
5,3,145834,5,3893,Venous cath NEC,"Venous catheterization, not elsewhere classified"
6,3,145834,6,966,Entral infus nutrit sub,Enteral infusion of concentrated nutritional s...
7,4,185777,1,3893,Venous cath NEC,"Venous catheterization, not elsewhere classified"
8,4,185777,2,8872,Dx ultrasound-heart,Diagnostic ultrasound of heart
9,4,185777,3,3323,Other bronchoscopy,Other bronchoscopy
11,6,107064,1,5569,Kidney transplant NEC,Other kidney transplantation


In [29]:
# microbiology data

df_microbiology = pd.read_sql_query('''
  SELECT subject_id, hadm_id, MAX(chartdate) AS chartdate, spec_type_desc, string_agg(DISTINCT(org_name), ', ') AS organism
  FROM microbiologyevents
  GROUP BY subject_id, hadm_id, charttime, spec_type_desc
  ORDER BY subject_id, chartdate DESC
  --LIMIT 10000;
''', cnx)

print(df_microbiology.shape)
df_microbiology.head(20)

(340305, 5)


Unnamed: 0,subject_id,hadm_id,chartdate,spec_type_desc,organism
0,2,163353,2138-07-17,BLOOD CULTURE - NEONATE,
1,3,145834,2101-10-28,STOOL,
2,3,145834,2101-10-26,URINE,YEAST
3,3,145834,2101-10-21,BLOOD CULTURE,
4,3,145834,2101-10-21,URINE,YEAST
5,3,145834,2101-10-21,CATHETER TIP-IV,
6,3,145834,2101-10-21,BLOOD CULTURE,
7,3,145834,2101-10-21,BLOOD CULTURE ( MYCO/F LYTIC BOTTLE),
8,3,145834,2101-10-21,BLOOD CULTURE,
9,3,145834,2101-10-21,SPUTUM,YEAST


In [42]:
# phenotype classification subject - we need to make sure these are part of our test set

annotations = pd.read_csv('data/annotations.csv')
annotations.columns.values[0] = 'hadm_id'
annotations.columns.values[1] = 'subject_id'

pheno_subjects = list(annotations['subject_id'])
print(annotations.shape)
annotations.head()

(1610, 18)


Unnamed: 0,hadm_id,subject_id,chart.time,cohort,Obesity,Non.Adherence,Developmental.Delay.Retardation,Advanced.Heart.Disease,Advanced.Lung.Disease,Schizophrenia.and.other.Psychiatric.Disorders,Alcohol.Abuse,Other.Substance.Abuse,Chronic.Pain.Fibromyalgia,Chronic.Neurological.Dystrophies,Advanced.Cancer,Depression,Dementia,Unsure
0,118003,3644,118003,1,0,0,0,0,0,0,0,0,1,0,0,1,0,0
1,177830,97736,999999,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0
2,185673,27694,999999,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0
3,131938,16275,131938,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,198999,4059,198999,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0


We now split the dataset into training, validation and test sets. As we do so, we query the relevant context data and append it to the input

In [30]:
#%%timeit -n 3 -r 3

# Split the dataset in a grouped and stratified manner

def StratifiedGroupShuffleSplit(df_main):

    df_main = df_main.reindex(np.random.permutation(df_main.index)) # shuffle dataset
    
    # create empty train, val and test datasets
    df_train = pd.DataFrame()
    df_val = pd.DataFrame()
    df_test = pd.DataFrame()

    hparam_mse_wgt = 0.1 # must be between 0 and 1
    assert(0 <= hparam_mse_wgt <= 1)
    train_proportion = 0.8 # must be between 0 and 1
    assert(0 <= train_proportion <= 1)
    val_test_proportion = (1-train_proportion)/2

    subject_grouped_df_main = df_main.groupby(['subject_id'], sort=False, as_index=False)
    readmit_grouped_df_main = df_main.groupby('30d_unplan_readmit').count()[['subject_id']]/len(df_main)*100 
    
    # function to calculate loss
    def calc_mse_loss(df):
        grouped_df = df.groupby('30d_unplan_readmit').count()[['subject_id']]/len(df)*100
        df_temp = readmit_grouped_df_main.join(grouped_df, on = '30d_unplan_readmit', how = 'left', lsuffix = '_main')
        df_temp.fillna(0, inplace=True)
        df_temp['diff'] = (df_temp['subject_id_main'] - df_temp['subject_id'])**2
        mse_loss = np.mean(df_temp['diff'])
        return mse_loss
    
    directory = "data/preprocessed/"
    
    f_train = open(directory + "src-train.txt","w+")
    f_val = open(directory + "src-val.txt","w+")
    f_test = open(directory + "src-test.txt","w+")
    
    len_train = 0
    len_val = 0
    len_test = 0
    total_records = 0
    i = 0

    # loop the groups of subjects one by one
    for _, group in subject_grouped_df_main:

        total_records = len_train + len_val + len_test
        g = pd.DataFrame(group)
        subject_id = g['subject_id'].iloc[0]
        
        pre_left = df_prescriptions['subject_id'].searchsorted(subject_id, 'left')
        pre_right = df_prescriptions['subject_id'].searchsorted(subject_id, 'right')
        
        lab_left = df_labitems['subject_id'].searchsorted(subject_id, 'left')
        lab_right = df_labitems['subject_id'].searchsorted(subject_id, 'right')
        
        diag_left = df_diagnoses['subject_id'].searchsorted(subject_id, 'left')
        diag_right = df_diagnoses['subject_id'].searchsorted(subject_id, 'right')
        
        proc_left = df_procedures['subject_id'].searchsorted(subject_id, 'left')
        proc_right = df_procedures['subject_id'].searchsorted(subject_id, 'right')
        
        micro_left = df_microbiology['subject_id'].searchsorted(subject_id, 'left')
        micro_right = df_microbiology['subject_id'].searchsorted(subject_id, 'right')
        
        g_prescriptions = df_prescriptions[pre_left:pre_right]
        g_labitems = df_labitems[lab_left:lab_right]
        g_diagnoses = df_diagnoses[diag_left:diag_right]
        g_procedures = df_procedures[proc_left:proc_right]
        g_microbiology = df_microbiology[micro_left:micro_right]
        
        train = False
        val = False
        test = False
        i += 1
        
        # all subjects in the phenotyping dataset need to be in the test set
        if subject_id in pheno_subjects:
            df_test = df_test.append(g, ignore_index=True)
            len_test += len(g)
            test = True
            
        # just to add something to each group to start off with - otherwise we end up dividing by 0
        elif (len_train == 0 or len_val == 0 or len_test == 0):
            if (len_train == 0):
                df_train = df_train.append(g, ignore_index=True)
                len_train += len(g)
                train = True
            elif (len_val == 0):
                df_val = df_val.append(g, ignore_index=True)
                len_val += len(g)
                val = True
            else:
                df_test = df_test.append(g, ignore_index=True)
                len_test += len(g)
                test = True
        
        # every 100th group, balance the groups jointly by proportion and by ratio of unplanned readmissions
        elif (i % 100 == 0):
            
            mse_loss_diff_train = calc_mse_loss(df_train) - calc_mse_loss(df_train.append(g, ignore_index=True))
            mse_loss_diff_val = calc_mse_loss(df_val) - calc_mse_loss(df_val.append(g, ignore_index=True))
            mse_loss_diff_test = calc_mse_loss(df_test) - calc_mse_loss(df_test.append(g, ignore_index=True))

            len_diff_train = (train_proportion - (len_train/total_records))
            len_diff_val = (val_test_proportion - (len_val/total_records))
            len_diff_test = (val_test_proportion - (len_test/total_records)) 

            len_loss_diff_train = len_diff_train * abs(len_diff_train)
            len_loss_diff_val = len_diff_val * abs(len_diff_val)
            len_loss_diff_test = len_diff_test * abs(len_diff_test)

            loss_train = (hparam_mse_wgt * mse_loss_diff_train) + ((1-hparam_mse_wgt) * len_loss_diff_train)
            loss_val = (hparam_mse_wgt * mse_loss_diff_val) + ((1-hparam_mse_wgt) * len_loss_diff_val)
            loss_test = (hparam_mse_wgt * mse_loss_diff_test) + ((1-hparam_mse_wgt) * len_loss_diff_test)

            if (max(loss_train,loss_val,loss_test) == loss_train):
                df_train = df_train.append(g, ignore_index=True)
                len_train += len(g)
                train = True
            elif (max(loss_train,loss_val,loss_test) == loss_val):
                df_val = df_val.append(g, ignore_index=True)
                len_val += len(g)
                val = True
            else:
                df_test = df_test.append(g, ignore_index=True)
                len_test += len(g)
                test = True
            
            print ("Group " + str(i) + ". loss_train: " + str(loss_train) + " | " + "loss_val: " + str(loss_val) + " | " + "loss_test: " + str(loss_test) + " | ")
        
        # all the other groups - divided simply by ratios of the dataset splits
        else:
            
            if (train_proportion > (len_train/total_records)):
                df_train = df_train.append(g, ignore_index=True)
                len_train += len(g)
                train = True
            elif (val_test_proportion > (len_val/total_records)):
                df_val = df_val.append(g, ignore_index=True)
                len_val += len(g)
                val = True
            else:
                df_test = df_test.append(g, ignore_index=True)
                len_test += len(g)
                test = True
        
        # loop through every row in the group to get relevant prescriptions and lab items before appending to file
        for j, row in enumerate(g.itertuples()):
            
            chartdate = datetime.combine(row[6], datetime.min.time())
            cutoff = chartdate
            chartdate = cutoff + timedelta(days=1)    
            
            lab_condition = np.logical_and((g_labitems.charttime >= cutoff),
                                           (g_labitems.charttime < chartdate))
            lab_items = g_labitems[lab_condition]
            lab_items = lab_items.sort_values(by=['charttime'], ascending=False)
            
            pre_condition = np.logical_and((g_prescriptions.startdate >= cutoff),
                                           (g_prescriptions.startdate < chartdate))
            prescriptions = g_prescriptions[pre_condition]
            prescriptions = prescriptions.sort_values(by=['startdate'], ascending=False)
            
            lab_items_list = ""
            lab_items_length = len(lab_items)
            if (lab_items_length > 0):
                for k, lab_row in enumerate(lab_items.itertuples()):
                    flag = ""
                    if (pd.isna(lab_row[5]) == False):
                        flag = " , " + str(lab_row[5])

                    lab_items_list += str(lab_row[6]) + " , " + str(lab_row[3]) + " , " + str(lab_row[4]) + flag
                    if (k != (lab_items_length - 1)):
                        lab_items_list += " | "

            prescriptions_list = ""
            prescriptions_length = len(prescriptions)
            if (prescriptions_length > 0):
                for k, pre_row in enumerate(prescriptions.itertuples()):
                    prescriptions_list += str(pre_row[4]) + " , " + str(pre_row[5])
                    if (k != (prescriptions_length - 1)):
                        prescriptions_list += " | "
            
            if (train == True):
                f_train.write(str(row[10]) + " <H> " + str(row[3]) + " <G> " + str(row[8]) + " <A> " + 
                    prescriptions_list + " <M> " + lab_items_list + " <L>" + "\n")
            elif (val == True):
                f_val.write(str(row[10]) + " <H> " + " <T> " + str(row[3]) + " <G> " + str(row[8]) + " <A> " + 
                    prescriptions_list + " <M> " + lab_items_list + " <L>" + "\n")
            else:
                f_test.write(str(row[10]) + " <H> " + " <T> " + str(row[3]) + " <G> " + str(row[8]) + " <A> " + 
                    prescriptions_list + " <M> " + lab_items_list + " <L>" + "\n")
        
        if (i % 100 == 0):
            print (i)
    
    f_train.close()
    f_val.close()
    f_test.close()
    
    return df_train, df_val, df_test

In [31]:
src_train, src_val, src_test = StratifiedGroupShuffleSplit(df)

In [32]:
src_train

Unnamed: 0,subject_id,dob,gender,hadm_id,category,chartdate,row_id,age_at_noteevent,text,hint,ethnicity,diagnosis,admission_type,30d_unplan_readmit
0,68,2132-02-29,F,108329,Discharge summary,2174-01-18,9139,42,Admission Date : [ 2174/1/4 ] Discharge Date :...,Admission Date : [ 2174/1/4 ] Discharge Date : [,black,weakness,EMERGENCY,N
1,68,2132-02-29,F,108329,Discharge summary,2174-01-18,56002,42,"Name : [ Known lastname 5477 ] , [ Known first...","Name : [ Known lastname 5477 ] , [ Known",black,weakness,EMERGENCY,N
2,68,2132-02-29,F,170467,Discharge summary,2174-01-03,9138,42,Admission Date : [ 2173/12/15 ] Discharge Date...,Admission Date : [ 2173/12/15 ] Discharge Date...,black,pneumonia,EMERGENCY,Y
3,68,2132-02-29,F,108329,Discharge summary,2174-01-19,56004,42,"Name : [ Known lastname 5477 ] , [ Known first...","Name : [ Known lastname 5477 ] , [ Known",black,weakness,EMERGENCY,N
4,78,2128-07-01,M,100536,Discharge summary,2177-02-17,1787,49,Admission Date : [ 2177/2/14 ] Discharge Date ...,Admission Date : [ 2177/2/14 ] Discharge Date : [,black,opiate intoxication,EMERGENCY,N
5,19,1808-08-05,M,109235,Discharge summary,2108-08-11,25231,90,Admission Date : [ 2108/8/5 ] Discharge Date :...,Admission Date : [ 2108/8/5 ] Discharge Date : [,white,c-2 fracture,EMERGENCY,N
6,36,2061-08-17,M,122659,Discharge summary,2131-05-25,7413,70,Admission Date : [ 2131/5/12 ] Discharge Date ...,Admission Date : [ 2131/5/12 ] Discharge Date : [,white,chest pain/shortness of breath,EMERGENCY,N
7,36,2061-08-17,M,182104,Discharge summary,2131-05-08,7412,70,Admission Date : [ 2131/4/30 ] Discharge Date ...,Admission Date : [ 2131/4/30 ] Discharge Date : [,white,coronary artery disease\coronary artery bypass...,EMERGENCY,Y
8,36,2061-08-17,M,165660,Discharge summary,2134-05-20,7414,73,Admission Date : [ 2134/5/10 ] Discharge Date ...,Admission Date : [ 2134/5/10 ] Discharge Date : [,white,ventral hernia/sda,ELECTIVE,N
9,21,2047-04-04,M,111970,Discharge summary,2135-02-08,7238,88,Admission Date : [ 2135/1/30 ] Discharge Date ...,Admission Date : [ 2135/1/30 ] Discharge Date : [,white,sepsis,EMERGENCY,N
