# Preprocess Discharge Summaries

In [7]:
import pandas as pd
import numpy as np
import os
import psycopg2
import sqlalchemy
import string
import spacy
from spacy.symbols import ORTH
import scispacy
from collections import Counter
import re
from datetime import date, datetime, timedelta
import random
from sklearn.model_selection import GroupShuffleSplit, StratifiedShuffleSplit
from spellchecker import SpellChecker

Connect to the mimic database and set the search path to the 'mimiciii' schema

In [8]:
dbschema='mimiciii'
cnx = sqlalchemy.create_engine('postgresql+psycopg2://aa5118:mimic@localhost:5432/mimic',
                    connect_args={'options': '-csearch_path={}'.format(dbschema)})


Query the discharge summary notes joined on to patient data

In [3]:
sql = """
  SELECT
      p.subject_id, p.dob, p.gender,
      n.hadm_id, n.category, n.chartdate, n.row_id,
      ROUND((cast(chartdate as date) - cast(dob as date)) / 365.242,0)
          AS age_at_noteevent,
      n.text
  FROM patients p 
  INNER JOIN noteevents n 
  ON p.subject_id = n.subject_id
  WHERE ROUND((cast(chartdate as date) - cast(dob as date)) / 365.242,0) > 14
  AND n.category = 'Discharge summary'
  ORDER BY subject_id
  --LIMIT 100;
"""

df = pd.read_sql_query(sqlalchemy.text(sql), cnx)
df.head()

Unnamed: 0,subject_id,dob,gender,hadm_id,category,chartdate,row_id,age_at_noteevent,text
0,3,2025-04-11,M,145834,Discharge summary,2101-10-31,44005,77.0,Admission Date: [**2101-10-20**] Discharg...
1,4,2143-05-12,F,185777,Discharge summary,2191-03-23,4788,48.0,Admission Date: [**2191-3-16**] Discharge...
2,6,2109-06-21,F,107064,Discharge summary,2175-06-15,20825,66.0,Admission Date: [**2175-5-30**] Dischar...
3,9,2108-01-26,M,150750,Discharge summary,2149-11-14,57115,42.0,"Name: [**Known lastname 10050**], [**Known fi..."
4,9,2108-01-26,M,150750,Discharge summary,2149-11-13,20070,42.0,Admission Date: [**2149-11-9**] Dischar...


Change data type of age to the smallest possible type of integer to save memory and get rid of decimal point

In [4]:
df['age_at_noteevent'] = pd.to_numeric(df['age_at_noteevent'], downcast='integer')
df.head()

Unnamed: 0,subject_id,dob,gender,hadm_id,category,chartdate,row_id,age_at_noteevent,text
0,3,2025-04-11,M,145834,Discharge summary,2101-10-31,44005,77,Admission Date: [**2101-10-20**] Discharg...
1,4,2143-05-12,F,185777,Discharge summary,2191-03-23,4788,48,Admission Date: [**2191-3-16**] Discharge...
2,6,2109-06-21,F,107064,Discharge summary,2175-06-15,20825,66,Admission Date: [**2175-5-30**] Dischar...
3,9,2108-01-26,M,150750,Discharge summary,2149-11-14,57115,42,"Name: [**Known lastname 10050**], [**Known fi..."
4,9,2108-01-26,M,150750,Discharge summary,2149-11-13,20070,42,Admission Date: [**2149-11-9**] Dischar...


In [5]:
df.shape

(55404, 9)

55404 'adult' (15 or over) discharge summaries - this is what we expect from our exploratory data analysis

The following punctuation marks frequently appear in the middle of words or between words without spacing meaning they are missed by the tokenizer. What we need to is to split the tokens on these punctuation marks after we have tokenized. We do this with regex. We then retokenize. This will substantially decreases the number of our unique words which we will replace with <UNK>

- ampersand
- brackets
- colons
- forward slashes(make sure to leave dates alone though)
- full stops
- hyphens
- equals signs
- semicolons
- plus signs

In [6]:
date_regex = re.compile(r'([0-9])-([0-9][0-9]?)-([0-9])') # change date format so spacy can recognise
newline_regex = re.compile(r'(\\n){3,}') # cap number of consecutive newline characters to 2
newline_regex2 = re.compile(r'(\\r){3,}') # cap number of consecutive newline characters to 2
ellipsis_regex = re.compile(r'(\.){2,}')
tilda_mult_regex = re.compile(r'(~){2,}')
atsign_mult_regex = re.compile(r'(@){2,}')

bracket_regex = re.compile(r'(.)(\()(.)')
bracket_regex2 = re.compile(r'(.)(\))(.)')
slash_regex = re.compile(r'(.)(\/)([^0-9])')
slash_regex2 = re.compile(r'([^0-9])(\/)(.)')
equals_regex = re.compile(r'(.)(=)(.)')
colon_regex = re.compile(r'(.)(:)(.)')
sq_bracket_regex = re.compile(r'(.)(\[)(.)')
dash_regex = re.compile(r'(.)(-)(.)')
dash_regex2 = re.compile(r'(-)([\S])')
plus_regex = re.compile(r'(.)(\+)(.)')
amp_regex = re.compile(r'(.)(&)(.)')
star_regex = re.compile(r'(.)(\*)(.)') 
comma_regex = re.compile(r'(.)(,)(.)')
tilda_regex = re.compile(r'(.)(~)(.)')
pipe_regex = re.compile(r'(.)(\|)(.)')
atsign_regex = re.compile(r'(.)(@)(.)')
dot_regex = re.compile(r'([^.][^0-9])(\.)([^0-9,][^.])')

dot_regex2 = re.compile(r'([^0-9])(\.)(.)')
semicol_regex = re.compile(r'(.);(.)')
caret_regex = re.compile(r'(.)\^(.)')

In [7]:
nlp = spacy.load('en_core_sci_md') # sciSpaCy

nlp.tokenizer.add_special_case(u'<PAR>', [{ORTH: u'<PAR>'}])
nlp.tokenizer.add_special_case(u'<UNK>', [{ORTH: u'<UNK>'}])

i = 0

def tokenise_text(text, counter):
    global i
    
    text = str(text)
    text = date_regex.sub(r'\1/\2/\3',text)
    text = newline_regex.sub(r' \\n\\n ',text)
    text = newline_regex2.sub(r' \\n\\n ',text)
    text = ellipsis_regex.sub(r'.',text)
    text = tilda_mult_regex.sub(r'~',text)
    text = atsign_mult_regex.sub(r'@',text)
    
    text = text.replace("[**","[").replace("**]","]")
    
    tokens = nlp.tokenizer(text)
    tokenised_text = ""
    
    for token in tokens:
        tokenised_text = tokenised_text + token.text + " "
    
    tokenised_text = tokenised_text.replace("\n"," <PAR> ")
    
    tokenised_text = bracket_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = bracket_regex2.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = slash_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = slash_regex2.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = slash_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = equals_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = colon_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = sq_bracket_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = dash_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = dash_regex.sub(r'\1 \2 \3',tokenised_text) # dash twice because sometimes it appears twice
    tokenised_text = dash_regex.sub(r'\1 \2 \3',tokenised_text) # dash thrice because sometimes it appears thrice
    tokenised_text = dash_regex2.sub(r'\1 \2',tokenised_text) # dash thrice because sometimes it appears thrice
    tokenised_text = plus_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = star_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = amp_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = comma_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = dot_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = atsign_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = tilda_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = pipe_regex.sub(r'\1 \2 \3',tokenised_text)
    tokenised_text = dot_regex2.sub(r'\1 \3',tokenised_text)
    tokenised_text = semicol_regex.sub(r'\1 \2',tokenised_text)
    tokenised_text = caret_regex.sub(r'\1 \2',tokenised_text)
        
    tokenised_text = ' '.join(tokenised_text.split())
    
    tokens = nlp.tokenizer(tokenised_text)
    tokenised_text = ""
    
    for token in tokens:
        tokenised_text = tokenised_text + token.text + " "
    
    counter.update(tokenised_text.lower().split())
    
    i += 1
    if (i % 100) == 0:
        print (i)
    
    return tokenised_text

In [8]:
word_freq = Counter()
df["text"] = df["text"].apply(tokenise_text, args = (word_freq,))

100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
10500
10600
10700
10800
10900
11000
11100
11200
11300
11400
11500
11600
11700
11800
11900
12000
12100
12200
12300
12400
12500
12600
12700
12800
12900
13000
13100
13200
13300
13400
13500
13600
13700
13800
13900
14000
14100
14200
14300
14400
14500
14600
14700
14800
14900
15000
15100
15200
15300
15400
15500
15600
15700
15800
15900
16000
16100
16200
16300
16400
16500
16600
16700
16800
16900
17000
17100
17200
17300
17400
17500
17600
17700
17800
17900
18000
18100
18200
18300
18400
1850

Below we isolate the tokens which appear 3 times or fewer. They are mostly misspellings.

In [21]:
infreq_words = [word for word in word_freq.keys() if word_freq[word] <= 3 and word[0].isdigit() == False]
print(len(infreq_words))
sorted(infreq_words)[10000:11000]

97573


['awakeness',
 'awakense',
 'awaker',
 'awakfulness',
 'awakining',
 'awakke',
 'awakre',
 'awakw',
 'award',
 'awarenes',
 'awarranted',
 'awave',
 'awawre',
 'aways',
 'awb',
 'awc',
 'aweek',
 'awell',
 'aweyes',
 'awf',
 'awith',
 'awkae',
 'awke',
 'awkoe',
 'awkwardly',
 'awlays',
 'awmi',
 'awning',
 'awnsering',
 'awopke',
 'awwake',
 'ax0',
 'ax0x3',
 'ax4',
 'ax95',
 'axacerbation',
 'axallary',
 'axatia',
 'axcluded',
 'axhypotension',
 'axi',
 'axialblood',
 'axiallr',
 'axiilary',
 'axilalry',
 'axillaries',
 'axillarly',
 'axillary04/29/12',
 'axillaryadenopathy',
 'axillas',
 'axillia',
 'axilliary',
 'axilllary',
 'axillobifem',
 'axillofem',
 'axillory',
 'axilo',
 'axious',
 'axisand',
 'axisdeviation',
 'axithro',
 'axithromycin',
 'axium',
 'axix',
 'axlnd',
 'axo3',
 'axobifem',
 'axol',
 'axomal',
 'axon',
 'axones',
 'axopt',
 'axos',
 'axox',
 'axox0',
 'axquired',
 'axterixis',
 'axtreonam',
 'axx3',
 'axycodone',
 'aygestin',
 'aygestrin',
 'aymloidosis',
 'ay

We try and see if we can correct the misspellings using the `pyspellchecker` library by using the Levenshtein Distance algorithm and comparing against a dictionary. We first add the words with >3 occurrence to our dictionary. This is because they include a lot of scientific/medical terms which might not already be there

In [22]:
freq_words = [word for word in word_freq.keys() if word_freq[word] > 3]
add_to_dictionary = " ".join(freq_words)
f=open("data/mimic_dict.txt", "w+")
f.write(add_to_dictionary)
f.close()

In [23]:
spell = SpellChecker()
spell.distance = 1  # set the distance parameter to just 1 edit away - much quicker
spell.word_frequency.load_text_file('data/mimic_dict.txt')

In [24]:
misspelled = spell.unknown(infreq_words)
misspell_dict = {}
for i, word in enumerate(misspelled):
    if (word != spell.correction(word)):
        misspell_dict[word] = spell.correction(word)
    if (i % 100 == 0):
        print(i)

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000


KeyboardInterrupt: 

In [13]:
print(len(misspell_dict))
misspell_dict

46432


{'withrawing': 'withdrawing',
 'redistrubution': 'redistribution',
 'silocone': 'silicone',
 '.s1': 's1',
 'deterioate': 'deteriorate',
 'hosptialzation': 'hospitalzation',
 'coloc': 'colon',
 'incidintally': 'incidentally',
 'loprsesor': 'lopressor',
 'doscontinued': 'discontinued',
 'reitterated': 'reiterated',
 'vancomycis': 'vancomycin',
 'flumzenil': 'flumazenil',
 'obsructions': 'obstructions',
 'antipsychoitics': 'antipsychotics',
 'mycfungin': 'micfungin',
 'palpaltion': 'palpation',
 'miinutes': 'minutes',
 'congition': 'condition',
 'pliva': 'oliva',
 'paretal': 'parental',
 'doxasozyn': 'doxasozin',
 'stidy': 'study',
 'ketalog': 'kenalog',
 'changesa': 'changes',
 'makung': 'making',
 'pracentesis': 'parcentesis',
 'brinchoscopies': 'bronchoscopies',
 'overalli': 'overall',
 'twingy': 'thingy',
 'perninious': 'pernicious',
 'succeptible': 'susceptible',
 'pancreatit': 'pancreatic',
 'abue': 'able',
 'startedv': 'started',
 'transalvular': 'transvalvular',
 'retinoids': 'ret

We now have correct spellings for many words in our dictionary that occurred <= 3 times. Anything else will be marked as `UNK`. We will save these as text files to avoid having to run this computation again.

In [17]:
unk_words = [word for word in infreq_words if word not in list(misspell_dict.keys())]
len(unk_words)

KeyboardInterrupt: 

In [None]:
if 'treatement' in misspell_dict.keys():
    print (misspell_dict['treatement'])

In [None]:
np.savetxt('data/discharge_unk_words.txt', unk_words, fmt='%s', newline=os.linesep)
f=open("data/discharge_typos.txt", "w+")

for key in misspell_dict:
    f.write(key + '\t' + misspell_dict[key] + '\n')
f.close()

#pd.read_csv('data/discharge_typos.txt', sep='\t',header=None)

 We will correct spelling mistakes whilst any word left uncorrected will be replaced with `<UNK>`

In [None]:
def fix_typos(text, typos, unks):
    
    tokens = text.split()
    
    for token in tokens:
        if token.lower() in typos.keys():
            token = typos[token.lower()]
        elif token.lower() in unks:
            token = "<UNK>"
        tokenised_text = tokenised_text + token + " "
    
    return tokenised_text

In [None]:
# apply tokenising function elementwise
df["text"] = df["text"].apply(fix_typos, args = (misspell_dict, unk_words))
df.head()

In [None]:
' '.join(df.head()['text'])

The text field has now been fully cleaned and tokenised. We can proceed to extract the first few tokens to use as a hint and move forth with joining other tables and partitioning the dataset.

In [26]:
counter = 0
def produce_hint(text):
    global counter
    l = text.split()
    counter += 1
    if (counter % 10000) == 0:
        print (counter)
    return ' '.join(l[:10]) # first 10 tokens

df['hint'] = df['text'].map(lambda x: produce_hint(x))

print(df.shape)
df.head()

10000
20000
30000
40000
50000
(55404, 10)


Unnamed: 0,subject_id,dob,gender,hadm_id,category,chartdate,row_id,age_at_noteevent,text,hint
0,3,2025-04-11,M,145834,Discharge summary,2101-10-31,44005,77,Admission Date : [ 2101/10/20 ] Discharge Date...,Admission Date : [ 2101/10/20 ] Discharge Date...
1,4,2143-05-12,F,185777,Discharge summary,2191-03-23,4788,48,Admission Date : [ 2191/3/16 ] Discharge Date ...,Admission Date : [ 2191/3/16 ] Discharge Date : [
2,6,2109-06-21,F,107064,Discharge summary,2175-06-15,20825,66,Admission Date : [ 2175/5/30 ] Discharge Date ...,Admission Date : [ 2175/5/30 ] Discharge Date : [
3,9,2108-01-26,M,150750,Discharge summary,2149-11-14,57115,42,"Name : [ Known lastname 10050 ] , [ Known firs...","Name : [ Known lastname 10050 ] , [ Known"
4,9,2108-01-26,M,150750,Discharge summary,2149-11-13,20070,42,Admission Date : [ 2149/11/9 ] Discharge Date ...,Admission Date : [ 2149/11/9 ] Discharge Date : [


In [27]:
# patients above 89 years of age had their dob modified to be 300 years old at time of first event for privacy reasons
# change their age to instead be 90

df.loc[df['age_at_noteevent'] > 200, 'age_at_noteevent'] = 90

In [10]:
# admissions data

admissions = pd.read_csv('data/df_adm.csv', sep=',')
admissions
#df = pd.merge(df, admissions,  how='left', left_on=['subject_id','hadm_id'], right_on = ['subject_id','hadm_id'])

Unnamed: 0,subject_id,hadm_id,ethnicity,diagnosis,admission_type,30d_unplan_readmit
0,3,145834,white,hypotension,EMERGENCY,N
1,4,185777,white,"fever , dehydration , failure to thrive",EMERGENCY,N
2,6,107064,white,chronic renal failure/sda,ELECTIVE,N
3,9,150750,other,hemorrhagic cva,EMERGENCY,N
4,11,194540,white,brain mass,EMERGENCY,N
5,12,112213,white,pancreatic cancer/sda,ELECTIVE,N
6,13,143045,white,coronary artery disease,EMERGENCY,N
7,17,194023,white,patient foramen ovale\ patent foramen ovale mi...,ELECTIVE,N
8,17,161087,white,pericardial effusion,EMERGENCY,N
9,18,188822,white,hypoglycemia;seizures,EMERGENCY,N


In [28]:
# lab items data

df_labitems = pd.read_sql_query('''
  SELECT l.subject_id, l.charttime, l.value, l.valueuom, l.flag, d.label
  FROM labevents l
  INNER JOIN d_labitems d 
  USING (itemid)
  ORDER BY l.subject_id
  LIMIT 10000;
''', cnx)

print(df_labitems.shape)
df_labitems.head()

(10000, 6)


Unnamed: 0,subject_id,charttime,value,valueuom,flag,label
0,2,2138-07-17 20:48:00,0,%,,Atypical Lymphocytes
1,2,2138-07-17 20:48:00,0,%,,Bands
2,2,2138-07-17 20:48:00,0,%,,Basophils
3,2,2138-07-17 20:48:00,0,%,,Eosinophils
4,2,2138-07-17 20:48:00,0,%,abnormal,Hematocrit


In [29]:
# prescriptions data

df_prescriptions = pd.read_sql_query('''
  SELECT subject_id, startdate, enddate, drug, prod_strength
  FROM prescriptions
  ORDER BY subject_id
  LIMIT 10000;
''', cnx)

print(df_prescriptions.shape)
df_prescriptions.head()

(10000, 5)


Unnamed: 0,subject_id,startdate,enddate,drug,prod_strength
0,2,2138-07-18,2138-07-20,NEO*IV*Gentamicin,10mg/mL-2mL
1,2,2138-07-18,2138-07-21,Ampicillin Sodium,500mg Vial
2,2,2138-07-18,2138-07-21,Send 500mg Vial,Send 500mg Vial
3,2,2138-07-18,2138-07-20,Syringe (Neonatal) *D5W*,1 Syringe
4,4,2191-03-16,2191-03-16,Primaquine Phosphate,26.3MG TAB PK


In [11]:
# diagnoses data

df_diagnoses = pd.read_sql_query('''
  SELECT d.subject_id, d.hadm_id, d.seq_num, d.icd9_code, icd.short_title, icd.long_title
  FROM diagnoses_icd d
  INNER JOIN d_icd_diagnoses icd 
  USING (icd9_code)
  ORDER BY d.subject_id, d.seq_num
  LIMIT 10000;
''', cnx)

print(df_diagnoses.shape)
df_diagnoses.head()

(10000, 6)


Unnamed: 0,subject_id,hadm_id,seq_num,icd9_code,short_title,long_title
0,2,163353,1,V3001,Single lb in-hosp w cs,"Single liveborn, born in hospital, delivered b..."
1,2,163353,2,V053,Need prphyl vc vrl hepat,Need for prophylactic vaccination and inoculat...
2,2,163353,3,V290,NB obsrv suspct infect,Observation for suspected infectious condition
3,3,145834,1,0389,Septicemia NOS,Unspecified septicemia
4,3,145834,2,78559,Shock w/o trauma NEC,Other shock without mention of trauma


In [12]:
# procedures data

df_procedures = pd.read_sql_query('''
  SELECT p.subject_id, p.hadm_id, p.seq_num, p.icd9_code, icd.short_title, icd.long_title
  FROM procedures_icd p
  INNER JOIN d_icd_procedures icd 
  USING (icd9_code)
  ORDER BY p.subject_id, p.seq_num
  LIMIT 10000;
''', cnx)

print(df_procedures.shape)
df_procedures.head()

(10000, 6)


Unnamed: 0,subject_id,hadm_id,seq_num,icd9_code,short_title,long_title
0,2,163353,1,9955,Vaccination NEC,Prophylactic administration of vaccine against...
1,3,145834,1,9604,Insert endotracheal tube,Insertion of endotracheal tube
2,3,145834,2,9962,Heart countershock NEC,Other electric countershock of heart
3,3,145834,3,8964,Pulmon art wedge monitor,Pulmonary artery wedge monitoring
4,3,145834,4,9672,Cont inv mec ven 96+ hrs,Continuous invasive mechanical ventilation for...


In [15]:
df_proc_events = pd.read_sql_query('''
  SELECT * FROM procedureevents_mv
  ORDER BY subject_id
  --LIMIT 10000;
''', cnx)

print(df_proc_events.shape)
df_proc_events.head()

(258066, 25)


Unnamed: 0,row_id,subject_id,hadm_id,icustay_id,starttime,endtime,itemid,value,valueuom,location,...,ordercategoryname,secondaryordercategoryname,ordercategorydescription,isopenbag,continueinnextdept,cancelreason,statusdescription,comments_editedby,comments_canceledby,comments_date
0,74780,23,124321,234044.0,2157-10-21 12:15:00,2157-10-22 16:08:00,224276,1673.0,min,,...,Peripheral Lines,,Task,1,0,0,FinishedRunning,,,NaT
1,74779,23,124321,234044.0,2157-10-21 12:15:00,2157-10-22 16:08:00,224276,1673.0,min,,...,Peripheral Lines,,Task,1,0,0,FinishedRunning,,,NaT
2,74778,23,124321,234044.0,2157-10-21 12:15:00,2157-10-22 16:08:00,224263,1673.0,min,,...,Invasive Lines,,Task,1,0,0,FinishedRunning,,,NaT
3,74777,23,124321,234044.0,2157-10-21 12:15:00,2157-10-22 16:08:00,225752,1673.0,min,,...,Invasive Lines,,Task,1,0,0,FinishedRunning,,,NaT
4,162248,34,144319,290505.0,2191-02-23 11:40:00,2191-02-23 11:41:00,227194,1.0,,,...,Intubation/Extubation,,Electrolytes,0,0,0,FinishedRunning,,,NaT


In [36]:
#%%timeit -n 3 -r 3

# Split the dataset in a grouped and stratified manner

def StratifiedGroupShuffleSplit(df_main):

    df_main = df_main.reindex(np.random.permutation(df_main.index)) # shuffle dataset
    
    # create empty train, val and test datasets
    df_train = pd.DataFrame()
    df_val = pd.DataFrame()
    df_test = pd.DataFrame()

    hparam_mse_wgt = 0.1 # must be between 0 and 1
    assert(0 <= hparam_mse_wgt <= 1)
    train_proportion = 0.8 # must be between 0 and 1
    assert(0 <= train_proportion <= 1)
    val_test_proportion = (1-train_proportion)/2

    subject_grouped_df_main = df_main.groupby(['subject_id'], sort=False, as_index=False)
    gender_grouped_df_main = df_main.groupby('gender').count()[['subject_id']]/len(df_main)*100 
    
    # function to calculate loss
    def calc_mse_loss(df):
        grouped_df = df.groupby('gender').count()[['subject_id']]/len(df)*100
        df_temp = gender_grouped_df_main.join(grouped_df, on = 'gender', how = 'left', lsuffix = '_main')
        df_temp.fillna(0, inplace=True)
        df_temp['diff'] = (df_temp['subject_id_main'] - df_temp['subject_id'])**2
        mse_loss = np.mean(df_temp['diff'])
        return mse_loss
    
    directory = "data/preprocessed/"
    
    f_train = open(directory + "src-train.txt","w+")
    f_val = open(directory + "src-val.txt","w+")
    f_test = open(directory + "src-test.txt","w+")
    
    len_train = 0
    len_val = 0
    len_test = 0
    total_records = 0
    i = 0

    # loop the groups of subjects one by one
    for _, group in subject_grouped_df_main:

        total_records = len_train + len_val + len_test
        g = pd.DataFrame(group)
        subject_id = g['subject_id'].iloc[0]
        
        pre_left = df_prescriptions['subject_id'].searchsorted(subject_id, 'left')
        pre_right = df_prescriptions['subject_id'].searchsorted(subject_id, 'right')
        
        lab_left = df_labitems['subject_id'].searchsorted(subject_id, 'left')
        lab_right = df_labitems['subject_id'].searchsorted(subject_id, 'right')
        
        g_prescriptions = df_prescriptions[pre_left:pre_right]
        g_labitems = df_labitems[lab_left:lab_right]
        i += 1
        
        train = False
        val = False
        test = False
        
        # first three groups only
        if (i < 4):
            if (i == 1):
                df_train = df_train.append(g, ignore_index=True)
                len_train += len(g)
                train = True
            elif (i == 2):
                df_val = df_val.append(g, ignore_index=True)
                len_val += len(g)
                val = True
            else:
                df_test = df_test.append(g, ignore_index=True)
                len_test += len(g)
                test = True
        
        # all the other groups except every 500th
        if ((i % 1000 != 0) & (i > 3)):
            
            if (train_proportion > (len_train/total_records)):
                df_train = df_train.append(g, ignore_index=True)
                len_train += len(g)
                train = True
            elif (val_test_proportion > (len_val/total_records)):
                df_val = df_val.append(g, ignore_index=True)
                len_val += len(g)
                val = True
            else:
                df_test = df_test.append(g, ignore_index=True)
                len_test += len(g)
                test = True
        
        # every 500th group, balance the groups by proportion and by categories
        elif (i % 500 == 0):
            
            mse_loss_diff_train = calc_mse_loss(df_train) - calc_mse_loss(df_train.append(g, ignore_index=True))
            mse_loss_diff_val = calc_mse_loss(df_val) - calc_mse_loss(df_val.append(g, ignore_index=True))
            mse_loss_diff_test = calc_mse_loss(df_test) - calc_mse_loss(df_test.append(g, ignore_index=True))

            len_diff_train = (train_proportion - (len_train/total_records))
            len_diff_val = (val_test_proportion - (len_val/total_records))
            len_diff_test = (val_test_proportion - (len_test/total_records)) 

            len_loss_diff_train = len_diff_train * abs(len_diff_train)
            len_loss_diff_val = len_diff_val * abs(len_diff_val)
            len_loss_diff_test = len_diff_test * abs(len_diff_test)

            loss_train = (hparam_mse_wgt * mse_loss_diff_train) + ((1-hparam_mse_wgt) * len_loss_diff_train)
            loss_val = (hparam_mse_wgt * mse_loss_diff_val) + ((1-hparam_mse_wgt) * len_loss_diff_val)
            loss_test = (hparam_mse_wgt * mse_loss_diff_test) + ((1-hparam_mse_wgt) * len_loss_diff_test)

            if (max(loss_train,loss_val,loss_test) == loss_train):
                df_train = df_train.append(g, ignore_index=True)
                len_train += len(g)
                train = True
            elif (max(loss_train,loss_val,loss_test) == loss_val):
                df_val = df_val.append(g, ignore_index=True)
                len_val += len(g)
                val = True
            else:
                df_test = df_test.append(g, ignore_index=True)
                len_test += len(g)
                test = True
            
            print ("Group " + str(i) + ". loss_train: " + str(loss_train) + " | " + "loss_val: " + str(loss_val) + " | " + "loss_test: " + str(loss_test) + " | ")
        
        # loop through every row in the group to get relevant prescriptions and lab items before appending to file
        for j, row in enumerate(g.itertuples()):
            
            chartdate = datetime.combine(row[6], datetime.min.time())
            cutoff = chartdate
            chartdate = cutoff + timedelta(days=1)    
            
            lab_condition = np.logical_and((g_labitems.charttime >= cutoff),
                                           (g_labitems.charttime < chartdate))
            lab_items = g_labitems[lab_condition]
            lab_items = lab_items.sort_values(by=['charttime'], ascending=False)
            
            pre_condition = np.logical_and((g_prescriptions.startdate >= cutoff),
                                           (g_prescriptions.startdate < chartdate))
            prescriptions = g_prescriptions[pre_condition]
            prescriptions = prescriptions.sort_values(by=['startdate'], ascending=False)
            
            lab_items_list = ""
            lab_items_length = len(lab_items)
            if (lab_items_length > 0):
                for k, lab_row in enumerate(lab_items.itertuples()):
                    flag = ""
                    if (pd.isna(lab_row[5]) == False):
                        flag = " , " + str(lab_row[5])

                    lab_items_list += str(lab_row[6]) + " , " + str(lab_row[3]) + " , " + str(lab_row[4]) + flag
                    if (k != (lab_items_length - 1)):
                        lab_items_list += " | "

            prescriptions_list = ""
            prescriptions_length = len(prescriptions)
            if (prescriptions_length > 0):
                for k, pre_row in enumerate(prescriptions.itertuples()):
                    prescriptions_list += str(pre_row[4]) + " , " + str(pre_row[5])
                    if (k != (prescriptions_length - 1)):
                        prescriptions_list += " | "
            
            if (train == True):
                f_train.write(str(row[10]) + " <H> " + str(row[3]) + " <G> " + str(row[8]) + " <A> " + 
                    prescriptions_list + " <M> " + lab_items_list + " <L>" + "\n")
            elif (val == True):
                f_val.write(str(row[10]) + " <H> " + " <T> " + str(row[3]) + " <G> " + str(row[8]) + " <A> " + 
                    prescriptions_list + " <M> " + lab_items_list + " <L>" + "\n")
            else:
                f_test.write(str(row[10]) + " <H> " + " <T> " + str(row[3]) + " <G> " + str(row[8]) + " <A> " + 
                    prescriptions_list + " <M> " + lab_items_list + " <L>" + "\n")
        
        if (i % 100 == 0):
            print (i)
    
    f_train.close()
    f_val.close()
    f_test.close()
    
    return df_train, df_val, df_test

In [37]:
src_train, src_val, src_test = StratifiedGroupShuffleSplit(df)

100
200
300


KeyboardInterrupt: 