In [1]:
import pickle
import numpy as np
import pandas as pd
from datetime import datetime
from tqdm import tqdm

In [2]:
with open('../data/preprocessed_X_visit_over3.pkl', 'rb') as f:
    data_x = pickle.load(f)

with open('../data/preprocessed_y_visit_over3.pkl', 'rb') as f: 
    labels = pickle.load(f)

In [3]:
visit_gt2t_new = {}
for k, v in data_x.items():
    # print(k, v)
    visit_new = {}
    for v_k, v_v in v.items():
        total_list = list(set(v_v['diagnoses'])) + list(set(v_v['procedures'])) + v_v['drugs']
        # visit_new[v_k] = v_v
        if len(total_list) > 1:
            visit_new[v_k] = {'diagnoses': None, 'procedures':None, 'drugs': None, 'admitdate':None}
            visit_new[v_k]['diagnoses'] = list(set(v_v['diagnoses'])) 
            visit_new[v_k]['procedures'] = list(set(v_v['procedures'])) 
            visit_new[v_k]['drugs'] = list(set(v_v['drugs'])) 
            visit_new[v_k]['admitdate'] = v_v['admitdate']
    visit_gt2t_new[k] = visit_new

In [4]:
adm = pd.read_csv('./admissions.csv.gz', compression='gzip', usecols=['subject_id', 'hadm_id', 'admittime', 'admission_type', 'race'])
patients = pd.read_csv('./patients.csv')
patients = patients[['subject_id', 'gender','anchor_age','anchor_year']]
patients['yob']= patients['anchor_year'] - patients['anchor_age']

In [5]:
visit_gt2t_last = {}
for k, v in visit_gt2t_new.items():
    # break
    if len(v) >= 2:
        visit_gt2t_last[k] = v

In [7]:
data_dict = {}
for p_id, v_data in tqdm(visit_gt2t_last.items()):
    visit_data = {'seq': [], 'data':[], 'time': [], 
                  'gender': patients.loc[patients['subject_id'] == p_id, 'gender'].unique()[0],
                  'race': adm.loc[adm['subject_id'] == p_id, 'race'].unique()[0]}
    for v_id, record in v_data.items():
        record['visit_id'] = v_id
        record['admission_type'] = adm.loc[(adm['subject_id'] == p_id) & (adm['hadm_id'] == v_id), 'admission_type'].values[0]
        dt_obj = datetime.strptime(record['admitdate'], '%Y-%m-%d')
        age_at_adm = dt_obj.year - patients.loc[patients['subject_id'] == p_id, 'yob'].values[0]
        if age_at_adm < 0:
            age_at_adm = 90
        record['age'] = age_at_adm
        seq_data = record['diagnoses'] + record['procedures'] + record['drugs']
        visit_data['seq'].append(seq_data)
        visit_data['data'].append(record)
        visit_data['time'].append(dt_obj)
    
    visit_data['seq'].append(['global'])
    visit_data['time'].append(visit_data['time'][-1])
    
    td_list = [(visit_data['time'][-1]- dt).days for dt in visit_data['time']]
    visit_data['timedelta'] = td_list
    visit_data['labels_origin'] = labels[p_id]['top100_label']
    visit_data['labels_bin'] = labels[p_id]['top100_label_bin']
    data_dict[p_id] = visit_data

In [34]:
with open('../data/data_dict.pkl', 'wb') as f:
    pickle.dump(data_dict, f)

In [4]:
with open('../data/data_dict.pkl', 'rb') as f:
    data_dict = pickle.load(f)

In [36]:
d_list = list()
p_list = list()
dr_list = list()
for k, visit in data_dict.items():
    for v in visit['data']:
        d_list.extend(v['diagnoses'])
        p_list.extend(v['procedures'])
        dr_list.extend(v['drugs'])
        
d_list = list(set(d_list))
p_list = list(set(p_list))
dr_list = list(set(dr_list))

In [37]:
len(d_list), len(p_list), len(dr_list)

(869, 747, 3528)

In [38]:
code2idx = {'padding':0}

idx = 1
for d in d_list:
    code2idx[d] = idx
    idx += 1

for p in p_list:
    code2idx[p] = idx
    idx += 1

dr_idx = 1
for dr in dr_list:  
    code2idx[dr] = idx
    idx += 1

code2idx['global'] = idx

In [45]:
code2idx[dr]

5144

In [46]:
with open('../data/code_indices/code2idx.pkl', 'wb') as f:
    pickle.dump(code2idx, f)

In [2]:
import pickle

In [7]:
with open('../data/code_indices/code2idx.pkl', 'rb') as f:
    code2idx = pickle.load(f)

In [8]:
data_dict_idx = data_dict.copy()

In [9]:
for k, visit in tqdm(data_dict_idx.items()):
    visit['seq_idx'] = [list(map(code2idx.get, v)) for v in visit['seq']]

  0%|          | 0/43096 [00:00<?, ?it/s]

100%|██████████| 43096/43096 [00:00<00:00, 46217.90it/s]


In [10]:
data_dict_new= dict()
for p_id in tqdm(list(data_dict_idx.keys())):
    p_data = {'visit':[], 'visit_idx': [], 'visit_length':0, 'code_length':[], \
              'seq':None, 'seq_idx':None, 'code_types':None, 'timedelta': None, 'label':[]}
    
    for v_id, v_data in enumerate(data_dict_idx[p_id]['data']):
        p_data['visit'].append(v_data['visit_id'])
        p_data['visit_idx'].append(v_id)
        p_data['visit_length'] += 1
        p_data['code_length'].append(len(data_dict_idx[p_id]['seq'][v_id]))
        p_data['seq'] = data_dict_idx[p_id]['seq']
        p_data['seq_idx'] = data_dict_idx[p_id]['seq_idx']
        p_data['timedelta'] = data_dict_idx[p_id]['timedelta']
    
    p_data['label'] = np.squeeze(data_dict_idx[p_id]['labels_bin'])
    data_dict_new[p_id] = p_data

100%|██████████| 43096/43096 [00:00<00:00, 100583.68it/s]


In [11]:
data_dict_max50 = dict()
for p_id , data in data_dict_new.items():
    new_data = dict()
    if len(data['visit']) > 50:
        print(p_id, len(data['visit']))
    new_data['visit'] = data['visit'][-50:]
    new_data['visit_idx'] = data['visit_idx'][-50:]
    new_data['visit_length'] = len(data['visit'][-50:]) + 1
    new_data['code_length'] = data['code_length'][-50:]
    new_data['seq'] = data['seq'][-51:]
    new_data['seq_idx'] = data['seq_idx'][-51:]
    new_data['timedelta'] = data['timedelta'][-51:]
    new_data['label'] = data['label']
    data_dict_max50[p_id] = new_data

10123949 55
10264646 76
10577647 78
10578325 67
10580201 59
10714009 113
10913302 64
11021643 63
11296936 76
11413236 77
11553072 69
11582633 81
11714071 62
11761621 56
11818101 64
11888614 56
11890447 66
11965254 90
12251785 79
12468016 88
12547294 69
12563258 60
12596559 63
13166511 53
13297743 91
13470788 52
13475033 96
13813803 68
13877234 60
13999829 54
14029474 52
14318651 81
14394983 124
15084163 53
15107347 51
15108590 52
15114531 67
15229574 79
15464144 75
15496609 133
15935768 56
16124481 62
16233333 67
16439884 58
16615356 55
16662316 129
16675371 66
16809525 57
16924675 54
17011846 56
17051420 54
17204468 59
17340686 51
17477304 51
17517983 94
17716210 80
17937834 61
18001923 69
18136887 62
18284271 82
18376342 52
18553055 55
18655830 69
18656167 55
18676703 59
18902344 70
18970086 68
19127408 51
19133405 78
19610016 52
19713100 62
19759225 63
19921471 57


In [12]:
len(data_dict_max50)

43096

In [13]:
with open('../data/data_dict_preprocess_maxlen50.pkl', 'wb') as f:
    pickle.dump(data_dict_max50, f)

In [11]:
import torch

In [77]:
seq = torch.tensor([123, 78, 32, 0, 0, 100000, 100000, 100000, 100000])
seq_mask = (seq != 100000).float()
last_non_one_idx = torch.where(seq_mask == 1)[0].max().item()
seq_final = torch.zeros_like(seq_mask)
seq_final[last_non_one_idx] = 1

In [1]:
import pickle

In [3]:
with open('./gpt_emb/gpt4o_te3_large_v2.pkl', 'rb') as f:
    gpt_emb = pickle.load(f)

In [10]:
gpt_emb.size(0)

870