In [1]:
import os, json, argparse, random, math, re
import pickle
from preprocess import save_sparse, save_data
from preprocess.parse_csv import Mimic3Parser, Mimic4Parser, EICUParser, Mimic4NoteParser
from preprocess.encode import encode_code
from preprocess.build_dataset import split_patients, build_code_xy, build_heart_failure_y
from preprocess.auxiliary import generate_code_code_adjacent, generate_neighbors, normalize_adj, divide_middle, generate_code_levels
import pandas as pd
from tqdm import tqdm
import numpy as np
from datetime import datetime, timedelta
import simple_icd_10_cm as cm
from create_data_clibench_1 import clean_note

In [2]:
save_path_parsed = 'data/mimic4/parsed'
check_dist_fields_base = [
    '_services_simple',
    '_transfers_careunits',
    'admission_type',
    'hospital_expire_flag',
    'admission_location',
    'discharge_location',
    'patient_insurance',
    'patient_lang',
    'patient_marital',
    'patient_race',
    'patient_gender'
]
debug_mode = False

In [3]:
# these loadings are not needed for data generation
# admission_codes = pickle.load(open(os.path.join(save_path_parsed, 'admission_codes.pkl'), 'rb'))
# admission_metadata = pickle.load(open(os.path.join(save_path_parsed, 'admission_metadata.pkl'), 'rb'))
# patient_metadata = pickle.load(open(os.path.join(save_path_parsed, 'patient_metadata.pkl'), 'rb'))
# admission_labevents = pickle.load(open(os.path.join(save_path_parsed, 'admission_labevents.pkl'), 'rb'))
# admission_prescriptions = pickle.load(open(os.path.join(save_path_parsed, 'admission_prescriptions.pkl'), 'rb'))
# admission_procedures = pickle.load(open(os.path.join(save_path_parsed, 'admission_procedures.pkl'), 'rb'))
# print(f'using intermediate files saved in {save_path_parsed} for note')
# admission_notes = pickle.load(open(os.path.join(save_path_parsed, 'admission_notes.pkl'), 'rb'))
# admission_radiology_notes = pickle.load(open(os.path.join(save_path_parsed, 'admission_radiology_notes.pkl'), 'rb'))
# print(f'using intermediate files saved in {save_path_parsed} for dict')
# with open(os.path.join(save_path_parsed, 'diagcode_longtitle.json')) as f:
#     diagcode_longtitle = json.load(f)
# with open(os.path.join(save_path_parsed, 'procedurecode_longtitle.json')) as f:
#     procedurecode_longtitle = json.load(f)

In [4]:
# load csv for LOINC code
if os.path.exists(os.path.join(save_path_parsed, 'loinc_metadata.json')):
    with open(os.path.join(save_path_parsed, 'loinc_metadata.json'), 'r') as f:
        loinc_metadata = json.load(f)
else:
    print('preparing lab item metadata from LOINC code')
    loinc_hierarchy_df = pd.read_csv('code_sys/LOINC/ComponentHierarchyBySystem.csv')
    loinc_df = pd.read_csv('code_sys/LOINC/Loinc.csv')
    with open(os.path.join(save_path_parsed, 'labitem_labels.json')) as f:
        labitem_labels = json.load(f)

    # traverse loinc_df dataframe to create a dictionary mapping loinc code to its parent loinc code
    loinc_metadata = {}
    for index, row in loinc_df.iterrows():
        loinc_code = row['LOINC_NUM']
        loinc_code_simple = loinc_code.split('-')[0] # remove the version number
        loinc_metadata[loinc_code_simple] = {
            'code': loinc_code,
            'component': row['COMPONENT'],
            'property': row['PROPERTY'],
            'time': row['TIME_ASPCT'],
            'system': row['SYSTEM'],
            'scale': row['SCALE_TYP'],
            'method': row['METHOD_TYP'],
        }

    count_new = 0
    for index, row in loinc_hierarchy_df.iterrows():
        loinc_code = row['CODE']
        loinc_code_simple = loinc_code.split('-')[0] # remove the version number
        loinc_ancestors = row['PATH_TO_ROOT'].split('.')[::-1] if isinstance(row['PATH_TO_ROOT'], str) else [] # from root to immedidate parent
        if loinc_code_simple in loinc_metadata:
            loinc_metadata[loinc_code_simple]['ancestors'] = loinc_ancestors
        else:
            count_new += 1
            loinc_metadata[loinc_code_simple] = {
                'code': loinc_code,
                'component': row['CODE_TEXT'],
                'ancestors': loinc_ancestors,
            }
    print(f'Loaded LOINC code count: {len(loinc_metadata)}, {count_new} from hierarchy only')

    count_common = 0
    for li, text_mimic in labitem_labels.items():
        if li in loinc_metadata:
            loinc_metadata[li]['text_mimic'] = text_mimic
            count_common += 1
        else:
            loinc_metadata[li] = {
                'component': text_mimic,
            }
            print(f'Not included in LOINC code but happen in MIMIC: {li}')
    print(f'{count_common} / {len(labitem_labels)} lab items in the dataset can be found on LOINC coding system')

    with open(os.path.join(save_path_parsed, 'loinc_metadata.json'), 'w') as f:
        json.dump(loinc_metadata, f, indent=4)

In [5]:
with open(os.path.join('code_sys/NDC/ndc_metadata.json'), 'r') as f:
    ndc_metadata = json.load(f)

In [6]:
def drug_ancestors(ndc_code, level=1):
    list_of_ancestors = []
    if ndc_code in ndc_metadata:
        if 'atc' in ndc_metadata[ndc_code]:
            for atc_this in ndc_metadata[ndc_code]['atc'][0]:
                ancestors = []
                atc_id = atc_this['id']
                # A, A10, A10B, A10BA, A10BA02
                ancestors.append(atc_id[0])
                if len(atc_id) == 3:
                    ancestors.append(atc_id[:2])
                if len(atc_id) == 4:
                    ancestors.append(atc_id[:3])
                if len(atc_id) == 5:
                    ancestors.append(atc_id[:4])
                if len(atc_id) == 7:
                    ancestors.append(atc_id[:6])
                if len(atc_id) not in [1, 3, 4, 5, 7]:
                    print('Cannot find ancestors for ATC code with current implementation', atc_id)
                ancestors = ancestors[::-1]
                list_of_ancestors.append(ancestors)
    ancestors_at_level = list(set([ancestors[-level] for ancestors in list_of_ancestors]))
    return ancestors_at_level

In [7]:
drug_ancestors('00641036725')

['A', 'R', 'H', 'D', 'S', 'C']

In [8]:
def labitem_ancestors(code):
    # V1 implementation, use the category mentioned in MIMIC data source
    # lab_def = labitem_labels[str(code)]
    # category = re.findall(r'\((.*?)\)', lab_def)[0]
    # return [category]
    if 'ancestors' in loinc_metadata[str(code)]:
        return loinc_metadata[str(code)]['ancestors']
    else:
        return []

In [9]:
def icd10pcs_ancestors(code):
    ancestor_1 = code[:1]
    ancestor_2 = code[:2]
    ancestor_3 = code[:3]
    return [ancestor_3, ancestor_2, ancestor_1]
    

In [10]:
def process_service_name(raw):
    name = raw[0].replace('.', '').replace('service', '').replace('services', '').replace('department', '').strip()
    if name in ['bilogics', 'biologic', 'biologics']:
        name = 'biologics'
    if name in ['denies', 'dental']:
        name = 'dental'
    if name in ['ed', 'ed consulting ortho', 'emergency']:
        name = 'emergency'
    if name in ['general', 'general: lying comfortably in bed', 'general surgery']:
        name = 'general surgery'
    if name in ['med', 'medicine']:
        name = 'medicine'
    if name in ['ob-gyn', 'obstetrics/gynecology']:
        name = 'obstetrics/gynecology'
    if name in ['podiatric surgery', 'podiatry']:
        name = 'podiatry'
    if name.startswith('ort'):
        name = 'orthopaedics'

    return name

In [11]:
def group_data(adm_data, group_key):
    print(f'------- filtering for {group_key}')
    if type(adm_data[0][group_key]) == list:
        all_labels = []
        for dp in adm_data:
            all_labels.extend(dp[group_key])
        all_labels = sorted(list(set(all_labels)))
    else:
        all_labels = sorted(list(set([dp[group_key] for dp in adm_data])))
    data_index_map = {}
    count_map = {}
    for label in all_labels:
        data_index_map[label] = []
    for i, dp in enumerate(adm_data):
        if type(adm_data[i][group_key]) == list:
            for label_this in dp[group_key]:
                data_index_map[label_this].append(i)
        else:
            data_index_map[dp[group_key]].append(i)
    count_per_label = [len(data_index_map[label]) for label in all_labels]
    count_small_bound = min(count_per_label)
    for label, count in zip(all_labels, count_per_label):
        count_map[label] = count
        print(f'{label}: {count}')
    print(f'total number of labels: {len(all_labels)}')
    print('lowest category count:', count_small_bound)
    return all_labels, data_index_map, count_small_bound, count_map

In [12]:
if debug_mode:
    adm_data = pickle.load(open(os.path.join(save_path_parsed, 'adm_data_100.pkl'), 'rb'))
    print(adm_data[80])
else:
    adm_data = pickle.load(open(os.path.join(save_path_parsed, 'adm_data.pkl'), 'rb'))
    with open(os.path.join(save_path_parsed, 'adm_data_100.pkl'), 'wb') as f:
        pickle.dump(adm_data[:100], f)

## Adding intermediate data fields

In [13]:
# Extract service departments from services and transfers two table
for i, dp in enumerate(adm_data):
    adm_data[i]['_services_simple'] = [s[-1] for s in dp['services'] if len(s[-1]) > 0]
    adm_data[i]['_transfers_careunits'] = [s[3] for s in dp['transfers'] if len(s[3]) > 0]

In [14]:
# Generate  target_laborders info, for lab ordering task
for i, dp in enumerate(adm_data):
    # new item for laborder would be [itemid, charttime, storetime]
    adm_data[i]['target_laborders'] = [[s[2], s[3], s[4]] for s in dp['labevents']]

In [15]:
# Only keep the procedures, prescriptions, lab items right after admit
filter_stats = {'full': {}, 'keep': {}}
filter_stats_sum = {'full': {}, 'keep': {}}
for field in ['target_procedures', 'target_prescriptions', 'target_laborders']:
    print(f'processing {field}')
    filter_stats['full'][field] = []
    filter_stats['keep'][field] = []
    for i, dp in enumerate(tqdm(adm_data)):
        full_records = adm_data[i][field]
        admittime = adm_data[i]['admittime']
        if field == 'target_laborders':
            decisiontime_idx = 1
        elif field == 'target_prescriptions':
            decisiontime_idx = 1
        elif field == 'target_procedures':
            decisiontime_idx = -1

        # keep the record within the first 24 hour after admit time, or the first batch in the record list
        if len(full_records) > 0:
            first_batch_time = min([e[decisiontime_idx] for e in full_records])
            adm_data[i][field] = [e for e in full_records if e[decisiontime_idx] - first_batch_time <= timedelta(hours=2) or e[decisiontime_idx] - admittime <= timedelta(days=1)]
            filter_stats['full'][field].append(len(full_records))
            filter_stats['keep'][field].append(len(adm_data[i][field]))
    filter_stats_sum['full'][field] = np.mean(filter_stats['full'][field])
    filter_stats_sum['keep'][field] = np.mean(filter_stats['keep'][field])

# Stats: how many of records are covered in the first batch time range
filter_stats_sum

processing target_procedures


100%|██████████| 331793/331793 [00:01<00:00, 175805.04it/s]


processing target_prescriptions


100%|██████████| 331793/331793 [01:38<00:00, 3355.39it/s]


processing target_laborders


100%|██████████| 331793/331793 [06:28<00:00, 854.30it/s] 


{'full': {'target_procedures': 3.076277796582731,
  'target_prescriptions': 44.65936252574113,
  'target_laborders': 188.33024589225033},
 'keep': {'target_procedures': 2.0837098412044326,
  'target_prescriptions': 21.470245967474078,
  'target_laborders': 50.019972381667536}}

In [16]:
# Clean discharge note
for i, dp in enumerate(adm_data):
    for note_i, note_item in enumerate(dp['notes_discharge']):
        cleaned_note, _ = clean_note(note_item[0])
        adm_data[i]['notes_discharge'][note_i][0] = cleaned_note

In [17]:
# add processed service department name
for i, dp in enumerate(adm_data):
    adm_data[i]['_service_processed'] = process_service_name(dp['_service'])

In [18]:
# Filter adm_data by legit _service_processed -> discarded
# prev_count = len(adm_data)
# adm_data = [dp for dp in adm_data if 
#             len(dp['_service_processed']) > 0 and 
#             len(dp['_service_processed']) < 40 and 
#             'time' not in dp['_service_processed'] and
#             '___' not in dp['_service_processed'] and 
#             '===' not in dp['_service_processed'] and
#             '30 days' not in dp['_service_processed'] and 
#             'l olecranon and r patella fracture' not in dp['_service_processed']
#             ]
# print(f'Removed: we keep {len(adm_data)}/{prev_count} admissions with valid _service_processed')

In [19]:
if debug_mode:
    print(adm_data[80])

## Processing branch 1: target task is diagnosis decision

In [None]:
save_path_final = 'data/mimic4/target_diagnoses'

In [None]:
adm_data_branch = [dp for dp in adm_data if len(dp['target_diagnoses']) > 0]
print(f'Removed: we keep {len(adm_data_branch)}/{len(adm_data)} admissions with at least 1 diagnosis code')

# Balance by major diagnosis
adm_data_branch_new = []
for i, dp in enumerate(adm_data_branch):
    diag_codes = [item[0] for item in dp['target_diagnoses']]
    diag_main = dp['target_diagnoses'][0][0]
    # Check whether there is illegal ICD codes
    illegal_flag = False
    for code in diag_codes:
        if not cm.is_valid_item(code):
            illegal_flag = True
            break
    if illegal_flag:
        continue
    diag_ancestors = cm.get_ancestors(cm.add_dot(diag_main))
    # chapter_diag = f"{diag_ancestors[-1]}\t{diag_ancestors[-2]}"
    chapter_diag = f"{diag_ancestors[-1]}"
    dp['_target_diagnoses_major_chapter'] = chapter_diag

    uniq_chapter = list(set([cm.get_ancestors(cm.add_dot(c))[-1] for c in diag_codes]))
    dp['_target_diagnoses_chapters'] = uniq_chapter
    
    adm_data_branch_new.append(dp)
print(f'Removed: we keep {len(adm_data_branch_new)}/{len(adm_data_branch)} admissions with all valid ICD-10 codes')
adm_data_branch = adm_data_branch_new

In [None]:
# Balance by major diagnosis
group_key = '_target_diagnoses_major_chapter'
all_labels, data_index_map, count_small_bound, count_map = group_data(adm_data_branch, group_key)
print(f'there are {len(all_labels)} labels in total')

param_dp_each_label = 47

selected_indexes = []
for selected_label in all_labels:
    if len(data_index_map[selected_label]) <= param_dp_each_label:
        selected_indexes.extend(data_index_map[selected_label])
    else:
        selected_indexes.extend(random.sample(data_index_map[selected_label], param_dp_each_label))
print(f'{len(selected_indexes)} samples are selected in total according to {group_key} distribution')

for i, dp in enumerate(adm_data_branch):
    if '_filter_flags' not in dp:
        adm_data_branch[i]['_filter_flags'] = {}
    adm_data_branch[i]['_filter_flags'][group_key] = False

for i in selected_indexes:
    adm_data_branch[i]['_filter_flags'][group_key] = True

In [None]:
# Balance by service department
group_key = '_services_simple'
all_labels, data_index_map, count_small_bound, count_map = group_data(adm_data_branch, group_key)
adm_data_branch_current = [dp for dp in adm_data_branch if dp['_filter_flags']['_target_diagnoses_major_chapter']]
_, _, _, count_map_current = group_data(adm_data_branch_current, group_key)

param_dp_each_label = 20

selected_indexes = []
for selected_label in all_labels:
    if selected_label not in count_map_current:
        sample_size = param_dp_each_label
    elif count_map_current[selected_label] < param_dp_each_label:
        sample_size = param_dp_each_label - count_map_current[selected_label]
    else:
        continue

    if len(data_index_map[selected_label]) <= param_dp_each_label:
        selected_indexes.extend(data_index_map[selected_label])
    else:
        selected_indexes.extend(random.sample(data_index_map[selected_label], sample_size))
print(f'{len(selected_indexes)} samples are selected in total according to {group_key} distribution')

for i, dp in enumerate(adm_data_branch):
    adm_data_branch[i]['_filter_flags'][group_key] = False
for i in selected_indexes:
    adm_data_branch[i]['_filter_flags'][group_key] = True

In [None]:
# Balance by care units
group_key = '_transfers_careunits'
all_labels, data_index_map, count_small_bound, count_map = group_data(adm_data_branch, group_key)
adm_data_branch_current = [dp for dp in adm_data_branch if dp['_filter_flags']['_target_diagnoses_major_chapter'] or dp['_filter_flags']['_services_simple']]
_, _, _, count_map_current = group_data(adm_data_branch_current, group_key)

param_dp_each_label = 20

selected_indexes = []
for selected_label in all_labels:
    if selected_label not in count_map_current:
        sample_size = param_dp_each_label
    elif count_map_current[selected_label] < param_dp_each_label:
        sample_size = param_dp_each_label - count_map_current[selected_label]
    else:
        continue

    if len(data_index_map[selected_label]) <= param_dp_each_label:
        selected_indexes.extend(data_index_map[selected_label])
    else:
        selected_indexes.extend(random.sample(data_index_map[selected_label], sample_size))
print(f'{len(selected_indexes)} samples are selected in total according to {group_key} distribution')

for i, dp in enumerate(adm_data_branch):
    adm_data_branch[i]['_filter_flags'][group_key] = False
for i in selected_indexes:
    adm_data_branch[i]['_filter_flags'][group_key] = True

In [None]:
# Include all instances with at least one True flag. For the left instaces, sample from each major diagnosis chapter
adm_data_branch_final = [dp for dp in adm_data_branch if any([flag for flag in dp['_filter_flags'].values()])]
adm_data_branch_left = [dp for dp in adm_data_branch if not any([flag for flag in dp['_filter_flags'].values()])]
print(f"{len(adm_data_branch_final)} dp for evaluation")
print(f"{len(adm_data_branch_left)} dp for training")

In [None]:
if not os.path.exists(save_path_final):
    os.makedirs(save_path_final, exist_ok=True)
pickle.dump(adm_data_branch_final, open(os.path.join(save_path_final, 'test.pkl'), 'wb'))
pickle.dump(adm_data_branch_left, open(os.path.join(save_path_final, 'train.pkl'), 'wb'))
# with open(os.path.join(save_path_final, 'test.json'), 'w') as f:
#     json.dump(adm_data_branch_final, f, indent=4)

In [None]:
check_dist_fields = check_dist_fields_base + ['_target_diagnoses_major_chapter', '_target_diagnoses_chapters']

for field in check_dist_fields:
    _, _, _, count_map_current = group_data(adm_data_branch_final, field)
    with open(os.path.join(save_path_final, f'distribution_{field}.json'), 'w') as f:
        json.dump(count_map_current, f, indent=4)

## Processing branch 2: target task is procedure decision

In [None]:
save_path_final = 'data/mimic4/target_procedures'

In [None]:
adm_data_branch = [dp for dp in adm_data if len(dp['target_procedures']) > 0]
print(f'Removed: we keep {len(adm_data_branch)}/{len(adm_data)} admissions with at least 1 procedure code')

for i, dp in enumerate(adm_data_branch):
    proc_codes = [item[0] for item in dp['target_procedures']]
    uniq_chapter = list(set([icd10pcs_ancestors(c)[-1] for c in proc_codes]))
    adm_data_branch[i]['_target_procedures_chapters'] = uniq_chapter

In [None]:
# Balance by procedures chapters
group_key = '_target_procedures_chapters'
all_labels, data_index_map, count_small_bound, count_map = group_data(adm_data_branch, group_key)
print(f'there are {len(all_labels)} labels in total')

param_dp_each_label = 40

selected_indexes = []
for selected_label in all_labels:
    if len(data_index_map[selected_label]) <= param_dp_each_label:
        selected_indexes.extend(data_index_map[selected_label])
    else:
        selected_indexes.extend(random.sample(data_index_map[selected_label], param_dp_each_label))
selected_indexes = list(set(selected_indexes))
print(f'{len(selected_indexes)} samples are selected in total according to {group_key} distribution')

for i, dp in enumerate(adm_data_branch):
    if '_filter_flags' not in dp:
        adm_data_branch[i]['_filter_flags'] = {}
    adm_data_branch[i]['_filter_flags'][group_key] = False

for i in selected_indexes:
    adm_data_branch[i]['_filter_flags'][group_key] = True

In [None]:
# Balance by service department
group_key = '_services_simple'
all_labels, data_index_map, count_small_bound, count_map = group_data(adm_data_branch, group_key)
adm_data_branch_current = [dp for dp in adm_data_branch if dp['_filter_flags']['_target_procedures_chapters']]
_, _, _, count_map_current = group_data(adm_data_branch_current, group_key)

param_dp_each_label = 20

selected_indexes = []
for selected_label in all_labels:
    if selected_label not in count_map_current:
        sample_size = param_dp_each_label
    elif count_map_current[selected_label] < param_dp_each_label:
        sample_size = param_dp_each_label - count_map_current[selected_label]
    else:
        continue

    if len(data_index_map[selected_label]) <= param_dp_each_label:
        selected_indexes.extend(data_index_map[selected_label])
    else:
        selected_indexes.extend(random.sample(data_index_map[selected_label], sample_size))
print(f'{len(selected_indexes)} samples are selected in total according to {group_key} distribution')

for i, dp in enumerate(adm_data_branch):
    adm_data_branch[i]['_filter_flags'][group_key] = False
for i in selected_indexes:
    adm_data_branch[i]['_filter_flags'][group_key] = True

In [None]:
# Balance by service department
group_key = '_transfers_careunits'
all_labels, data_index_map, count_small_bound, count_map = group_data(adm_data_branch, group_key)
adm_data_branch_current = [dp for dp in adm_data_branch if dp['_filter_flags']['_target_procedures_chapters'] or dp['_filter_flags']['_services_simple']]
_, _, _, count_map_current = group_data(adm_data_branch_current, group_key)

param_dp_each_label = 20

selected_indexes = []
for selected_label in all_labels:
    if selected_label not in count_map_current:
        sample_size = param_dp_each_label
    elif count_map_current[selected_label] < param_dp_each_label:
        sample_size = param_dp_each_label - count_map_current[selected_label]
    else:
        continue

    if len(data_index_map[selected_label]) <= param_dp_each_label:
        selected_indexes.extend(data_index_map[selected_label])
    else:
        selected_indexes.extend(random.sample(data_index_map[selected_label], sample_size))
print(f'{len(selected_indexes)} samples are selected in total according to {group_key} distribution')

for i, dp in enumerate(adm_data_branch):
    adm_data_branch[i]['_filter_flags'][group_key] = False
for i in selected_indexes:
    adm_data_branch[i]['_filter_flags'][group_key] = True

In [None]:
# Include all instances with at least one True flag. For the left instaces, sample from each major diagnosis chapter
adm_data_branch_final = [dp for dp in adm_data_branch if any([flag for flag in dp['_filter_flags'].values()])]
adm_data_branch_left = [dp for dp in adm_data_branch if not any([flag for flag in dp['_filter_flags'].values()])]
print(f"{len(adm_data_branch_final)} dp for evaluation")
print(f"{len(adm_data_branch_left)} dp for training")

In [None]:
if not os.path.exists(save_path_final):
    os.makedirs(save_path_final, exist_ok=True)
pickle.dump(adm_data_branch_final, open(os.path.join(save_path_final, 'test.pkl'), 'wb'))
pickle.dump(adm_data_branch_left, open(os.path.join(save_path_final, 'train.pkl'), 'wb'))
# with open(os.path.join(save_path_final, 'test.json'), 'w') as f:
#     json.dump(adm_data_branch_final, f, indent=4)

In [None]:
check_dist_fields = check_dist_fields_base + ['_target_procedures_chapters']

for field in check_dist_fields:
    _, _, _, count_map_current = group_data(adm_data_branch_final, field)
    with open(os.path.join(save_path_final, f'distribution_{field}.json'), 'w') as f:
        json.dump(count_map_current, f, indent=4)

## Processing branch 3: target task is lab test orders

In [None]:
save_path_final = 'data/mimic4/target_laborders'

In [None]:
adm_data_branch = [dp for dp in adm_data if len(dp['target_laborders']) > 0]
print(f'Removed: we keep {len(adm_data_branch)}/{len(adm_data)} admissions with at least 1 laborders')

for i, dp in enumerate(adm_data_branch):
    lab_codes = [item[0] for item in dp['target_laborders']]
    # -1 would be the highest level, too abstract {component}
    # -2 would be Laboratory, Clinical, Attachments, Survey instruments
    # -3 would be categories like Skin challenge, Drug doses, Allergy etc
    uniq_chapter_list = []
    for c in lab_codes:
        an_this = labitem_ancestors(c)
        if len(an_this) >=3 :
            uniq_chapter_list.append(an_this[-3])
    # uniq_chapter = list(set([labitem_ancestors(c)[-3] for c in lab_codes]))
    adm_data_branch[i]['_target_labs_categories'] = list(set(uniq_chapter_list))

In [None]:
# Balance by procedures chapters
group_key = '_target_labs_categories'
all_labels, data_index_map, count_small_bound, count_map = group_data(adm_data_branch, group_key)
print(f'there are {len(all_labels)} labels in total')

param_dp_each_label = 45

selected_indexes = []
for selected_label in all_labels:
    if len(data_index_map[selected_label]) <= param_dp_each_label:
        selected_indexes.extend(data_index_map[selected_label])
    else:
        selected_indexes.extend(random.sample(data_index_map[selected_label], param_dp_each_label))
selected_indexes = list(set(selected_indexes))
print(f'{len(selected_indexes)} samples are selected in total according to {group_key} distribution')

for i, dp in enumerate(adm_data_branch):
    if '_filter_flags' not in dp:
        adm_data_branch[i]['_filter_flags'] = {}
    adm_data_branch[i]['_filter_flags'][group_key] = False

for i in selected_indexes:
    adm_data_branch[i]['_filter_flags'][group_key] = True

In [None]:
adm_data_branch[0]

In [None]:
# Balance by service department
group_key = '_services_simple'
all_labels, data_index_map, count_small_bound, count_map = group_data(adm_data_branch, group_key)
adm_data_branch_current = [dp for dp in adm_data_branch if dp['_filter_flags']['_target_labs_categories']]
print(len(adm_data_branch_current))
_, _, _, count_map_current = group_data(adm_data_branch_current, group_key)

param_dp_each_label = 20

selected_indexes = []
for selected_label in all_labels:
    if selected_label not in count_map_current:
        sample_size = param_dp_each_label
    elif count_map_current[selected_label] < param_dp_each_label:
        sample_size = param_dp_each_label - count_map_current[selected_label]
    else:
        continue

    if len(data_index_map[selected_label]) <= param_dp_each_label:
        selected_indexes.extend(data_index_map[selected_label])
    else:
        selected_indexes.extend(random.sample(data_index_map[selected_label], sample_size))
print(f'{len(selected_indexes)} samples are selected in total according to {group_key} distribution')

for i, dp in enumerate(adm_data_branch):
    adm_data_branch[i]['_filter_flags'][group_key] = False
for i in selected_indexes:
    adm_data_branch[i]['_filter_flags'][group_key] = True

In [None]:
# Balance by care units
group_key = '_transfers_careunits'
all_labels, data_index_map, count_small_bound, count_map = group_data(adm_data_branch, group_key)
adm_data_branch_current = [dp for dp in adm_data_branch if dp['_filter_flags']['_target_labs_categories'] or dp['_filter_flags']['_services_simple']]
print(len(adm_data_branch_current))
_, _, _, count_map_current = group_data(adm_data_branch_current, group_key)

param_dp_each_label = 20

selected_indexes = []
for selected_label in all_labels:
    if selected_label not in count_map_current:
        sample_size = param_dp_each_label
    elif count_map_current[selected_label] < param_dp_each_label:
        sample_size = param_dp_each_label - count_map_current[selected_label]
    else:
        continue

    if len(data_index_map[selected_label]) <= param_dp_each_label:
        selected_indexes.extend(data_index_map[selected_label])
    else:
        selected_indexes.extend(random.sample(data_index_map[selected_label], sample_size))
print(f'{len(selected_indexes)} samples are selected in total according to {group_key} distribution')

for i, dp in enumerate(adm_data_branch):
    adm_data_branch[i]['_filter_flags'][group_key] = False
for i in selected_indexes:
    adm_data_branch[i]['_filter_flags'][group_key] = True

In [None]:
# Include all instances with at least one True flag. For the left instaces, sample from each major diagnosis chapter
adm_data_branch_final = [dp for dp in adm_data_branch if any([flag for flag in dp['_filter_flags'].values()])]
adm_data_branch_left = [dp for dp in adm_data_branch if not any([flag for flag in dp['_filter_flags'].values()])]
print(f"{len(adm_data_branch_final)} dp for evaluation")
print(f"{len(adm_data_branch_left)} dp for training")

In [None]:
if not os.path.exists(save_path_final):
    os.makedirs(save_path_final, exist_ok=True)
pickle.dump(adm_data_branch_final, open(os.path.join(save_path_final, 'test.pkl'), 'wb'))
pickle.dump(adm_data_branch_left, open(os.path.join(save_path_final, 'train.pkl'), 'wb'))
# with open(os.path.join(save_path_final, 'test.json'), 'w') as f:
#     json.dump(adm_data_branch_final, f, indent=4)

In [None]:
check_dist_fields = check_dist_fields_base + ['_target_labs_categories']

for field in check_dist_fields:
    _, _, _, count_map_current = group_data(adm_data_branch_final, field)
    with open(os.path.join(save_path_final, f'distribution_{field}.json'), 'w') as f:
        json.dump(count_map_current, f, indent=4)

## Processing branch 4: target task is prescription decision

In [20]:
save_path_final = 'data/mimic4/target_prescriptions'

In [21]:
adm_data_branch = [dp for dp in adm_data if len(dp['target_prescriptions']) > 0]
print(f'Removed: we keep {len(adm_data_branch)}/{len(adm_data)} admissions with at least 1 prescription')

all_unique_ndc = []
for i, dp in enumerate(adm_data_branch):
    for p in dp['target_prescriptions']:
        if p[6] not in all_unique_ndc:
            all_unique_ndc.append(p[6])
with open(os.path.join(save_path_parsed, 'ndc_in_data.json'), 'w') as f:
    json.dump(all_unique_ndc, f, indent=4)

print(f'there are {len(all_unique_ndc)} unique NDC codes in the mimic dataset')

Removed: we keep 331181/331793 admissions with at least 1 prescription
there are 5331 unique NDC codes in the mimic dataset


In [22]:
for i, dp in enumerate(adm_data_branch):
    pres_codes = [item[6] for item in dp['target_prescriptions']]
    uniq_chapters = []
    for pres_code in pres_codes:
        uniq_chapters.extend(drug_ancestors(pres_code))
    adm_data_branch[i]['_target_prescriptions_categories'] = list(set(uniq_chapters))

In [23]:
# Balance by procedures chapters
group_key = '_target_prescriptions_categories'
all_labels, data_index_map, count_small_bound, count_map = group_data(adm_data_branch, group_key)
print(f'there are {len(all_labels)} labels in total')

param_dp_each_label = 55

selected_indexes = []
for selected_label in all_labels:
    if len(data_index_map[selected_label]) <= param_dp_each_label:
        selected_indexes.extend(data_index_map[selected_label])
    else:
        selected_indexes.extend(random.sample(data_index_map[selected_label], param_dp_each_label))
selected_indexes = list(set(selected_indexes))
print(f'{len(selected_indexes)} samples are selected in total according to {group_key} distribution')

for i, dp in enumerate(adm_data_branch):
    if '_filter_flags' not in dp:
        adm_data_branch[i]['_filter_flags'] = {}
    adm_data_branch[i]['_filter_flags'][group_key] = False

for i in selected_indexes:
    adm_data_branch[i]['_filter_flags'][group_key] = True

------- filtering for _target_prescriptions_categories
A: 322629
B: 316266
C: 292414
D: 190047
G: 156352
H: 145942
J: 153841
L: 22761
M: 59195
N: 313901
P: 28830
R: 161113
S: 263629
V: 189465
total number of labels: 14
lowest category count: 22761
there are 14 labels in total
770 samples are selected in total according to _target_prescriptions_categories distribution


In [24]:
# Balance by service department
group_key = '_services_simple'
all_labels, data_index_map, count_small_bound, count_map = group_data(adm_data_branch, group_key)
adm_data_branch_current = [dp for dp in adm_data_branch if dp['_filter_flags']['_target_prescriptions_categories']]
_, _, _, count_map_current = group_data(adm_data_branch_current, group_key)

param_dp_each_label = 20

selected_indexes = []
for selected_label in all_labels:
    if selected_label not in count_map_current:
        sample_size = param_dp_each_label
    elif count_map_current[selected_label] < param_dp_each_label:
        sample_size = param_dp_each_label - count_map_current[selected_label]
    else:
        continue

    if len(data_index_map[selected_label]) <= param_dp_each_label:
        selected_indexes.extend(data_index_map[selected_label])
    else:
        selected_indexes.extend(random.sample(data_index_map[selected_label], sample_size))
print(f'{len(selected_indexes)} samples are selected in total according to {group_key} distribution')

for i, dp in enumerate(adm_data_branch):
    adm_data_branch[i]['_filter_flags'][group_key] = False
for i in selected_indexes:
    adm_data_branch[i]['_filter_flags'][group_key] = True

------- filtering for _services_simple
CMED: 35148
CSURG: 10497
DENT: 18
ENT: 994
EYE: 34
GU: 5220
GYN: 5793
MED: 158774
NMED: 20168
NSURG: 12206
OBS: 4232
OMED: 25655
ORTHO: 20027
PSURG: 3346
PSYCH: 7541
SURG: 37110
TRAUM: 6498
TSURG: 4154
VSURG: 9728
total number of labels: 19
lowest category count: 18
------- filtering for _services_simple
CMED: 68
CSURG: 23
ENT: 3
GU: 13
GYN: 21
MED: 386
NMED: 37
NSURG: 24
OBS: 13
OMED: 73
ORTHO: 40
PSURG: 6
PSYCH: 13
SURG: 93
TRAUM: 17
TSURG: 10
VSURG: 27
total number of labels: 17
lowest category count: 3
103 samples are selected in total according to _services_simple distribution


In [25]:
# Balance by service department
group_key = '_transfers_careunits'
all_labels, data_index_map, count_small_bound, count_map = group_data(adm_data_branch, group_key)
adm_data_branch_current = [dp for dp in adm_data_branch if dp['_filter_flags']['_target_prescriptions_categories'] or dp['_filter_flags']['_services_simple']]
_, _, _, count_map_current = group_data(adm_data_branch_current, group_key)

param_dp_each_label = 20

selected_indexes = []
for selected_label in all_labels:
    if selected_label not in count_map_current:
        sample_size = param_dp_each_label
    elif count_map_current[selected_label] < param_dp_each_label:
        sample_size = param_dp_each_label - count_map_current[selected_label]
    else:
        continue

    if len(data_index_map[selected_label]) <= param_dp_each_label:
        selected_indexes.extend(data_index_map[selected_label])
    else:
        selected_indexes.extend(random.sample(data_index_map[selected_label], sample_size))
print(f'{len(selected_indexes)} samples are selected in total according to {group_key} distribution')

for i, dp in enumerate(adm_data_branch):
    adm_data_branch[i]['_filter_flags'][group_key] = False
for i in selected_indexes:
    adm_data_branch[i]['_filter_flags'][group_key] = True

------- filtering for _transfers_careunits
Cardiac Surgery: 16733
Cardiac Vascular Intensive Care Unit (CVICU): 15965
Cardiology: 598
Cardiology Surgery Intermediate: 4484
Coronary Care Unit (CCU): 10632
Discharge Lounge: 48247
Emergency Department: 223656
Emergency Department Observation: 16650
Hematology/Oncology: 35058
Hematology/Oncology Intermediate: 11439
Labor & Delivery: 3690
Med/Surg: 46830
Med/Surg/GYN: 26099
Med/Surg/Trauma: 22829
Medical Intensive Care Unit (MICU): 20910
Medical/Surgical (Gynecology): 7711
Medical/Surgical Intensive Care Unit (MICU/SICU): 15974
Medicine: 144986
Medicine/Cardiology: 37831
Medicine/Cardiology Intermediate: 6346
Neuro Intermediate: 3810
Neuro Stepdown: 1872
Neuro Surgical Intensive Care Unit (Neuro SICU): 2308
Neurology: 37617
Observation: 3422
Obstetrics (Postpartum & Antepartum): 3936
Obstetrics Antepartum: 1076
Obstetrics Postpartum: 591
PACU: 20421
Psychiatry: 13006
Surgery: 9233
Surgery/Pancreatic/Biliary/Bariatric: 7057
Surgery/Trauma: 1

In [26]:
# Include all instances with at least one True flag. For the left instaces, sample from each major diagnosis chapter
adm_data_branch_final = [dp for dp in adm_data_branch if any([flag for flag in dp['_filter_flags'].values()])]
adm_data_branch_left = [dp for dp in adm_data_branch if not any([flag for flag in dp['_filter_flags'].values()])]
print(f"{len(adm_data_branch_final)} dp for evaluation")
print(f"{len(adm_data_branch_left)} dp for training")

1036 dp for evaluation
330145 dp for training


In [27]:
if not os.path.exists(save_path_final):
    os.makedirs(save_path_final, exist_ok=True)
pickle.dump(adm_data_branch_final, open(os.path.join(save_path_final, 'test.pkl'), 'wb'))
pickle.dump(adm_data_branch_left, open(os.path.join(save_path_final, 'train.pkl'), 'wb'))
# with open(os.path.join(save_path_final, 'test.json'), 'w') as f:
#     json.dump(adm_data_branch_final, f, indent=4)

In [28]:
check_dist_fields = check_dist_fields_base + ['_target_prescriptions_categories']

for field in check_dist_fields:
    _, _, _, count_map_current = group_data(adm_data_branch_final, field)
    with open(os.path.join(save_path_final, f'distribution_{field}.json'), 'w') as f:
        json.dump(count_map_current, f, indent=4)

------- filtering for _services_simple
CMED: 110
CSURG: 30
DENT: 18
ENT: 24
EYE: 20
GU: 22
GYN: 25
MED: 467
NMED: 67
NSURG: 47
OBS: 52
OMED: 77
ORTHO: 42
PSURG: 23
PSYCH: 20
SURG: 104
TRAUM: 26
TSURG: 27
VSURG: 28
total number of labels: 19
lowest category count: 18
------- filtering for _transfers_careunits
Cardiac Surgery: 32
Cardiac Vascular Intensive Care Unit (CVICU): 57
Cardiology: 29
Cardiology Surgery Intermediate: 22
Coronary Care Unit (CCU): 30
Discharge Lounge: 158
Emergency Department: 634
Emergency Department Observation: 46
Hematology/Oncology: 100
Hematology/Oncology Intermediate: 34
Labor & Delivery: 63
Med/Surg: 126
Med/Surg/GYN: 89
Med/Surg/Trauma: 70
Medical Intensive Care Unit (MICU): 64
Medical/Surgical (Gynecology): 40
Medical/Surgical Intensive Care Unit (MICU/SICU): 67
Medicine: 362
Medicine/Cardiology: 93
Medicine/Cardiology Intermediate: 54
Neuro Intermediate: 47
Neuro Stepdown: 35
Neuro Surgical Intensive Care Unit (Neuro SICU): 46
Neurology: 111
Observation: