In [3]:
import sys
import datasets

import numpy as np
from  numpy import int32
import pandas as pd

from collections import Counter, defaultdict
from os import path
import csv
import jsonlines
import math
import operator
import itertools

import plotly.offline as py
import plotly.graph_objs as go

In [4]:
SPLITS_DIR = 'splits' #folder containing the dataset splits information
MIMIC_DATA_DIR = '/path/to/mimicdata/' #path to the folder containing the MIMIC3 data (see readme)
OUT_DIR = '/path/to/out_dir' #path to the output folder
notes_tokenized_file = '/path/to/notes_tokenized.ndjson' #path to the tokenized notes (output of preprocess.py)

vocab_min = 3 #discard tokens appearing in fewer than this many documents

# Data processing

## Combine diagnosis and procedure codes and reformat them

The codes in MIMIC-III are given in separate files for procedures and diagnoses, and the codes are given without periods, which might lead to collisions if we naively combine them. So we have to add the periods back in the right place.

In [6]:
dfproc = pd.read_csv(path.join(MIMIC_DATA_DIR, 'PROCEDURES_ICD.csv'), dtype={'SUBJECT_ID':int32, 'HADM_ID':int32, 'ICD9_CODE':str}, keep_default_na=False, na_values='')
dfproc_desc = pd.read_csv(path.join(MIMIC_DATA_DIR, 'D_ICD_PROCEDURES.csv'), dtype={'ICD9_CODE':str}, keep_default_na=False, na_values='')
dfdiag = pd.read_csv(path.join(MIMIC_DATA_DIR, 'DIAGNOSES_ICD.csv'), dtype={'SUBJECT_ID':int32, 'HADM_ID':int32, 'ICD9_CODE':str}, keep_default_na=False, na_values='')
dfdiag_desc = pd.read_csv(path.join(MIMIC_DATA_DIR, 'D_ICD_DIAGNOSES.csv'), dtype={'ICD9_CODE':str}, keep_default_na=False, na_values='')
dfproc['ICD9_CODE'] = dfproc['ICD9_CODE'].fillna('')
dfproc = dfproc.dropna(axis=0, subset=['SUBJECT_ID', 'HADM_ID'])
dfdiag['ICD9_CODE'] = dfdiag['ICD9_CODE'].fillna('')
dfdiag = dfdiag.dropna(axis=0, subset=['SUBJECT_ID', 'HADM_ID'])

In [7]:
dfdiag['ICD9_CODE'] = dfdiag.apply(lambda row: str(datasets.reformat(str(row[4]), True)), axis=1)
dfdiag_desc['ICD9_CODE'] = dfdiag.apply(lambda row: str(datasets.reformat(str(row[4]), True)), axis=1)
dfproc['ICD9_CODE'] = dfproc.apply(lambda row: str(datasets.reformat(str(row[4]), False)), axis=1)
dfproc_desc['ICD9_CODE'] = dfproc.apply(lambda row: str(datasets.reformat(str(row[4]), False)), axis=1)
dfcodes = pd.concat([dfdiag, dfproc])

## Load tokenized notes

- Select only discharge summaries and their addenda

In [9]:
df = pd.read_json(notes_tokenized_file, lines=True, dtype={0:int32, 1:int32, 5:bool, 6:int32})
df.columns = ['SUBJECT_ID', 'HADM_ID', 'CHARTDATE', 'CATEGORY', 'DESCRIPTION', 'ISERROR', 'N_TOKENS', 'TEXT']
df = df[(df['CATEGORY'] == 'Discharge summary') & (df['ISERROR'] == False)]

In [10]:
df = df.sort_values(['CHARTDATE']).groupby(['SUBJECT_ID', 'HADM_ID'], as_index=False).agg({'TEXT': lambda doclist : list(itertools.chain(sent for doc in doclist for sent in doc)),'N_TOKENS': sum})

In [11]:
distinct_tok = len(set(token for doc in df['TEXT'] for sentence in doc for token in sentence))
num_tok = df['N_TOKENS'].sum()

In [13]:
print("Num distinct tokens", str(distinct_tok))
print("Num tokens", str(num_tok))

Num distinct tokens 154543
Num tokens 79335737


In [14]:
len(df['HADM_ID'].unique()), len(dfcodes['HADM_ID'].unique())

(52726, 58976)

# Join labels with set of discharge summaries

Looks like there were some HADM_ID's that didn't have discharge summaries, so they weren't included with our notes

In [25]:
hadm_ids = pd.DataFrame(df['HADM_ID'].unique(), columns=['HADM_ID'], dtype=int32)
dfcodes = pd.merge(dfcodes, hadm_ids, how='inner', on=['HADM_ID'])

In [16]:
len(dfcodes['HADM_ID'].unique()), len(dfcodes['ICD9_CODE'].unique())

(58976, 9017)

## Append labels to notes in a single file

In [17]:
dfcodes_grouped = dfcodes.groupby(['SUBJECT_ID', 'HADM_ID'], as_index=False).agg({'ICD9_CODE' : lambda codes: list([str(code) for code in codes if len(str(code)) > 0])}).rename(columns={'ICD9_CODE':'LABELS'})
dflabeled = pd.merge(df, dfcodes_grouped, how='inner', on=['SUBJECT_ID', 'HADM_ID'])
with jsonlines.open(path.join(OUT_DIR, 'notes_labeled.ndjson'), 'w') as labeled:
    for index, note in dflabeled.iterrows():
        labeled.write(note.tolist())

Let's sanity check the combined data we just made. Do we have all hadm id's accounted for, and the same vocab stats?

In [18]:
cnt = Counter([token for doc in dflabeled['TEXT'] for sentence in doc for token in sentence])
distinct_tok = len(cnt.keys())
num_tok = sum(cnt.values())

In [19]:
print("num distinct tokens", len(cnt.keys()), "num tokens", sum(cnt.values()))

num distinct tokens 154543 num tokens 79335737


In [20]:
len(dflabeled['HADM_ID'].unique())

52726

In [21]:
py.init_notebook_mode(connected=True)
x = np.linspace(0, 1, 1000)
y = [dflabeled['N_TOKENS'].quantile(i) for i in x]
trace = go.Scatter(x=x, y=y)
data = [trace]
py.plot(data, filename=path.join(OUT_DIR, 'notes_len_quantiles.html'))
py.iplot(data)

## Create train/dev/test splits

In [22]:
base_name = path.join(OUT_DIR, "notes_labeled")
splits = {}
for split in ['train', 'dev', 'test']:
    split_ids_df = pd.read_csv(path.join(SPLITS_DIR, '{}_full_hadm_ids.csv'.format(split)), names=['HADM_ID'], dtype=int32)
    splits[split] = pd.merge(dflabeled, split_ids_df, how='inner', on=['HADM_ID'])
    with jsonlines.open('{}_{}.ndjson'.format(base_name, split), 'w') as split_out:
        for index, row in splits[split].iterrows():
            split_out.write(row.tolist())

## Build vocabulary from training data

In [23]:
vname = path.join(OUT_DIR, 'vocab.csv')

word_cnt = Counter()
doc_cnt = Counter()
for doc in splits['train']['TEXT']:
    tokens = [token for sentence in doc for token in sentence]
    doc_cnt.update(set(tokens))
    word_cnt.update(tokens)

merged_cnt = {word:(count,doc_cnt[word]) for word, count in word_cnt.items()}
words, counts = zip(*merged_cnt.items())
word_count, doc_count = zip(*counts)

words = pd.Series(words)
word_count = pd.Series(word_count)
doc_count = pd.Series(doc_count)

vocab = pd.concat([words, doc_count, word_count], axis=1)
vocab.columns=['word', 'doc_count', 'word_count']
vocab = vocab.sort_values(by=['doc_count'], ascending=False)
vocab.to_csv(path.join(OUT_DIR, 'word_count.csv'), header=False, index=False)
vocab = vocab[vocab['doc_count'] >= vocab_min]
with open(vname, 'w') as f:
    f.write(pd.Series(vocab['word']).str.cat(sep='\n'))

## Filter each split to the top K diagnosis/procedure codes

In [41]:
K = 50

In [46]:
codes_top_K = dflabeled['LABELS'].apply(pd.Series).stack().reset_index(drop=True)
codes_top_K = codes_top_K.groupby(codes_top_K).count().sort_values(ascending=False).head(K)

In [48]:
codes_top_K

401.9     20053
38.93     14444
428.0     12842
427.31    12594
414.01    12179
96.04      9932
96.6       9161
584.9      8907
250.00     8784
96.71      8619
272.4      8504
518.81     7249
99.04      7147
39.61      6809
599.0      6442
530.81     6156
96.72      5926
272.0      5766
285.9      5296
88.56      5240
244.9      4788
486        4733
38.91      4575
285.1      4499
36.15      4390
276.2      4358
496        4296
99.15      4172
995.92     3792
V58.61     3698
507.0      3592
038.9      3580
88.72      3500
585.9      3367
403.90     3350
311        3347
305.1      3272
37.22      3248
412        3203
33.24      3188
39.95      3178
287.5      3002
410.71     3001
276.1      2985
V45.81     2943
424.0      2878
45.13      2849
V15.82     2741
511.9      2693
93.90      2663
dtype: int64

In [49]:
codes_top_K.to_csv(path.join(OUT_DIR, 'top_{}_codes.csv'.format(K)), header = False)

In [64]:
base_name = path.join(OUT_DIR, 'notes_labeled_top_{}'.format(K))
splits_top_K = {}
codes_top_K_set = set(codes_top_K.index)
for split in ['train', 'dev', 'test']:
    split_ids_top_K_df = pd.read_csv(path.join(SPLITS_DIR, '{}_{}_hadm_ids.csv'.format(split, K)), names=['HADM_ID'], dtype=int32)
    splits_top_K[split] = pd.merge(dflabeled, split_ids_top_K_df, how='inner', on=['HADM_ID'])
    splits_top_K[split]['LABELS'] = splits_top_K[split]['LABELS'].apply(lambda labels: list(codes_top_K_set.intersection(set(labels))))
    with jsonlines.open('{}_{}.ndjson'.format(base_name, split), 'w') as split_out:
        for index, row in splits[split].iterrows():
            split_out.write(row.tolist())