In [None]:
from google.colab import drive
drive.mount('/content/drive/',force_remount=True)

Mounted at /content/drive/


In [None]:

path_to_data = "/content/drive/My Drive/eICU_Data"

In [None]:
import tensorflow as tf
print(tf.__version__)

2.15.0


In [None]:
from __future__ import division
from __future__ import print_function

import pickle
import csv
import os
import sys
import numpy as np
import sklearn.model_selection as ms
import tensorflow as tf
from IPython.display import display, HTML


class EncounterInfo(object):

  def __init__(self, patient_id, encounter_id, encounter_timestamp, expired,
               readmission):
    self.patient_id = patient_id
    self.encounter_id = encounter_id
    self.encounter_timestamp = encounter_timestamp
    self.expired = expired
    self.readmission = readmission
    self.dx_ids = []
    self.rx_ids = []
    self.labs = {}
    self.physicals = []
    self.treatments = []


def process_patient(infile, encounter_dict, hour_threshold=24):
  inff = open(infile, 'r')
  count = 0
  patient_dict = {}
  for line in csv.DictReader(inff):
    # if count % 10000 == 0:
    #   print(count)


    patient_id = line['patienthealthsystemstayid']
    encounter_id = line['patientunitstayid']
    encounter_timestamp = -int(line['hospitaladmitoffset'])
    if patient_id not in patient_dict:
      patient_dict[patient_id] = []
    patient_dict[patient_id].append((encounter_timestamp, encounter_id))

  inff.close()
  print('close1')

  patient_dict_sorted = {}
  for patient_id, time_enc_tuples in patient_dict.items():
    patient_dict_sorted[patient_id] = sorted(time_enc_tuples)

  enc_readmission_dict = {}
  for patient_id, time_enc_tuples in patient_dict_sorted.items():
    for time_enc_tuple in time_enc_tuples[:-1]:
      enc_id = time_enc_tuple[1]
      enc_readmission_dict[enc_id] = True
    last_enc_id = time_enc_tuples[-1][1]
    enc_readmission_dict[last_enc_id] = False

  inff = open(infile, 'r')
  count = 0
  for line in csv.DictReader(inff):
    # if count % 10000 == 0:
    #   print(count)

    patient_id = line['patienthealthsystemstayid']
    encounter_id = line['patientunitstayid']
    encounter_timestamp = -int(line['hospitaladmitoffset'])
    discharge_status = line['unitdischargestatus']
    duration_minute = float(line['unitdischargeoffset'])
    expired = True if discharge_status == 'Expired' else False
    readmission = enc_readmission_dict[encounter_id]

    if duration_minute > 60. * hour_threshold:
      continue

    ei = EncounterInfo(patient_id, encounter_id, encounter_timestamp, expired,
                       readmission)
    if encounter_id in encounter_dict:
      print('Duplicate encounter ID!!')
      assert False, "Stopping cell execution"
    encounter_dict[encounter_id] = ei
    count += 1

  inff.close()
  print(count)

  return encounter_dict


def process_admission_dx(infile, encounter_dict):
  inff = open(infile, 'r')
  count = 0
  missing_eid = 0
  for line in csv.DictReader(inff):
    # if count % 10000 == 0:
    #   print(count)

    encounter_id = line['patientunitstayid']
    dx_id = line['admitdxpath'].lower()

    if encounter_id not in encounter_dict:
      missing_eid += 1
      continue
    encounter_dict[encounter_id].dx_ids.append(dx_id)
    count += 1
  inff.close()
  print('Count : %d',count)
  print('Admission Diagnosis without Encounter ID: %d' % missing_eid)

  return encounter_dict


def process_diagnosis(infile, encounter_dict):
  inff = open(infile, 'r')
  count = 0
  missing_eid = 0
  for line in csv.DictReader(inff):
    # if count % 10000 == 0:
    #   sys.stdout.write('%d\r' % count)
    #   sys.stdout.flush()

    encounter_id = line['patientunitstayid']
    dx_id = line['diagnosisstring'].lower()

    if encounter_id not in encounter_dict:
      missing_eid += 1
      continue
    encounter_dict[encounter_id].dx_ids.append(dx_id)
    count += 1
  inff.close()
  print('')
  print('Diagnosis without Encounter ID: %d' % missing_eid)

  return encounter_dict


def process_treatment(infile, encounter_dict):
  inff = open(infile, 'r')
  count = 0
  missing_eid = 0

  for line in csv.DictReader(inff):
    # if count % 10000 == 0:
    #   sys.stdout.write('%d\r' % count)
    #   sys.stdout.flush()

    encounter_id = line['patientunitstayid']
    treatment_id = line['treatmentstring'].lower()
    # print(treatment_id)
    if encounter_id not in encounter_dict:
      missing_eid += 1
      continue
    encounter_dict[encounter_id].treatments.append(treatment_id)
    count += 1
  inff.close()
  print('')
  print('Treatment without Encounter ID: %d' % missing_eid)
  print('Accepted treatments: %d' % count)

  return encounter_dict


def build_seqex(enc_dict,
                skip_duplicate=False,
                min_num_codes=1,
                max_num_codes=50):
  key_list = []
  seqex_list = []
  dx_str2int = {}
  treat_str2int = {}
  num_cut = 0
  num_duplicate = 0
  count = 0
  num_dx_ids = 0
  num_treatments = 0
  num_unique_dx_ids = 0
  num_unique_treatments = 0
  min_dx_cut = 0
  min_treatment_cut = 0
  max_dx_cut = 0
  max_treatment_cut = 0
  num_expired = 0
  num_readmission = 0

  for _, enc in enc_dict.items():
    if skip_duplicate:
      if (len(enc.dx_ids) > len(set(enc.dx_ids)) or
          len(enc.treatments) > len(set(enc.treatments))):
        num_duplicate += 1
        continue

    if len(set(enc.dx_ids)) < min_num_codes:
      min_dx_cut += 1
      continue

    if len(set(enc.treatments)) < min_num_codes:
      min_treatment_cut += 1
      continue

    if len(set(enc.dx_ids)) > max_num_codes:
      max_dx_cut += 1
      continue

    if len(set(enc.treatments)) > max_num_codes:
      max_treatment_cut += 1
      continue

    count += 1
    num_dx_ids += len(enc.dx_ids)
    num_treatments += len(enc.treatments)
    num_unique_dx_ids += len(set(enc.dx_ids))
    num_unique_treatments += len(set(enc.treatments))

    for dx_id in enc.dx_ids:
      if dx_id not in dx_str2int:
        dx_str2int[dx_id] = len(dx_str2int)

    for treat_id in enc.treatments:
      if treat_id not in treat_str2int:
        treat_str2int[treat_id] = len(treat_str2int)

    seqex = tf.train.SequenceExample()
    patient_id_enc = (enc.patient_id + ':' + enc.encounter_id).encode('utf-8')
    seqex.context.feature['patientId'].bytes_list.value.append(patient_id_enc)
    if enc.expired:
      seqex.context.feature['label.expired'].int64_list.value.append(1)
      num_expired += 1
    else:
      seqex.context.feature['label.expired'].int64_list.value.append(0)

    if enc.readmission:
      seqex.context.feature['label.readmission'].int64_list.value.append(1)
      num_readmission += 1
    else:
      seqex.context.feature['label.readmission'].int64_list.value.append(0)

    dx_ids = seqex.feature_lists.feature_list['dx_ids']
    # dx_ids.feature.add().bytes_list.value.extend(list(set(enc.dx_ids)))
    dx_ids.feature.add().bytes_list.value.extend([dx_id.encode('utf-8') for dx_id in set(enc.dx_ids)])


    dx_int_list = [dx_str2int[item] for item in list(set(enc.dx_ids))]
    dx_ints = seqex.feature_lists.feature_list['dx_ints']
    dx_ints.feature.add().int64_list.value.extend(dx_int_list)

    proc_ids = seqex.feature_lists.feature_list['proc_ids']
    # proc_ids.feature.add().bytes_list.value.extend(list(set(enc.treatments)))
    proc_ids.feature.add().bytes_list.value.extend([treatment.encode('utf-8') for treatment in set(enc.treatments)])


    proc_int_list = [treat_str2int[item] for item in list(set(enc.treatments))]
    proc_ints = seqex.feature_lists.feature_list['proc_ints']
    proc_ints.feature.add().int64_list.value.extend(proc_int_list)

    seqex_list.append(seqex)
    key = seqex.context.feature['patientId'].bytes_list.value[0]
    key_list.append(key)

  print('Filtered encounters due to duplicate codes: %d' % num_duplicate)
  print('Filtered encounters due to thresholding: %d' % num_cut)
  print('Average num_dx_ids: %f' % (num_dx_ids / count))
  print('Average num_treatments: %f' % (num_treatments / count))
  print('Average num_unique_dx_ids: %f' % (num_unique_dx_ids / count))
  print('Average num_unique_treatments: %f' % (num_unique_treatments / count))
  print('Min dx cut: %d' % min_dx_cut)
  print('Min treatment cut: %d' % min_treatment_cut)
  print('Max dx cut: %d' % max_dx_cut)
  print('Max treatment cut: %d' % max_treatment_cut)
  print('Number of expired: %d' % num_expired)
  print('Number of readmission: %d' % num_readmission)

  return key_list, seqex_list, dx_str2int, treat_str2int


def select_train_valid_test(key_list, random_seed=1234):
  key_train, key_temp = ms.train_test_split(
      key_list, test_size=0.2, random_state=random_seed)
  key_valid, key_test = ms.train_test_split(
      key_temp, test_size=0.5, random_state=random_seed)
  return key_train, key_valid, key_test


def count_conditional_prob_dp(seqex_list, output_path, train_key_set=None):
  dx_freqs = {}
  proc_freqs = {}
  dp_freqs = {}
  total_visit = 0
  for seqex in seqex_list:
    if total_visit % 1000 == 0:
      print('Visit count: %d\r' % total_visit)

    key = seqex.context.feature['patientId'].bytes_list.value[0]
    if (train_key_set is not None and key not in train_key_set):
      total_visit += 1
      continue

    dx_ids = seqex.feature_lists.feature_list['dx_ids'].feature[
        0].bytes_list.value
    proc_ids = seqex.feature_lists.feature_list['proc_ids'].feature[
        0].bytes_list.value

    for dx in dx_ids:
      if dx not in dx_freqs:
        dx_freqs[dx] = 0
      dx_freqs[dx] += 1

    for proc in proc_ids:
      if proc not in proc_freqs:
        proc_freqs[proc] = 0
      proc_freqs[proc] += 1

    for dx in dx_ids:
      dx_str = dx.decode('utf-8') if isinstance(dx, bytes) else dx
      for proc in proc_ids:
        proc_str = proc.decode('utf-8') if isinstance(proc, bytes) else proc
        dp = dx_str + ',' + proc_str
        if dp not in dp_freqs:
          dp_freqs[dp] = 0
        dp_freqs[dp] += 1

    total_visit += 1

  dx_probs = dict([(k, v / float(total_visit)) for k, v in dx_freqs.items()
                  ])
  proc_probs = dict([
      (k, v / float(total_visit)) for k, v in proc_freqs.items()
  ])
  dp_probs = dict([(k, v / float(total_visit)) for k, v in dp_freqs.items()
                  ])

  dp_cond_probs = {}
  pd_cond_probs = {}
  for dx, dx_prob in dx_probs.items():
    dx_str = dx.decode('utf-8') if isinstance(dx, bytes) else dx
    for proc, proc_prob in proc_probs.items():
      proc_str = proc.decode('utf-8') if isinstance(proc, bytes) else proc
      dp = dx_str + ',' + proc_str
    #   dp = dx + ',' + proc
      pd = proc_str + ',' + dx_str
      if dp in dp_probs:
        dp_cond_probs[dp] = dp_probs[dp] / dx_prob
        pd_cond_probs[pd] = dp_probs[dp] / proc_prob
      else:
        dp_cond_probs[dp] = 0.0
        pd_cond_probs[pd] = 0.0

  pickle.dump(dx_probs, open(output_path + '/dx_probs.empirical.p', 'wb'), -1)
  pickle.dump(proc_probs, open(output_path + '/proc_probs.empirical.p', 'wb'),
              -1)
  pickle.dump(dp_probs, open(output_path + '/dp_probs.empirical.p', 'wb'), -1)
  pickle.dump(dp_cond_probs,
              open(output_path + '/dp_cond_probs.empirical.p', 'wb'), -1)
  pickle.dump(pd_cond_probs,
              open(output_path + '/pd_cond_probs.empirical.p', 'wb'), -1)


def add_sparse_prior_guide_dp(seqex_list,
                              stats_path,
                              key_set=None,
                              max_num_codes=50):
  print('Loading conditional probabilities.')
  dp_cond_probs = pickle.load(
      open(stats_path + '/dp_cond_probs.empirical.p', 'rb'))
  pd_cond_probs = pickle.load(
      open(stats_path + '/pd_cond_probs.empirical.p', 'rb'))

  print('Adding prior guide.')
  total_visit = 0
  new_seqex_list = []
  for seqex in seqex_list:
    if total_visit % 1000 == 0:
      print('Visit count: %d\r' % total_visit)

    key = seqex.context.feature['patientId'].bytes_list.value[0]
    if (key_set is not None and key not in key_set):
      total_visit += 1
      continue

    dx_ids = seqex.feature_lists.feature_list['dx_ids'].feature[
        0].bytes_list.value
    proc_ids = seqex.feature_lists.feature_list['proc_ids'].feature[
        0].bytes_list.value

    indices = []
    values = []
    for i, dx in enumerate(dx_ids):
      dx_str = dx.decode('utf-8') if isinstance(dx, bytes) else dx
      for j, proc in enumerate(proc_ids):
        proc_str = proc.decode('utf-8') if isinstance(proc, bytes) else proc
        # dp = dx + ',' + proc
        dp = dx_str + ',' + proc_str
        indices.append((i, max_num_codes + j))
        prob = 0.0 if dp not in dp_cond_probs else dp_cond_probs[dp]
        values.append(prob)

    for i, proc in enumerate(proc_ids):
      proc_str = proc.decode('utf-8') if isinstance(proc, bytes) else proc
      for j, dx in enumerate(dx_ids):
        dx_str = dx.decode('utf-8') if isinstance(dx, bytes) else dx
        pd = proc_str + ',' + dx_str
        indices.append((max_num_codes + i, j))
        prob = 0.0 if pd not in pd_cond_probs else pd_cond_probs[pd]
        values.append(prob)

    indices = list(np.array(indices).reshape([-1]))
    indices_feature = seqex.feature_lists.feature_list['prior_indices']
    indices_feature.feature.add().int64_list.value.extend(indices)
    values_feature = seqex.feature_lists.feature_list['prior_values']
    values_feature.feature.add().float_list.value.extend(values)

    new_seqex_list.append(seqex)
    total_visit += 1

  return new_seqex_list

In [None]:
num_fold = 4

patient_file = path_to_data + '/patient.csv'
admission_dx_file = path_to_data + '/admissionDx.csv'
diagnosis_file = path_to_data + '/diagnosis.csv'
treatment_file = path_to_data + '/treatment.csv'

encounter_dict = {}
print('Processing patient.csv')
encounter_dict = process_patient(patient_file, encounter_dict, hour_threshold=24)
print('Processing admission diagnosis.csv')
encounter_dict = process_admission_dx(admission_dx_file, encounter_dict)
print('Processing diagnosis.csv')
encounter_dict = process_diagnosis(diagnosis_file, encounter_dict)
print('Processing treatment.csv')
encounter_dict = process_treatment(treatment_file, encounter_dict)
# print(encounter_dict[0])
count1 = 0

# Iterate through the dictionary items
for key, value in encounter_dict.items():
    # Print the key-value pair
    print(key, value)

    # Increment the counter
    count1 += 1

    # Check if 10 items have been printed
    if count1 == 5:
        break


Processing patient.csv
close1
68076
Processing admission diagnosis.csv
Count : %d 176269
Admission Diagnosis without Encounter ID: 450589
Processing diagnosis.csv

Diagnosis without Encounter ID: 2483092
Processing treatment.csv

Treatment without Encounter ID: 3372000
Accepted treatments: 316745
141178 <__main__.EncounterInfo object at 0x79daf8a47910>
141197 <__main__.EncounterInfo object at 0x79daf8a462f0>
141208 <__main__.EncounterInfo object at 0x79daf8a4cd00>
141229 <__main__.EncounterInfo object at 0x79daf8a4caf0>
141260 <__main__.EncounterInfo object at 0x79daf8a4ca90>


In [None]:
count1 = 0
for key, encounter_info in encounter_dict.items():
    print(f"Key: {key}, Patient ID: {encounter_info.patient_id}, Encounter ID: {encounter_info.encounter_id}, Timestamp: {encounter_info.encounter_timestamp}, Expired: {encounter_info.expired}, Readmission: {encounter_info.readmission}, Dx IDs: {encounter_info.dx_ids}, Rx IDs: {encounter_info.rx_ids}, Labs: {encounter_info.labs}, Physicals: {encounter_info.physicals}, Treatments: {encounter_info.treatments}")
    count1 += 1
    if count1 == 10:
        break

Key: 141178, Patient ID: 128927, Encounter ID: 141178, Timestamp: 14, Expired: False, Readmission: True, Dx IDs: [], Rx IDs: [], Labs: {}, Physicals: [], Treatments: []
Key: 141197, Patient ID: 128943, Encounter ID: 141197, Timestamp: 25, Expired: False, Readmission: True, Dx IDs: ['admission diagnosis|non-operative organ systems|organ system|cardiovascular', 'admission diagnosis|was the patient admitted from the o.r. or went to the o.r. within 4 hours of admission?|no', 'admission diagnosis|all diagnosis|non-operative|diagnosis|cardiovascular|sepsis, pulmonary'], Rx IDs: [], Labs: {}, Physicals: [], Treatments: []
Key: 141208, Patient ID: 128952, Encounter ID: 141208, Timestamp: 1, Expired: False, Readmission: False, Dx IDs: ['admission diagnosis|was the patient admitted from the o.r. or went to the o.r. within 4 hours of admission?|no', 'admission diagnosis|non-operative organ systems|organ system|neurologic', 'admission diagnosis|all diagnosis|non-operative|diagnosis|neurology|overd

In [None]:
# len111=len(encounter_dict['141197'].dx_ids)
# print(len111)
# len112=len(set(encounter_dict['141197'].dx_ids))
# print(len112)

In [None]:
key_list, seqex_list, dx_map, proc_map = build_seqex(
    encounter_dict, skip_duplicate=False, min_num_codes=1, max_num_codes=50)


Filtered encounters due to duplicate codes: 0
Filtered encounters due to thresholding: 0
Average num_dx_ids: 8.253912
Average num_treatments: 7.697826
Average num_unique_dx_ids: 6.462268
Average num_unique_treatments: 5.026276
Min dx cut: 16670
Min treatment cut: 10373
Max dx cut: 1
Max treatment cut: 6
Number of expired: 2983
Number of readmission: 7051


In [None]:
pickle.dump(dx_map, open(path_to_data + '/dx_map.p', 'wb'), -1)
pickle.dump(proc_map, open(path_to_data + '/proc_map.p', 'wb'), -1)

In [None]:
for i in range(num_fold):
    fold_path = path_to_data + '/fold_' + str(i)
    stats_path = path_to_data + '/train_stats'
    os.makedirs(stats_path, exist_ok=True)
    os.makedirs(fold_path, exist_ok=True)

    key_train, key_valid, key_test = select_train_valid_test(
        key_list, random_seed=42)
    print(len(key_list))
    print(len(key_train))
    print(len(key_valid))
    print(len(key_test))

    count_conditional_prob_dp(seqex_list, stats_path, set(key_train))
    train_seqex = add_sparse_prior_guide_dp(
        seqex_list, stats_path, set(key_train), max_num_codes=50)
    validation_seqex = add_sparse_prior_guide_dp(
        seqex_list, stats_path, set(key_valid), max_num_codes=50)
    test_seqex = add_sparse_prior_guide_dp(
        seqex_list, stats_path, set(key_test), max_num_codes=50)
    with tf.io.TFRecordWriter(fold_path + '/train.tfrecord') as writer:
        for seqex in train_seqex:
            writer.write(seqex.SerializeToString())

    with tf.io.TFRecordWriter(fold_path + '/validation.tfrecord') as writer:
        for seqex in validation_seqex:
            writer.write(seqex.SerializeToString())

    with tf.io.TFRecordWriter(fold_path + '/test.tfrecord') as writer:
        for seqex in test_seqex:
            writer.write(seqex.SerializeToString())


41026
32820
4103
4103
Visit count: 0
Visit count: 1000
Visit count: 2000
Visit count: 3000
Visit count: 4000
Visit count: 5000
Visit count: 6000
Visit count: 7000
Visit count: 8000
Visit count: 9000
Visit count: 10000
Visit count: 11000
Visit count: 12000
Visit count: 13000
Visit count: 14000
Visit count: 15000
Visit count: 16000
Visit count: 17000
Visit count: 18000
Visit count: 19000
Visit count: 20000
Visit count: 21000
Visit count: 22000
Visit count: 23000
Visit count: 24000
Visit count: 25000
Visit count: 26000
Visit count: 27000
Visit count: 28000
Visit count: 29000
Visit count: 30000
Visit count: 31000
Visit count: 32000
Visit count: 33000
Visit count: 34000
Visit count: 35000
Visit count: 36000
Visit count: 37000
Visit count: 38000
Visit count: 39000
Visit count: 40000
Visit count: 41000
Loading conditional probabilities.
Adding prior guide.
Visit count: 0
Visit count: 1000
Visit count: 2000
Visit count: 3000
Visit count: 4000
Visit count: 5000
Visit count: 6000
Visit count: 7

In [None]:
print(encounter_dict['141178'])

<__main__.EncounterInfo object at 0x79daf8a47910>


In [None]:
encinfo=encounter_dict['141178']
print(dir(encinfo))

['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'dx_ids', 'encounter_id', 'encounter_timestamp', 'expired', 'labs', 'patient_id', 'physicals', 'readmission', 'rx_ids', 'treatments']


In [None]:
print(len(encounter_dict))

68076
