In [1]:
import os
import numpy as np
import json
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

In [2]:
MIMIC_3_DIR = '/home/xueren/Desktop/EMNLP/dataset/MIMIC3-3737'

if not os.path.exists(MIMIC_3_DIR):
    os.makedirs(MIMIC_3_DIR)

In [3]:
with open(os.path.join(MIMIC_3_DIR, 'ICD9CODES.json'), 'r') as f:
    label2desc = json.load(f)
label = list(label2desc.keys())

In [4]:
path_6668 = MIMIC_3_DIR[:-5]

train_df = pd.read_csv(os.path.join(path_6668, 'clean_train.csv'))
val_df = pd.read_csv(os.path.join(path_6668, 'clean_val.csv'))
test_df = pd.read_csv(os.path.join(path_6668, 'clean_test.csv'))

In [5]:
train_df

Unnamed: 0.1,Unnamed: 0,SUBJECT_ID,HADM_ID,TEXT,LABELS,length
0,0,158,169433,admission date discharge date date of birth se...,532.40;493.20;V45.81;412;401.9,51
1,1,2896,178124,name known lastname known firstname unit no nu...,211.3;427.31;578.9;560.1;496;584.9;428.0;276.5...,55
2,2,6495,139808,admission date discharge date date of birth se...,998.59;998.32;905.4;E929.0;041.85,60
3,3,3564,117638,admission date discharge date service doctor l...,038.49;041.6;785.59;518.81;507.0;592.1;591;276...,68
4,4,7995,190945,admission date discharge date date of birth se...,440.22;492.8;401.9;714.0,74
...,...,...,...,...,...,...
47713,47713,16655,105131,admission date discharge date date of birth se...,320.3;996.81;428.0;599.0;038.11;421.0;995.92;4...,7858
47714,47714,339,112625,admission date discharge date date of birth se...,577.0;995.94;574.21;518.81;584.9;482.83;511.9;...,8097
47715,47715,59970,128930,admission date discharge date date of birth se...,444.0;348.30;584.5;532.40;568.81;585.6;518.81;...,8774
47716,47716,25030,172599,admission date discharge date date of birth se...,431;331.4;996.81;403.91;707.0;250.81;780.39;51...,8783


In [6]:
def clean_df(df, label_list):
    df_ = []
    for i in tqdm(range(len(df))):
        sub_id = df['SUBJECT_ID'][i]
        hadm_id = df['HADM_ID'][i]
        text = df['TEXT'][i]
        labels = df['LABELS'][i]
        length = df['length'][i]
        new_label = []
        if not pd.isna(labels):
            for l in labels.split(';'):
                if l not in label_list:
                    continue
                else:
                    new_label.append(l.strip())
                string = ';'.join(new_label)
        if new_label:
            row = [sub_id, hadm_id, text, string, length]
            df_.append(row)
    return df_

In [7]:
clean_train_list = clean_df(train_df, label)
clean_val_list = clean_df(val_df, label)
clean_test_list = clean_df(test_df, label)

clean_train = pd.DataFrame(clean_train_list, columns=['SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS', 'length'])
clean_val = pd.DataFrame(clean_val_list, columns=['SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS', 'length'])
clean_test = pd.DataFrame(clean_test_list, columns=['SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS', 'length'])

clean_train.dropna(subset=['LABELS'], inplace=True)
clean_train.reset_index(drop=True, inplace=True)
clean_val.dropna(subset=['LABELS'], inplace=True)
clean_val.reset_index(drop=True, inplace=True)
clean_test.dropna(subset=['LABELS'], inplace=True)
clean_test.reset_index(drop=True, inplace=True)

100%|█████████████████████████████████████████████████████████| 47718/47718 [00:04<00:00, 9893.79it/s]
100%|███████████████████████████████████████████████████████████| 1631/1631 [00:00<00:00, 8693.56it/s]
100%|███████████████████████████████████████████████████████████| 3372/3372 [00:00<00:00, 8416.88it/s]


In [8]:
clean_label_dict = dict()

for i in range(len(clean_train)):
    labels = clean_train['LABELS'][i]
    if not pd.isna(labels):
        for l in labels.split(';'):
            clean_label_dict[l] = clean_label_dict.get(l, 0) + 1
    else:
        print(i)
        
        
for i in range(len(clean_val)):
    labels = clean_val['LABELS'][i]
    if not pd.isna(labels):
        for l in labels.split(';'):
            clean_label_dict[l] = clean_label_dict.get(l, 0) + 1
    else:
        print(i)
        
        
for i in range(len(clean_test)):
    labels = clean_test['LABELS'][i]
    if not pd.isna(labels):
        for l in labels.split(';'):
            clean_label_dict[l] = clean_label_dict.get(l, 0) + 1
    else:
        print(i)
        
assert len(clean_label_dict.keys()) == len(label)

In [9]:
clean_train.to_csv(os.path.join(MIMIC_3_DIR, 'clean_train.csv'))
clean_val.to_csv(os.path.join(MIMIC_3_DIR, 'clean_val.csv'))
clean_test.to_csv(os.path.join(MIMIC_3_DIR, 'clean_test.csv'))

In [10]:
df = pd.concat([clean_train, clean_val, clean_test]).reset_index(drop=True)
df

Unnamed: 0,SUBJECT_ID,HADM_ID,TEXT,LABELS,length
0,158,169433,admission date discharge date date of birth se...,493.20;412;401.9,51
1,2896,178124,name known lastname known firstname unit no nu...,211.3;427.31;578.9;560.1;496;584.9;428.0;276.5...,55
2,6495,139808,admission date discharge date date of birth se...,998.32,60
3,3564,117638,admission date discharge date service doctor l...,038.49;785.59;518.81;592.1;591;276.2;584.9,68
4,7995,190945,admission date discharge date date of birth se...,440.22;492.8;401.9;714.0,74
...,...,...,...,...,...
52398,96777,176399,admission date discharge date date of birth se...,480.1;996.85;780.39;117.7;204.01;117.3;078.5;2...,5890
52399,95323,142423,admission date discharge date date of birth se...,518.81;486;292.0;453.42;112.0;292.81;276.3;263...,6116
52400,91074,106110,admission date discharge date date of birth se...,486;518.81;584.9;518.0;491.21;428.32;112.2;519...,6117
52401,92316,158581,admission date discharge date date of birth se...,427.41;785.51;807.4;584.9;276.2;276.0;427.5;41...,6227


In [11]:
def cnt_instance_per_label(df, column_name):
    label_cnt = {}
#     column_name = 'ICD9_DIAG'
    for i in range(len(df)):
        if type(df[column_name][i]) == float:
            continue
        ps = df[column_name][i].strip()
        for p in ps.split(';'):
            p = p.strip()
            label_cnt[p] = label_cnt.get(p, 0) + 1
    return label_cnt

def sortBy(l1, l2, reverse=True):
    x_axis, y_axis = [], []
    if l1 and l2:
        zipped = zip(l1, l2)
        sort_zipped = sorted(zipped, key=lambda x:(x[1], x[0]), reverse=reverse)
        result = zip(*sort_zipped)
        x_axis, y_axis = [list(x) for x in result]
    return x_axis, y_axis

In [12]:
label_cnt = cnt_instance_per_label(df, column_name='LABELS')
total_num = list(label_cnt.values())
ICD9CODE = list(label_cnt.keys())

sort_ICD9CODE, sort_total_num = sortBy(ICD9CODE, total_num, reverse=True)

sorted_label_cnt = {}
for i in range(len(sort_ICD9CODE)):
    code = sort_ICD9CODE[i]
    num = sort_total_num[i]
    sorted_label_cnt[code] = num

with open(os.path.join(MIMIC_3_DIR, 'MIMIC3_Label_cnt.json'), 'w') as f:
    json.dump(sorted_label_cnt, f, indent=4)

In [13]:
all_list_len = []
for ls in df['LABELS']:
    l_list = [l for l in ls.split(';')]
    all_list_len.append(len(l_list))
np.mean(all_list_len)

7.996183424613094

In [14]:
from collections import defaultdict
hierarchy2ICD9CODE = defaultdict(list)

for ICD9_CODE, description in label2desc.items():
    
    if ICD9_CODE.startswith('E') or ICD9_CODE.startswith('V'):
        hierarchy2ICD9CODE['external causes of injury and supplemental classification'].append(ICD9_CODE)
    
    else:
        if 0 <= float(ICD9_CODE) < 140:
            hierarchy2ICD9CODE['infectious and parasitic diseases'].append(ICD9_CODE)
        elif 140 <= float(ICD9_CODE) < 240:
            hierarchy2ICD9CODE['neoplasms'].append(ICD9_CODE)
        elif 240 <= float(ICD9_CODE) < 280:
            hierarchy2ICD9CODE['endocrine, nutritional and metabolic diseases, and immunity disorders'].append(ICD9_CODE)
        elif 280 <= float(ICD9_CODE) < 290:
            hierarchy2ICD9CODE['diseases of the blood and blood-forming organs'].append(ICD9_CODE)
        elif 290 <= float(ICD9_CODE) < 320:
            hierarchy2ICD9CODE['mental disorders'].append(ICD9_CODE)
        elif 320 <= float(ICD9_CODE) < 390:
            hierarchy2ICD9CODE['diseases of the nervous system and sense organs'].append(ICD9_CODE)
        elif 390 <= float(ICD9_CODE) < 460:
            hierarchy2ICD9CODE['diseases of the circulatory system'].append(ICD9_CODE)
        elif 460 <= float(ICD9_CODE) < 520:
            hierarchy2ICD9CODE['diseases of the respiratory system'].append(ICD9_CODE)
        elif 520 <= float(ICD9_CODE) < 580:
            hierarchy2ICD9CODE['diseases of the digestive system'].append(ICD9_CODE)
        elif 580 <= float(ICD9_CODE) < 630:
            hierarchy2ICD9CODE['diseases of the genitourinary system'].append(ICD9_CODE)
        elif 630 <= float(ICD9_CODE) < 680:
            hierarchy2ICD9CODE['complications of pregnancy, childbirth, and the puerperium'].append(ICD9_CODE)
        elif 680 <= float(ICD9_CODE) < 710:
            hierarchy2ICD9CODE['diseases of the skin and subcutaneous tissue'].append(ICD9_CODE)
        elif 710 <= float(ICD9_CODE) < 740:
            hierarchy2ICD9CODE['diseases of the musculoskeletal system and connective tissue'].append(ICD9_CODE)
        elif 740 <= float(ICD9_CODE) < 760:
            hierarchy2ICD9CODE['congenital anomalies'].append(ICD9_CODE)
        elif 760 <= float(ICD9_CODE) < 780:
            hierarchy2ICD9CODE['certain conditions originating in the perinatal period'].append(ICD9_CODE)
        elif 780 <= float(ICD9_CODE) < 800:
            hierarchy2ICD9CODE['symptoms, signs, and ill-defined conditions'].append(ICD9_CODE)
        elif 800 <= float(ICD9_CODE) < 1000:
            hierarchy2ICD9CODE['injury and poisoning'].append(ICD9_CODE)
        else:
            print('anomaly code {}'.format(ICD9_CODE))
            break
            
            
ICD9CODE2hierarchy = {}
for hier, ICD9_CODES in hierarchy2ICD9CODE.items():
    for ICD9_CODE in ICD9_CODES:
        ICD9CODE2hierarchy[ICD9_CODE] = hier

with open(os.path.join(MIMIC_3_DIR, 'p2hier.json'), 'w') as f:
    json.dump(ICD9CODE2hierarchy, f, indent=4)
    
with open(os.path.join(MIMIC_3_DIR, 'hier2p.json'), 'w') as f:
    json.dump(hierarchy2ICD9CODE, f, indent=4)

In [15]:
len(hierarchy2ICD9CODE.keys())

17