In [1]:
import csv
import os
import json
import re
import copy
import pickle
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

# get age & sex info of subjects

In [2]:
def prepare_subject_key_info(
        splits_path='./splits',
        splits=['train', 'val', 'test'],
        subject_list_fn='subject_list.txt', 
        read_path='./subjects_with_PI',
        info_fn='age_sex_info.json',
    ):    
    assert os.path.exists(splits_path)
    
    subject_list_path = os.path.join('./', subject_list_fn)
    if os.path.exists(subject_list_path):
        subject_list = open(subject_list_path, 'r').read().strip().split('\n')
    else:
        subject_list = set()
        for split in tqdm(splits):
            f = os.path.join(splits_path, split+'.csv')
            f = pd.read_csv(f)
            for i in tqdm(range(len(f))):
                subject_list.add(f.iloc[i]['key'].split('_')[0])
        subject_list = sorted(list(subject_list))
        
        with open(subject_list_path, 'w') as f:
            f.write('\n'.join(subject_list))

    info_path = os.path.join(splits_path, info_fn)
    if os.path.exists(info_path):
        info = json.load(open(info_path, 'r'))
    else:
        info = {'subject_info': {}, 'key_info': {}}

    RE_PI = r'discharge\s*instructions.{0,30}:'
    failed_age = set()

    for id in tqdm(subject_list):
        subject_path = os.path.join(read_path, id+'.csv')

        with open(subject_path, 'r') as f:
            reader = csv.reader(f, delimiter=',')
            header = next(reader)
            
            sex_of_this_subject, possible_sex_of_this_subject, sex_diff = None, None, 0
            age_of_this_subject, age_chartyear = None, None
            the_first_discharge_year = None
            
            already_has_age = (id in info['subject_info']) and ('age' in info['subject_info'][id]) and info['subject_info'][id]['age'] < 200

            for line in reader:
                subject_id = line[1]
                assert subject_id == id

                if subject_id not in info['subject_info']:
                    info['subject_info'][subject_id] = {}
                info['subject_info'][subject_id]['keys'] = []
                
                text = line[-1].lower()

                female = re.compile(r'\sfemale').findall(text) \
                        + re.compile(r'\swoman\s').findall(text) \
                        + re.compile(r'\sshe\s').findall(text) \
                        + re.compile(r'\sher\s').findall(text) \
                        + re.compile(r'\s*elderly f[,\s]').findall(text)
                male = re.compile(r'\smale').findall(text) \
                        + re.compile(r'\sman\s').findall(text) \
                        + re.compile(r'\she\s').findall(text) \
                        + re.compile(r'\shis\s').findall(text) \
                        + re.compile(r'\s*elderly m\s*').findall(text)
                
                diff = (abs(len(female) - len(male))) / max(1, min([len(female), len(male)]))
                if diff > sex_diff:
                    sex_diff = diff
                    if len(female) > len(male):
                        possible_sex_of_this_subject = 'f'
                    elif len(male) > len(female):
                        possible_sex_of_this_subject = 'm'
                
                if age_of_this_subject is None and not already_has_age:
                    age_chartyear = int(line[3].split('-')[0])

                    age_range = re.compile(r'\[\*\*age\s*over\s*(\d*?)\s\S*\*\*\]').findall(text)
                    if len(age_range) and int(age_range[0]) < 200:
                        # we can only know the lower bound of pt's age
                        age_of_this_subject = int(age_range[0])
                    else:
                        age = re.compile(r'\D(\d+?) y.\s*o.\s*').findall(text) \
                            + re.compile(r'\D*(\d+?)\s*yo\b').findall(text) \
                            + re.compile(r'\D*(\d+?)\s*yo[fm]\b').findall(text) \
                            + re.compile(r'\D(\d+?)\Dyear\Dold\s').findall(text) \
                            + re.compile(r'\D*(\d+?)\s*[FM]\b').findall(line[-1])
                        
                        if len(age) and int(age[0]) < 200:
                            # directly get pt's age
                            age_of_this_subject = int(age[0])
                        else:
                            birthday = re.compile(r'date\s*of\s*birth:\s*\[\*\*(.*?)\*\*\]').findall(text)
                            if len(birthday):
                                # calculate pt's age
                                birthday_year = int(birthday[0].split('-')[0])
                                assert age_chartyear >= birthday_year
                                age_of_this_subject = age_chartyear - birthday_year
                            else:
                                age_vague = re.compile(r'\s*elderly\s*').findall(text)
                                if len(age_vague):
                                    # we define the age of elderly people as 60
                                    age_of_this_subject = 60

                if line[6].lower() == 'discharge summary' and line[7].lower() == 'report':
                    if re.search(RE_PI, repr(text)) is None:
                        continue

                    chartyear = int(line[3].split('-')[0])

                    if the_first_discharge_year is None:
                        the_first_discharge_year = chartyear

                    sex = re.compile(r'sex:\s*([a-z]]*?)\n').findall(text)
                    if len(sex) > 0:
                        if sex_of_this_subject is not None:                   
                            assert sex_of_this_subject == sex[0], subject_id
                        else:
                            sex_of_this_subject = sex[0]
                    
                    admission_id = line[2]
                    key = f'{subject_id}_{admission_id}'
                    info['subject_info'][subject_id]['keys'].append(key)
                    info['key_info'][key] = {'chartyear': chartyear}


            sex = sex_of_this_subject or possible_sex_of_this_subject
            assert sex
            info['subject_info'][subject_id]['sex'] = sex

            if not already_has_age:
                if age_of_this_subject is None:
                    failed_age.add(subject_id)
                    info['subject_info'][subject_id]['age'] = 70
                    info['subject_info'][subject_id]['year'] = the_first_discharge_year
                else:
                    info['subject_info'][subject_id]['age'] = age_of_this_subject
                    info['subject_info'][subject_id]['year'] = age_chartyear
    
    if len(failed_age):
        print('- Warning: we can not infer the age of subjects', failed_age)
        print('- We deafult set the age of these subjects to 70')
    
    print('-'*100)
    print('- Updating key info ...')
    for key in tqdm(info['key_info']):
        subject_id = key.split('_')[0]
        info['key_info'][key]['sex'] = info['subject_info'][subject_id]['sex']
        info['key_info'][key]['age'] = info['subject_info'][subject_id]['age'] \
                                    + info['key_info'][key]['chartyear'] \
                                    - info['subject_info'][subject_id]['year']
        
        if info['key_info'][key]['age'] <= 0:
            print(info['subject_info'][subject_id]['age'], info['subject_info'][subject_id]['year'], info['key_info'][key]['chartyear'])
            info['key_info'][key]['age'] = 1

    json.dump(info, open(info_path, 'w'))

In [3]:
prepare_subject_key_info()

100%|██████████| 28029/28029 [43:09<00:00, 10.82it/s]  


- We deafult set the age of these subjects to 70
----------------------------------------------------------------------------------------------------
- Updating key info ...


100%|██████████| 36748/36748 [00:00<00:00, 309887.76it/s]


0 2103 2103
0 2181 2181
5 2206 2200


# Subtask for different sexes

In [11]:
info = json.load(open('./splits/age_sex_info.json', 'r'))
for split in ['train', 'val', 'test']:
    path = os.path.join('./splits/', split+'.csv')
    df = pd.read_csv(path)

    unique_subjects = set()
    sex_count = {'f': 0, 'm': 0}
    subtask_sex = {'f': [], 'm': []}
    for i in range(len(df)):
        key = df.iloc[i]['key']
        id = key.split('_')[0]
        if id not in unique_subjects:
            unique_subjects.add(id)
            sex_count[info['subject_info'][id]['sex']] += 1
        
        subtask_sex[info['subject_info'][id]['sex']].append(key)
    
    print(f'===== {split} =====')
    print('There are %d subjects, %d females and %d males.' % (len(unique_subjects), sex_count['f'], sex_count['m']))
    print('There are %d hospital admissions, %d females and %d males.' % (len(df), len(subtask_sex['f']), len(subtask_sex['m'])))

    if split == 'test':
        for k, v in subtask_sex.items():
            save_path = f'./splits/subtasks/sex/{split}'
            os.makedirs(save_path, exist_ok=True)
            save_path = os.path.join(save_path, f'{k}.txt')
            
            print('- Saving the subtask file to', save_path)
            with open(save_path, 'w') as f:
                f.write('\n'.join(v))

===== train =====
There are 22423 subjects, 9629 females and 12794 males.
There are 28673 hospital admissions, 12435 females and 16238 males.
===== val =====
There are 2803 subjects, 1223 females and 1580 males.
There are 3557 hospital admissions, 1607 females and 1950 males.
===== test =====
There are 2803 subjects, 1267 females and 1536 males.
There are 3621 hospital admissions, 1649 females and 1972 males.
- Saving the subtask file to ./splits/subtasks/sex/test/f.txt
- Saving the subtask file to ./splits/subtasks/sex/test/m.txt


# Subtask for different ages

In [5]:
import plotly.offline as py
import plotly.graph_objs as go
import matplotlib.pyplot as plt

info = json.load(open('./splits/age_sex_info.json', 'r'))
age_range = [(0, 55), (55, 70), (70, 200)]

for split in ['train', 'val', 'test']:
    path = os.path.join('./splits/', split+'.csv')
    df = pd.read_csv(path)

    age_count = [0] * 200
    min_age, max_age = 200, 0
    subtask_age = defaultdict(list)
    for i in range(len(df)):
        key = df.iloc[i]['key']
        age = info['key_info'][key]['age']
        age_count[age] += 1
        min_age = min(min_age, age)
        max_age = max(max_age, age)

        for begin, end in age_range:
            k = f'{begin}_{end}'
            if begin <= age < end:
                subtask_age[k].append(key)
    
    print(f'===== {split} =====')
    print('There are %d hospital admissions' % len(df))
    for k, v in subtask_age.items():
        print('%10s: %d (%.3f)' % (k, len(v), len(v) / len(df)))

    x = [_ for _ in range(max_age - min_age + 1)]
    y = age_count[min_age:max_age+1]
    py.iplot([go.Bar(x=x, y=y, marker={'color': 'red', 'line': {'width': 0.1, 'color': 'red'}})])
    plt.show()
    
    if split == 'test':
        for k, v in subtask_age.items():
            save_path = f'./splits/subtasks/age/{split}'
            os.makedirs(save_path, exist_ok=True)
            save_path = os.path.join(save_path, f'{k}.txt')
            
            print('- Saving the subtask file to', save_path)
            with open(save_path, 'w') as f:
                f.write('\n'.join(v))


===== train =====
There are 28673 hospital admissions
      0_55: 8447 (0.295)
     55_70: 8752 (0.305)
    70_200: 11474 (0.400)


===== val =====
There are 3557 hospital admissions
    70_200: 1375 (0.387)
     55_70: 1070 (0.301)
      0_55: 1112 (0.313)


===== test =====
There are 3621 hospital admissions
    70_200: 1368 (0.378)
      0_55: 1144 (0.316)
     55_70: 1109 (0.306)


- Saving the subtask file to ./splits/subtasks/age/test/70_200.txt
- Saving the subtask file to ./splits/subtasks/age/test/0_55.txt
- Saving the subtask file to ./splits/subtasks/age/test/55_70.txt


# Subtask for different diseases

In [9]:
from collections import Counter, defaultdict

def find(data, mapping, string):
    n_string = string[2:].replace('.', '').strip()
    pos = mapping[n_string]
    print(string, data.iloc[pos]['SHORT_TITLE'], '--------', data.iloc[pos]['LONG_TITLE'])

data = pd.read_csv('./D_ICD_DIAGNOSES.csv')
mapping = list(data['ICD9_CODE'])
mapping = {str(v).strip(): i for i, v in enumerate(mapping)}

aid2dx = pickle.load(open('./diagnose-procedure-medication/admDxMap_mimic3.pk', 'rb'))
# you should run the cmd below to get adjacent_matrix_code_map.pkl
# python pretreatments/prepare_codes_adjacent_matrix.py
code_map = pickle.load(open('./info/adjacent_matrix_code_map.pkl', 'rb'))

topk = 10
all_codes = []
# for split in ['train', 'val', 'test']:
for split in ['test']:
    key2code = {}
    code2key = defaultdict(set)
    path = os.path.join('./splits/', split+'.csv')
    df = pd.read_csv(path)

    for i in range(len(df)):
        key = df.iloc[i]['key']
        aid = int(key.split('_')[1])
        if aid not in aid2dx:
            continue
        
        key2code[key] = set(aid2dx[aid])

        for code in key2code[key]:
            if code not in code_map:
                continue
            
            all_codes.append(code)
            code2key[code.split('.')[0]].add(key)

    print(f'===== {split} =====')
    print('There are %d hospital admissions, only %d of them have Dx codes' % (len(df), len(key2code)))

    all_codes = Counter(all_codes)
    all_codes = sorted([(k, v) for k, v in all_codes.items()], key=lambda x: -x[1])
    code2key = sorted([(k, v) for k, v in code2key.items()], key=lambda x: -len(x[1]))

    cover_keys = set()
    subtask_disease = {}
    
    for i in range(topk):
        code_prefix = code2key[i][0]
        for k, v in all_codes:
            if code_prefix in k:
                one_of_code = k
                break
        
        find(data, mapping, one_of_code)
        this_keys = code2key[i][1]
        cover_keys = cover_keys | this_keys
        print(i+1, 'this coverage:', len(this_keys), len(set([key.split('_')[0] for key in this_keys])))
        print(i+1, ' all coverage:', len(cover_keys))

        subtask_disease[code_prefix] = list(this_keys)
    
    if split == 'test':
        for k, v in subtask_disease.items():
            save_path = f'./splits/subtasks/disease/{split}'
            os.makedirs(save_path, exist_ok=True)
            save_path = os.path.join(save_path, f'{k}.txt')
            
            print('- Saving the subtask file to', save_path)
            with open(save_path, 'w') as f:
                f.write('\n'.join(v))


===== test =====
There are 3621 hospital admissions, only 3621 of them have Dx codes
D_401.9 Hypertension NOS -------- Unspecified essential hypertension
1 this coverage: 1600 1354
1  all coverage: 1600
D_272.4 Hyperlipidemia NEC/NOS -------- Other and unspecified hyperlipidemia
2 this coverage: 1187 981
2  all coverage: 2031
D_427.31 Atrial fibrillation -------- Atrial fibrillation
3 this coverage: 1162 961
3  all coverage: 2457
D_414.01 Crnry athrscl natve vssl -------- Coronary atherosclerosis of native coronary artery
4 this coverage: 1103 898
4  all coverage: 2605
D_250.00 DMII wo cmp nt st uncntr -------- Diabetes mellitus without mention of complication, type II or unspecified type, not stated as uncontrolled
5 this coverage: 1072 768
5  all coverage: 2771
D_276.2 Acidosis -------- Acidosis
6 this coverage: 1019 870
6  all coverage: 3001
D_285.9 Anemia NOS -------- Anemia, unspecified
7 this coverage: 1018 868
7  all coverage: 3131
D_428.0 CHF NOS -------- Congestive heart failu