In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import csv
import pickle5 as pickle
import os
import string
import gzip

import Levenshtein
import numpy as np
from tensorflow.keras.preprocessing.text import Tokenizer, text_to_word_sequence
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [2]:
MIN_WORD_OCCURENCES = 5
MIN_SEQ_LEN = 9
MAX_SEQ_LEN = 2200
DATA_DIR = '../data/'

In [3]:
df_diag_diabetes_hadm_ids = pickle.load(open(f'{DATA_DIR}diag_diabetes_hadm_ids.p','rb'))
print('df_diag_diabetes_hadm_ids.shape:   ', df_diag_diabetes_hadm_ids.shape)
diag_diabetes_hadm_ids = df_diag_diabetes_hadm_ids.to_list()

diag_diabetes_hadm_ids_dict =  {str(k): v for v, k in enumerate(diag_diabetes_hadm_ids)}

print('len diag_diabetes_hadm_ids_dict: ', len(diag_diabetes_hadm_ids_dict))


df_diag_diabetes_hadm_ids.shape:    (14222,)
len diag_diabetes_hadm_ids_dict:  14222


In [4]:
print('100458' in diag_diabetes_hadm_ids_dict)

True


In [5]:
regular_diagnosis = {}
rolled_diagnosis = {}
regular_icd9_lookup = []
rolled_icd9_lookup = []

# with open('diagnosis.csv', 'rb') as f:
with gzip.open(f'{DATA_DIR}DIAGNOSES_ICD.csv.gz', 'rt') as f:
    reader = csv.reader(f)
    next(reader)
    for row in reader:
        if row[2] in diag_diabetes_hadm_ids_dict:
            note_id = (row[1], row[2])
            regular_icd9 = row[4]
            rolled_icd9 = regular_icd9[:3]

            if regular_icd9 not in regular_icd9_lookup:
                regular_icd9_lookup.append(regular_icd9)
            if rolled_icd9 not in rolled_icd9_lookup:
                rolled_icd9_lookup.append(rolled_icd9)

            regular_note_diagnosis = regular_diagnosis.get(note_id, [])
            rolled_note_diagnosis = rolled_diagnosis.get(note_id, [])
            regular_idx = regular_icd9_lookup.index(regular_icd9)
            rolled_idx = rolled_icd9_lookup.index(rolled_icd9)

            if regular_idx not in regular_note_diagnosis:
                regular_diagnosis[note_id] = regular_note_diagnosis + [regular_idx]
            if rolled_idx not in rolled_note_diagnosis:
                rolled_diagnosis[note_id] = rolled_note_diagnosis + [rolled_idx]

In [7]:
print(len(regular_diagnosis), len(rolled_diagnosis), len(regular_icd9_lookup), len(rolled_icd9_lookup))
# 58976 58976 6985 943
# 14222 14222 4103 781

pickle.dump( regular_diagnosis, open( f'{DATA_DIR}regular_diagnosis.pickle', "wb" ) )

regular_diagnosis = pickle.load(open( f'{DATA_DIR}regular_diagnosis.pickle', "rb" ) )

print(len(regular_diagnosis), len(rolled_diagnosis), len(regular_icd9_lookup), len(rolled_icd9_lookup))

14222 14222 4103 781
14222 14222 4103 781


In [10]:
texts = []
regular_labels = []
rolled_labels = []

unique_categories = []
texts_categories = []


# tt = string.maketrans(string.digits, 'd' * len(string.digits))
tt = str.maketrans(string.digits, 'd' * len(string.digits))
# with open('notes.csv', 'rb') as f:
count_row = 0
with gzip.open(f'{DATA_DIR}NOTEEVENTS.csv.gz', 'rt') as f:
    reader = csv.reader(f)
    next(reader)
    for row in reader:
        count_row = count_row + 1

        key = (row[1], row[2])
        cat = row[6]
        if count_row == 10:
            print(key, cat, row)
            
        if cat not in unique_categories:
            unique_categories.append(cat)

        if key in regular_diagnosis:
            text = row[-1].strip().translate(tt)
            if text:
                texts.append(text)
                texts_categories.append(unique_categories.index(cat))
                regular_labels.append(regular_diagnosis[key])
                rolled_labels.append(rolled_diagnosis[key])
print('Notes processing done!!')

('56174', '189681') Discharge summary ['183', '56174', '189681', '2118-12-09', '', '', 'Discharge summary', 'Report', '', '', 'Admission Date:  [**2118-12-7**]              Discharge Date:   [**2118-12-9**]\n\nDate of Birth:  [**2073-12-25**]             Sex:   F\n\nService: NEUROSURGERY\n\nAllergies:\nCodeine\n\nAttending:[**First Name3 (LF) 1854**]\nChief Complaint:\nSkull defect\n\nMajor Surgical or Invasive Procedure:\ns/p cranioplasty on [**2118-12-7**]\n\n\nHistory of Present Illness:\n44 yo female with a h/o left frontal AVM in the supplementary\nmotor area. The AVM was treated with stereotactic radiosurgery\n(Gamma Knife)in [**2114**]. In [**2116**], the patient developed a seizure\ndisorder. [**2118-5-27**] she developed\nheadaches and after an MRI and a digital angiogram showed no\nresidual pathological vessels, a contrast enhancing lesion\nwith massive focal residual edema was diagnosed- very\nlikely represents radionecrosis. The patient had midline\nshift and mass effect. O

In [11]:
print(len(texts), len(texts_categories), len(regular_labels), len(rolled_labels))
# 1851286 1851286 1851286 1851286
# 406190 406190 406190 406190

406190 406190 406190 406190


In [12]:
    print('Average regular labels per report:',
          sum(map(len, regular_labels)) / len(regular_labels))
    print('Average rolled labels per report:',
          sum(map(len, rolled_labels)) / len(rolled_labels))

Average regular labels per report: 17.420187596937392
Average rolled labels per report: 15.754735468623059


In [13]:
  tokenizer = Tokenizer()
  tokenizer.fit_on_texts(texts)

In [14]:
  word_index = tokenizer.word_index
  print('Unique tokens *before* preprocessing:', len(word_index))
  # Unique tokens *before* preprocessing: 371809

Unique tokens *before* preprocessing: 162715


In [15]:
  # Segment words on frequently/infrequently occuring
  frequent_words = []
  infrequent_words = []
  for word, count in tokenizer.word_counts.items():
      if count < MIN_WORD_OCCURENCES:
          infrequent_words.append(word)
      else:
          frequent_words.append(word)

In [16]:
print(len(frequent_words), len(infrequent_words))
# 109640 262169

53229 109486


In [20]:
def generate_infrequent_word_mapping(infrequent_words, frequent_words):
    print('infrequent_words len: ', len(infrequent_words))
    print('frequent_words len: ', len(frequent_words))
    if not os.path.exists(f'{DATA_DIR}infrequent_word_mapping.pickle'):
        infrequent_word_mapping = {}
        for idx, word in enumerate(infrequent_words):
            if idx % 1000 == 0:
                print('infrequent_words processed: ', idx)
            dists = np.vectorize(lambda x: Levenshtein.distance(word, x))(frequent_words)
            most_similar_word = frequent_words[np.argmin(dists)]
            infrequent_word_mapping[word] = most_similar_word
        with open(f'{DATA_DIR}infrequent_word_mapping.pickle', 'wb') as f:
            pickle.dump(infrequent_word_mapping, f, protocol=pickle.HIGHEST_PROTOCOL)
    else:
        with open(f'{DATA_DIR}infrequent_word_mapping.pickle', 'rb') as f:
            infrequent_word_mapping = pickle.load(f)
    return infrequent_word_mapping

In [21]:
infrequent_word_mapping = generate_infrequent_word_mapping(infrequent_words,
                                                               frequent_words)

infrequent_words len:  109486
frequent_words len:  53229
infrequent_words processed:  0
infrequent_words processed:  1000
infrequent_words processed:  2000
infrequent_words processed:  3000
infrequent_words processed:  4000
infrequent_words processed:  5000
infrequent_words processed:  6000
infrequent_words processed:  7000
infrequent_words processed:  8000
infrequent_words processed:  9000
infrequent_words processed:  10000
infrequent_words processed:  11000
infrequent_words processed:  12000
infrequent_words processed:  13000
infrequent_words processed:  14000
infrequent_words processed:  15000
infrequent_words processed:  16000
infrequent_words processed:  17000
infrequent_words processed:  18000
infrequent_words processed:  19000
infrequent_words processed:  20000
infrequent_words processed:  21000
infrequent_words processed:  22000
infrequent_words processed:  23000
infrequent_words processed:  24000
infrequent_words processed:  25000
infrequent_words processed:  26000
infrequent_

In [22]:
# Replace infrequent word with a frequent similar word
infrequent_word_index = {}
for word in infrequent_words:
    most_similar_word = infrequent_word_mapping[word]
    infrequent_word_index[word] = word_index[most_similar_word]
    del word_index[word]
 
print('Unique tokens *after* preprocessing:', len(word_index))

Unique tokens *after* preprocessing: 53229


In [23]:
# Reimplementation of `tokenizer.texts_to_sequences`
sequences = []
for text in texts:
    seq = text_to_word_sequence(text)
    vec = []
    for word in seq:
        idx = word_index.get(word)
        if idx is not None:
            vec.append(idx)
        else:
            vec.append(infrequent_word_index[word])
    sequences.append(vec)

In [31]:
# Sequence must be < MAX_SEQ_LEN and > MIN_SEQ_LEN
seqs = []
cats = []
reg_labels = []
rol_labels = []
for seq, cat, reg, rol in zip(sequences, texts_categories, regular_labels, rolled_labels):
    if len(seq) < MAX_SEQ_LEN and len(seq) > MIN_SEQ_LEN:
        seqs.append(seq)
        cats.append(cat)
        reg_labels.append(reg)
        rol_labels.append(rol)
sequences = seqs
texts_categories = cats
regular_labels = reg_labels
rolled_labels = rol_labels

lens = list(map(len, sequences))

print('Shortest sequence has', min(lens), 'tokens')
print('Longest sequences has', max(lens), 'tokens')
print('Average tokens per sequence:', sum(lens) / len(sequences))


Shortest sequence has 10 tokens
Longest sequences has 2199 tokens
Average tokens per sequence: 309.05638814582744


In [32]:

with open(f'{DATA_DIR}word_index.pickle', 'wb') as f:
    pickle.dump(tokenizer.word_index, f, protocol=pickle.HIGHEST_PROTOCOL)

texts = pad_sequences(sequences)

# Encode labels as k-hot
reg = np.zeros((len(regular_labels), len(regular_icd9_lookup)), dtype=np.int32)
rol = np.zeros((len(rolled_labels), len(rolled_icd9_lookup)), dtype=np.int32)
for i, label in enumerate(regular_labels): reg[i][label] = 1
for i, label in enumerate(rolled_labels): rol[i][label] = 1


In [None]:
regular_labels = reg
rolled_labels = rol


In [39]:
for idx, lab in enumerate(rolled_labels[1]):
  if lab == True:
    print(idx, lab)

4 1
8 1
24 1
31 1
38 1
53 1
61 1
73 1
99 1


In [40]:
# Encode categories as 1-hot
cats = np.zeros((len(texts_categories), len(unique_categories)), dtype=np.float32)
for i, cat in enumerate(texts_categories): cats[i][cat] = 1
texts_categories = cats

# keep labels with >= 1 examples
regular_icd9_lookup = np.asarray(regular_icd9_lookup)
rolled_icd9_lookup = np.asarray(rolled_icd9_lookup)

keep = np.sum(regular_labels, 0) >= 1
regular_labels = regular_labels[:, keep]
regular_icd9_lookup = regular_icd9_lookup[keep]
keep = np.sum(rolled_labels, 0) >= 1
rolled_labels = rolled_labels[:, keep]
rolled_icd9_lookup = rolled_icd9_lookup[keep]

np.savez(f'{DATA_DIR}icd9_lookup.npz',
          regular_icd9_lookup=regular_icd9_lookup,
          rolled_icd9_lookup=rolled_icd9_lookup)

print('Texts shape:', texts.shape)
print('Categories shape:', texts_categories.shape)
print('Regular labels shape:', regular_labels.shape)
print('Rolled labels shape:', rolled_labels.shape)

# Shuffle
if os.path.exists(f'{DATA_DIR}shuffled_indices.npy'):
    indices = np.load(f'{DATA_DIR}shuffled_indices.npy')
else:
    indices = np.arange(texts.shape[0])
    np.random.shuffle(indices)
    np.save(f'{DATA_DIR}shuffled_indices.npy', indices)
texts = texts[indices]
texts_categories = texts_categories[indices]
regular_labels = regular_labels[indices]
rolled_labels = rolled_labels[indices]

np.savez(f'{DATA_DIR}data.npz',
          x=texts, cats=texts_categories,
          reg_y=regular_labels, rol_y=rolled_labels)

Texts shape: (399623, 2199)
Categories shape: (399623, 15)
Regular labels shape: (399623, 4097)
Rolled labels shape: (399623, 780)


In [43]:
for idx, txt in enumerate(texts[0]):
  if txt != 0:
    print(idx, txt)

2007 169
2008 17
2009 41
2010 241
2011 5429
2012 223
2013 2041
2014 3
2015 13
2016 1458
2017 3
2018 3432
2019 1859
2020 13237
2021 4709
2022 104
2023 3673
2024 3
2025 6233
2026 223
2027 155
2028 3916
2029 1593
2030 9
2031 77
2032 12066
2033 8578
2034 11
2035 6
2036 389
2037 389
2038 17
2039 1312
2040 19
2041 1
2042 395
2043 98
2044 536
2045 16503
2046 658
2047 837
2048 100
2049 3
2050 1947
2051 2215
2052 2
2053 1570
2054 2233
2055 567
2056 4
2057 5
2058 1732
2059 116
2060 162
2061 10
2062 81
2063 217
2064 4
2065 668
2066 261
2067 895
2068 713
2069 3
2070 806
2071 895
2072 1
2073 1278
2074 26
2075 16092
2076 2
2077 48
2078 1
2079 3237
2080 2317
2081 7
2082 552
2083 241
2084 2
2085 3417
2086 7
2087 2796
2088 54
2089 9008
2090 13
2091 295
2092 949
2093 7
2094 789
2095 156
2096 1402
2097 248
2098 4
2099 1629
2100 399
2101 530
2102 2406
2103 9
2104 15429
2105 248
2106 1460
2107 1072
2108 9
2109 286
2110 491
2111 745
2112 121
2113 56
2114 396
2115 19
2116 521
2117 197
2118 13
2119 206
2120 1

In [44]:
for idx, lab in enumerate(rolled_labels[1]):
  if lab == True:
    print(idx, lab)

6 1
7 1
8 1
14 1
18 1
27 1
31 1
33 1
38 1
44 1
71 1
72 1
73 1
78 1
93 1
94 1
98 1
100 1
128 1
200 1
326 1
498 1
