In [None]:
import numpy as np
import pandas as pd
import os
import time
import string
import pickle

In [None]:
patients_f = '[the output from pretraining_labeling]'
matching_dict_f = '[the output from case_control_matching]'
OUTPATH = '[output directory]'

In [None]:
import pickle

with open(patients_f, 'rb') as f:
    patients = pickle.load(f)
    
with open(matching_dict_f, 'rb') as f:
    matching_dict = pickle.load(f)

In [None]:
case_l = []
control_l = []
for p in patients:
    days = (p['final_dt'] - p['lung_dt']).days
    if p['label'] == 0:
        control_l.append(days)
    else:
        case_l.append(days)

### plot the number of days from lung cancer diagnosis to the last encounter for cases and controls

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.boxplot([case_l, control_l], labels=['Case', 'Control'])
plt.ylabel('Number of days')
plt.title('The number of days from lung cancer diagnosis to the last encounter')
plt.savefig('num_day_distribution.pdf')  

### add lab test and clinical event

In [None]:
lab_test_f = '[the path of lab test features]'
   
with open(lab_test_f, 'rb') as f:
    lab_df = pickle.load(f)
lab_dict = {k:v for k,v in lab_df.groupby('encounter_id')}

In [None]:
event_f = '[the path of clinical event features]'
   
with open(event_f, 'rb') as f:
    event_df = pickle.load(f)
event_df = event_df.dropna()
event_dict = {k:v for k,v in event_df.groupby('encounter_id')}

In [None]:
def binage(age):
    if age < 0:
        return 'A_ERROR'
    else:
        try:
            return 'A_' + string.ascii_uppercase[int(age)//5]
        except:
            return 'A_ERROR'

In [None]:
def icd_dict(k=9):
    dict_icd = {}
    with open('./ICD9CMtoICD10CM/icd9to10dictionary.txt', 'r') as f:
        line = f.readline()
        while line:
            nine = str.strip(line.split('|')[0])
            ten = str.strip(line.split('|')[1])
            desc = str.strip(line.split('|')[2])
            #print(line)
            if k==9:
                dict_icd[nine] = [ten,desc]
            elif k==10:
                dict_icd[ten] = [nine, desc]
            else:
                raise Exception('ICD should be 9 or 10 Version!')
            line = f.readline()
    return dict_icd
icd9_dict = icd_dict()
icd9_mismatches = []

In [None]:
def icd9to10(key):
    try:
        icd10 = icd9_dict[key][0]
    except:
        icd10 = key
        icd9_mismatches.append(key)
    return icd10

In [None]:
def data_extraction(d, label, dt):
    race = 'RACE_' + str(d['race']).lower()
    gender = 'GEND_' + str(d['gender']).lower()
    mariral_status = 'MARI_' + str(d['marital_status']).lower()
    yob = d['yob']
    encs = d['ENCOUNTER']
    p = []
    times = []
    last_enc_dt = None
    for enc in encs:
        if label == 0:
            if (enc['discharged_dt_tm'] - dt).days > 30:
                break
        elif enc['discharged_dt_tm'] >= dt:
            break
        if 1:
            enc_codes = []
            diag_codes = []
            med_codes = []
            surg_codes = []
            lab_codes = []
            event_codes = []
            agecode = binage(enc['discharged_dt_tm'].year - yob)
            diags = enc.get('DIAGNOSIS')
            if diags is not None:
                for diag in diags:
                    if diag['diagnosis_type'] == 'ICD9':
                        icd_code = 'DIAG_' + str(icd9to10(diag['diagnosis_code'])).upper()
                    else:
                        icd_code = 'DIAG_' + str(diag['diagnosis_code']).upper()
                    diag_codes.append(icd_code)
                diag_codes = list(set(diag_codes))
                        
            meds = enc.get('MEDICATION')
            if meds is not None:
                med_codes = ['MED_' + med['generic_name'].upper() for med in meds]
                med_codes = list(set(med_codes))
                #med_codes = []
            surgs = enc.get('SURGICAL')
            if surgs is not None:
                surg_codes = ['SURG_' + str(surg['surgical_procedure_id']).upper() for surg in surgs]
                surg_codes = list(set(surg_codes))
                #surg_codes = []
                
            ### extract lab_test
            lab_code_tmp = lab_dict.get(enc['encounter_id'])
            if lab_code_tmp is not None:
                lab_codes = list(set(lab_code_tmp.code_result))
                lab_codes = ['LAB_' + str(lab_code_) for lab_code_ in lab_codes]
            
            ### extract clinical_event
            event_tmp = event_dict.get(enc['encounter_id'])
            if event_tmp is not None:
                event_codes = list(set(event_tmp.code_result))
                event_codes = ['EVENT_' + str(event_code_) for event_code_ in event_codes]
                
            
            enc_codes = diag_codes + med_codes + surg_codes + lab_codes + event_codes
            #enc_codes = lab_codes + event_codes
            #enc_codes = diag_codes
            
            if enc_codes is None or len(enc_codes) == 0:
                continue
            
            #enc_codes = enc_codes + [race, gender, mariral_status, agecode]
            enc_codes = enc_codes + [gender, agecode]
            p.append(enc_codes)
            
            ### calculate time duration
            str_dt = str(enc['discharged_dt_tm'])
            if last_enc_dt is None:
                times.append(0)
            else:
                if str_dt == 'NaT':
                    n_days = np.mean(times)
                else:
                    n_days = (enc['discharged_dt_tm'] - last_enc_dt).days
                times.append(n_days)
            if str_dt != 'NaT':
                last_enc_dt = enc['discharged_dt_tm']
    return p, times

In [None]:
new_patients = {d['patient_sk']:d for d in patients}

In [None]:
from tqdm import tqdm_notebook as tqdm

res = {}
for case_psk in tqdm(list(matching_dict.keys())):
    control_psks = matching_dict[case_psk]
    case = new_patients[case_psk]
    res[case_psk] = (1, data_extraction(case, 1, case['bm_dt']))
    for control_psk in control_psks:
        control = new_patients[control_psk]
        res[control_psk] = (0, data_extraction(control, 0, case['bm_dt']))
            


In [None]:
print('Saving data by pickle')

with open(os.path.join(OUTPATH, 'data_time_with_label_lab_event_same_hospital.pickle'), 'wb') as f:
    pickle.dump(res, f, protocol=pickle.HIGHEST_PROTOCOL)
    