In [110]:
import os
import pickle
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.preprocessing import OneHotEncoder

In [111]:
path = '/data/notebook/shared/MIMIC-IV'

In [112]:
with open(os.path.join(path, 'dict_types_mimic_240408_clinic_3_years.pkl'), 'rb') as f:
    dtype_dict = pickle.load(f)
f.close()

with open(os.path.join(path, 'total_data_dict_with_timedelta_nomedi_with_code_label_240423_clinic_3_years.pkl'), 'rb') as f:
    data_dict_d = pickle.load(f)
f.close()

In [114]:
visits

{'year': [[11138,
   4912,
   4102,
   1074,
   12382,
   1121,
   6778,
   2120,
   1074,
   12382,
   1121,
   6778,
   2120,
   1074,
   12382,
   1121,
   6778,
   2120],
  [4912, 9277, 9733, 12701, 12701, 12701],
  [4912, 3176, 9733, 12701, 12701, 12701],
  [9733, 4912, 10835, 7203, 7203, 7203]],
 'code': [['d_41011',
   'd_486',
   'd_42821',
   'p_0066',
   'p_3606',
   'p_3723',
   'p_8856',
   'p_0045',
   'p_0066',
   'p_3606',
   'p_3723',
   'p_8856',
   'p_0045',
   'p_0066',
   'p_3606',
   'p_3723',
   'p_8856',
   'p_0045'],
  ['d_486', 'd_5849', 'd_1628', 'procedure_0', 'procedure_0', 'procedure_0'],
  ['d_486', 'd_51881', 'd_1628', 'procedure_0', 'procedure_0', 'procedure_0'],
  ['d_1628', 'd_486', 'd_51884', 'p_3491', 'p_3491', 'p_3491']],
 'time': [Timestamp('2129-08-04 00:00:00'),
  Timestamp('2129-12-23 22:40:00'),
  Timestamp('2130-09-24 03:55:00'),
  Timestamp('2131-03-10 00:00:00')],
 'timedelta': [583, 441, 166, 0],
 'total_label': 0,
 'code_label': 0,
 'clini

In [119]:
total_labels = []
code_labels = []
length_list = []
clinical_labels = []
code_length_list = []
for sample_id, visits in tqdm(data_dict_d.items()):
    # 레이블 추가
    total_label = visits['total_label']
    code_label = visits['code_label']
    clinical_label = visits['clinical_label']
    total_labels.append(total_label)
    code_labels.append(code_label)
    clinical_labels.append(clinical_labels)
    length_list.append(sum([len(visits[year]) for year in ['year']]))
    code_length_list.append(max([len(seq)for seq in visits['year']]))

100%|██████████| 8037/8037 [00:00<00:00, 296119.23it/s]


In [120]:
np.unique(labels, return_counts=True)
np.mean(length_list)
np.std(length_list)

max_visits_length = max(length_list)
max_index = max(dtype_dict.values())
max_code_len = max(code_length_list)
print('max_index:', max_index+1)
print('max_visit:', max_visits_length)
print('max_code_len:', max_code_len)

max_index: 15373
max_visit: 49
max_code_len: 27


In [121]:
def pad_sequence(seq_diagnosis_codes, maxlen, maxcode):
    lengths = len(data['year'])
    diagnosis_codes = np.zeros((maxlen, maxcode), dtype=np.int64)
    seq_mask_code = np.zeros((maxlen, maxcode), dtype=np.int8)
    seq_mask = np.zeros((maxlen), dtype=np.int8)
    seq_mask_final = np.zeros((maxlen), dtype=np.int8)
    for pid, subseq in enumerate(seq_diagnosis_codes):
        for tid, code in enumerate(subseq):
            diagnosis_codes[pid, tid] = code
            seq_mask_code[pid, tid] = 1
    seq_mask[:lengths] = 1
    seq_mask_final[lengths - 1] = 1
    return diagnosis_codes, seq_mask_code, seq_mask, seq_mask_final

In [122]:
def keep_last_one_in_columns(a):
    # 결과 배열 초기화
    result = np.zeros_like(a)
    # 각 열에 대해 반복
    for col_index in range(a.shape[1]):
        # 현재 열 추출
        column = a[:, col_index]
        # 이 열에서 마지막 '1' 찾기
        last_one_idx = np.max(np.where(column == 1)[0]) if 1 in column else None
        if last_one_idx is not None:
            result[last_one_idx, col_index] = 1
    return result

In [123]:
new_data_dict_d = {}
year_list = []
for sample_id, data in tqdm(data_dict_d.items()):
    data_dict_new = {}
    # pad_seq, seq_mask_code = pad_sequence(data['year'], max_visits_length, max_code_len)
    pad_seq, seq_mask_code, seq_mask, seq_mask_final = pad_sequence(data['year'], max_visits_length, max_code_len)
    data_dict_new['code_index'] = pad_seq
    data_dict_new['code'] = data['code']
    data_dict_new['time'] = data['time']
    data_dict_new['timedelta'] = data['timedelta']
    time_feature = np.array([[timestamp.year, timestamp.month, timestamp.day, timestamp.week] for timestamp in data['time']])
    data_dict_new['time_feature'] = np.pad(time_feature, pad_width=((0, max_visits_length - time_feature.shape[0]),(0,0)))
    data_dict_new['year'] = np.array([timestamp.year for timestamp in data['time']])
    # data_dict_new['seq_mask'] = np.pad(np.ones(time_feature.shape[0]), (0, max_visits_length - time_feature.shape[0]))
    data_dict_new['seq_mask'] = seq_mask
    data_dict_new['seq_mask_final'] = seq_mask_final
    data_dict_new['seq_mask_code'] = seq_mask_code
    unique_year = np.unique(data_dict_new['year'])
    if len(unique_year) == 2:
        # print("two year sample", sample_id)
        unique_year = np.append(unique_year, unique_year[-1]+1)
    elif len(unique_year) == 1:
        # print("one year sample", sample_id)
        continue
    encoder = OneHotEncoder(categories=[unique_year], sparse=False, handle_unknown='ignore')
    year_onehot =  encoder.fit_transform(np.array(data_dict_new['year']).reshape(-1,1))
    last_year_visit = keep_last_one_in_columns(year_onehot)
    data_dict_new['year_onehot'] = np.pad(year_onehot, pad_width=((0, max_visits_length - year_onehot.shape[0]), (0,0)))
    data_dict_new['last_year_onehot'] = np.pad(last_year_visit, pad_width=((0, max_visits_length - year_onehot.shape[0]), (0,0)))
    data_dict_new['label'] = data['total_label']    
    data_dict_new['code_label'] = data['code_label']    
    data_dict_new['clinical_label'] = data['clinical_label']    
    new_data_dict_d[sample_id] = data_dict_new

100%|██████████| 8037/8037 [00:02<00:00, 2828.19it/s]


In [124]:
with open(os.path.join('./data/', 'preprocessed_nomedi_240423_clinic_3_years.pkl'), 'wb') as f:
    pickle.dump(new_data_dict_d, f)
f.close()

In [107]:
os.path.join(path, 'total_data_dict_with_timedelta_nomedi_240421_clinic_3_years.pkl')

'/data/notebook/shared/MIMIC-IV/total_data_dict_with_timedelta_nomedi_240421_clinic_3_years.pkl'

In [108]:
np.outer(seq_mask, seq_mask)

array([[1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int8)