In [9]:
import pandas as pd
import numpy as np
import pickle

import random
random.seed(10)

In [10]:
path_to_data = '/root/data/mimic-iv_data/hosp'
    
dfd = pd.read_csv(f'{path_to_data}/diagnoses_icd.csv')

# Add ICD version to code
dfd['icd_code'] = dfd['icd_code'] + '-' + dfd['icd_version'].astype(str)

# HADM IDs are not sorted numerically, need to sort by admission time
dfa = pd.read_csv(f'{path_to_data}/admissions.csv')

# Merge relevant columns from all datasets
dfd = dfd.drop('icd_version', axis=1)
dfa = dfa[['subject_id', 'hadm_id', 'admittime']]
df = pd.merge(dfd, dfa, how='inner')
df['admittime'] = pd.to_datetime(df['admittime'])

# Ensure sorted and keep relevant columns
df = df.sort_values(by=['subject_id', 'admittime', 'seq_num'], ascending=True)

In [11]:
# Filter ICD codes not in dataset
with open("./data/006_icd_to_vsa_data.pkl", "rb") as f:
    icd_to_vsa_data = pickle.load(f)

# Keep codes that appear in mappings and dataset
clean_codes = np.array(list(icd_to_vsa_data.keys()))

df_clean = df[df.icd_code.isin(clean_codes)]
print(df_clean['icd_code'].unique().shape[0])

# Previously, there were 25460 unique codes
df = df_clean

25493


In [12]:
df[:50]

Unnamed: 0,subject_id,hadm_id,seq_num,icd_code,admittime
0,10000032,22595853,1,5723-9,2180-05-06 22:23:00
1,10000032,22595853,2,78959-9,2180-05-06 22:23:00
2,10000032,22595853,3,5715-9,2180-05-06 22:23:00
3,10000032,22595853,4,07070-9,2180-05-06 22:23:00
4,10000032,22595853,5,496-9,2180-05-06 22:23:00
5,10000032,22595853,6,29680-9,2180-05-06 22:23:00
6,10000032,22595853,7,30981-9,2180-05-06 22:23:00
7,10000032,22595853,8,V1582-9,2180-05-06 22:23:00
8,10000032,22841357,1,07071-9,2180-06-26 18:27:00
9,10000032,22841357,2,78959-9,2180-06-26 18:27:00


In [13]:
res = {}
res_5 = {}
for item in df[['subject_id', 'hadm_id']].values.tolist():
    res.setdefault(item[0], []).append(item[1])
for k in res:
    res[k] = list(dict.fromkeys(res[k]))
    if len(res[k]) >= 5:
        res_5[k] = res[k]


res_5_keys = list(res_5.keys())
random.shuffle(res_5_keys)

res_5_test = res_5_keys[:5000]
res_5_train = res_5_keys[5000:]

print(len(res.keys()))
[res.pop(k, None) for k in res_5_test]
[res_5.pop(k, None) for k in res_5_train]
print(len(res.keys()))
print(len(res_5.keys()))

190121
185121
5000


In [15]:
with open("./data/mimic-iv_data/test_subject_id.txt", 'w') as f:
    for id in res_5_test:
        f.write(str(id))

In [16]:
res

{10000032: [22595853, 22841357, 29079034, 25742920],
 10000068: [25022803],
 10000084: [23052089, 29888819],
 10000108: [27250926],
 10000117: [22927623, 27988844],
 10000248: [20600184],
 10000280: [25852320],
 10000560: [28979390],
 10000635: [26134563],
 10000719: [24558333],
 10000724: [20823482],
 10000764: [27897940],
 10000826: [20032235, 21086876, 28289260],
 10000883: [29957930, 25221576],
 10000886: [21927847],
 10000904: [28328117],
 10000935: [29541074, 24955974, 21738619, 26381316, 25849114],
 10001176: [23334588],
 10001180: [21102262, 25534937, 27864856],
 10001186: [24906418, 24016413, 21334040],
 10001217: [24597018, 27703517],
 10001319: [29230609, 23005466, 24591241],
 10001338: [22119639, 28835314, 29335220, 27987619],
 10001401: [21544441, 26840593, 24818636, 27060146, 28058085, 27012892],
 10001472: [23506139],
 10001492: [27463908],
 10001663: [23405714],
 10001667: [22672901],
 10001725: [25563031],
 10001843: [21728396],
 10001860: [21441082],
 10001877: [25679

In [17]:
d = {hadm_id: i for v in res.values() for i, hadm_id in enumerate(v, start=1) }
d_5 = {hadm_id: i for v in res_5.values() for i, hadm_id in enumerate(v, start=1) }

In [18]:
d

{22595853: 1,
 22841357: 2,
 29079034: 3,
 25742920: 4,
 25022803: 1,
 23052089: 1,
 29888819: 2,
 27250926: 1,
 22927623: 1,
 27988844: 2,
 20600184: 1,
 25852320: 1,
 28979390: 1,
 26134563: 1,
 24558333: 1,
 20823482: 1,
 27897940: 1,
 20032235: 1,
 21086876: 2,
 28289260: 3,
 29957930: 1,
 25221576: 2,
 21927847: 1,
 28328117: 1,
 29541074: 1,
 24955974: 2,
 21738619: 3,
 26381316: 4,
 25849114: 5,
 23334588: 1,
 21102262: 1,
 25534937: 2,
 27864856: 3,
 24906418: 1,
 24016413: 2,
 21334040: 3,
 24597018: 1,
 27703517: 2,
 29230609: 1,
 23005466: 2,
 24591241: 3,
 22119639: 1,
 28835314: 2,
 29335220: 3,
 27987619: 4,
 21544441: 1,
 26840593: 2,
 24818636: 3,
 27060146: 4,
 28058085: 5,
 27012892: 6,
 23506139: 1,
 27463908: 1,
 23405714: 1,
 22672901: 1,
 25563031: 1,
 21728396: 1,
 21441082: 1,
 25679292: 1,
 21320596: 2,
 21268656: 1,
 26679629: 2,
 23594368: 3,
 21577720: 4,
 24325811: 5,
 26812645: 6,
 27765344: 7,
 25758848: 8,
 29675586: 9,
 26170293: 10,
 27016754: 11,
 262

In [19]:
dfv = pd.DataFrame({'hadm_id': list(d.keys()), 'visit_order': list(d.values())})
dfv_5 = pd.DataFrame({'hadm_id': list(d_5.keys()), 'visit_order': list(d_5.values())})

In [20]:
dfv[dfv['hadm_id']==22595853]

Unnamed: 0,hadm_id,visit_order
0,22595853,1


In [21]:
df_5 = pd.merge(df, dfv_5)
df = pd.merge(df, dfv)

In [22]:
df[:50]

Unnamed: 0,subject_id,hadm_id,seq_num,icd_code,admittime,visit_order
0,10000032,22595853,1,5723-9,2180-05-06 22:23:00,1
1,10000032,22595853,2,78959-9,2180-05-06 22:23:00,1
2,10000032,22595853,3,5715-9,2180-05-06 22:23:00,1
3,10000032,22595853,4,07070-9,2180-05-06 22:23:00,1
4,10000032,22595853,5,496-9,2180-05-06 22:23:00,1
5,10000032,22595853,6,29680-9,2180-05-06 22:23:00,1
6,10000032,22595853,7,30981-9,2180-05-06 22:23:00,1
7,10000032,22595853,8,V1582-9,2180-05-06 22:23:00,1
8,10000032,22841357,1,07071-9,2180-06-26 18:27:00,2
9,10000032,22841357,2,78959-9,2180-06-26 18:27:00,2


In [23]:
df = df[['subject_id', 'visit_order', 'seq_num', 'icd_code']]
df_5 = df_5[['subject_id', 'visit_order', 'seq_num', 'icd_code']]

In [24]:
df

Unnamed: 0,subject_id,visit_order,seq_num,icd_code
0,10000032,1,1,5723-9
1,10000032,1,2,78959-9
2,10000032,1,3,5715-9
3,10000032,1,4,07070-9
4,10000032,1,5,496-9
...,...,...,...,...
4192810,19999987,1,7,41401-9
4192811,19999987,1,8,78039-9
4192812,19999987,1,9,0413-9
4192813,19999987,1,10,36846-9


In [26]:
from tqdm import tqdm

In [30]:
# Turn df into mapping
dataset_dict = {}
for (index, subject_id, visit_order, seq_num, icd_code) in tqdm(df.itertuples()):
    dataset_dict.setdefault(subject_id, {}).setdefault("icd_code", []).append(icd_code)
    dataset_dict.setdefault(subject_id, {}).setdefault("visit_order", []).append(visit_order)
    dataset_dict.setdefault(subject_id, {}).setdefault("seq_num", []).append(seq_num)

dataset_dict_test = {}
for (index, subject_id, visit_order, seq_num, icd_code) in tqdm(df_5.itertuples()):
    dataset_dict_test.setdefault(subject_id, {}).setdefault("icd_code", []).append(icd_code)
    dataset_dict_test.setdefault(subject_id, {}).setdefault("visit_order", []).append(visit_order)
    dataset_dict_test.setdefault(subject_id, {}).setdefault("seq_num", []).append(seq_num)

4192815it [00:07, 551179.55it/s]
560951it [00:00, 597798.76it/s]


In [31]:
import pickle
with open('./data/mimic-iv_data/mimic_icd_data_dict.pkl', 'wb') as f:
     pickle.dump(dataset_dict, f)

with open('./data/mimic-iv_data/mimic_icd_data_dict_test.pkl', 'wb') as f:
     pickle.dump(dataset_dict_test, f)

## Make a hf dataset

In [45]:
from datasets import Dataset
from transformers import AutoTokenizer
import pickle

tokenizer = AutoTokenizer.from_pretrained('/root/data/tokenizer-mimic-iv-icd-final/')

with open('./data/mimic-iv_data/mimic_icd_data_dict.pkl', 'rb') as f:
    datadict = pickle.load(f)

# with open('./data/mimic-iv_data/mimic_icd_data_dict_test.pkl', 'rb') as f:
#     datadict = pickle.load(f)

tokenization_params = {
    'max_length': 128,
    'truncation': True,
    'padding': 'max_length',
    'is_split_into_words': True,
    'return_special_tokens_mask': True
}

In [46]:
dd = {}
for d in datadict.values():
    for k, v in d.items():
        dd.setdefault(k, []).append(v)

In [47]:
ds = Dataset.from_dict(dd)

In [48]:
def process_function(entry, tokenization_params):
    codes = entry['icd_code']
    # Get coding embedding ids
    encoding = tokenizer(codes, **tokenization_params)
    # Add token type ids for visit sequence
    num_codes_to_keep = min(tokenization_params['max_length'] - 2, len(codes))
    encoding['token_type_ids'][1:num_codes_to_keep+1] = entry['visit_order'][:num_codes_to_keep]
    entry.update(encoding)
    return entry
ds = ds.map(lambda x: process_function(x, tokenization_params), num_proc=4, remove_columns=['icd_code', 'visit_order', 'seq_num'])

     

#0:   0%|          | 0/46281 [00:00<?, ?ex/s]

 

#1:   0%|          | 0/46280 [00:00<?, ?ex/s]

 

#2:   0%|          | 0/46280 [00:00<?, ?ex/s]

 

#3:   0%|          | 0/46280 [00:00<?, ?ex/s]

In [49]:
ds

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'],
    num_rows: 185121
})

In [50]:
ds.save_to_disk('./data/mimic-iv_data/mimic_icd_hf_dataset_train')
# ds.save_to_disk('./data/mimic-iv_data/mimic_icd_hf_dataset_test')

## Split datasets into train and test

In [1]:
import datasets

In [3]:
# Re-save the "train" dataset as pretraining
ds = datasets.load_from_disk("./data/mimic-iv_data/mimic_icd_hf_dataset_train_v2")
dsd = ds.train_test_split(test_size=0.1)

In [4]:
dsd

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'],
        num_rows: 166608
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'],
        num_rows: 18513
    })
})

In [5]:
dsd.save_to_disk("/root/data/mimic-iv_data/mimic_icd_hf_dataset_pretraining_v2")

Flattening the indices:   0%|          | 0/167 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/19 [00:00<?, ?ba/s]

In [6]:
# Re-save "test" dataset as finetuning

ds = datasets.load_from_disk("./data/mimic-iv_data/mimic_icd_hf_dataset_test_v2")
dsd = ds.train_test_split(test_size=1000)

In [7]:
dsd

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'],
        num_rows: 4000
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'],
        num_rows: 1000
    })
})

In [8]:
dsd.save_to_disk("/root/data/mimic-iv_data/mimic_icd_hf_dataset_finetuning_v2")

Flattening the indices:   0%|          | 0/4 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]