In [2]:
from collections import defaultdict

import dill
import numpy as np
import pandas as pd
import nltk
import re
from tqdm import tqdm
import random
tqdm.pandas()

In [3]:
##### process medications #####
# load med data
def med_process(med_file):
    med_pd = pd.read_csv(med_file, dtype={'NDC': 'category'})

    med_pd.drop(columns=['ROW_ID', 'DRUG_TYPE', 'DRUG_NAME_POE', 'DRUG_NAME_GENERIC',
                         'FORMULARY_DRUG_CD', 'PROD_STRENGTH', 'DOSE_VAL_RX',
                         'DOSE_UNIT_RX', 'FORM_VAL_DISP', 'FORM_UNIT_DISP', 'GSN', 'FORM_UNIT_DISP',
                         'ROUTE', 'ENDDATE', 'DRUG'], axis=1, inplace=True)
    med_pd.drop(index=med_pd[med_pd['NDC'] == '0'].index, axis=0, inplace=True)
    med_pd.fillna(method='pad', inplace=True)
    med_pd.dropna(inplace=True)
    med_pd.drop_duplicates(inplace=True)
    med_pd['ICUSTAY_ID'] = med_pd['ICUSTAY_ID'].astype('int64')
    med_pd['STARTDATE'] = pd.to_datetime(med_pd['STARTDATE'], format='%Y-%m-%d %H:%M:%S')
    med_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID', 'STARTDATE'], inplace=True)
    med_pd = med_pd.reset_index(drop=True)

    med_pd = med_pd.drop(columns=['ICUSTAY_ID'])
    med_pd = med_pd.drop_duplicates()
    med_pd = med_pd.reset_index(drop=True)

    return med_pd

# medication mapping
def ndc2atc4(med_pd):
    with open(ndc_rxnorm_file, 'r') as f:
        ndc2rxnorm = eval(f.read())
    med_pd['RXCUI'] = med_pd['NDC'].map(ndc2rxnorm)
    med_pd.dropna(inplace=True)

    rxnorm2atc = pd.read_csv(ndc2atc_file)
    rxnorm2atc = rxnorm2atc.drop(columns=['YEAR', 'MONTH', 'NDC'])
    rxnorm2atc.drop_duplicates(subset=['RXCUI'], inplace=True)
    med_pd.drop(index=med_pd[med_pd['RXCUI'].isin([''])].index, axis=0, inplace=True)

    med_pd['RXCUI'] = med_pd['RXCUI'].astype('int64')
    med_pd = med_pd.reset_index(drop=True)
    med_pd = med_pd.merge(rxnorm2atc, on=['RXCUI'])
    med_pd.drop(columns=['NDC', 'RXCUI'], inplace=True)
    med_pd = med_pd.rename(columns={'ATC4': 'NDC'})
    med_pd['NDC'] = med_pd['NDC'].map(lambda x: x[:4])
    med_pd = med_pd.drop_duplicates()
    med_pd = med_pd.reset_index(drop=True)
    return med_pd

# visit >= 2
def process_visit_lg2(med_pd):
    a = med_pd[['SUBJECT_ID', 'HADM_ID']].groupby(by='SUBJECT_ID')['HADM_ID'].unique().reset_index()
    a['HADM_ID_Len'] = a['HADM_ID'].map(lambda x: len(x))
    a = a[a['HADM_ID_Len'] > 1]
    return a

# most common medications
def filter_300_most_med(med_pd):
    med_count = med_pd.groupby(by=['NDC']).size().reset_index().rename(columns={0: 'count'}).sort_values(by=['count'],
                                                                                                         ascending=False).reset_index(
        drop=True)
    med_pd = med_pd[med_pd['NDC'].isin(med_count.loc[:299, 'NDC'])]

    return med_pd.reset_index(drop=True)

##### process diagnosis #####
def diag_process(diag_file):
    diag_pd = pd.read_csv(diag_file)
    diag_pd.dropna(inplace=True)
    diag_pd.drop(columns=['SEQ_NUM', 'ROW_ID'], inplace=True)
    diag_pd.drop_duplicates(inplace=True)
    diag_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID'], inplace=True)
    diag_pd = diag_pd.reset_index(drop=True)

    def filter_2000_most_diag(diag_pd):
        diag_count = diag_pd.groupby(by=['ICD9_CODE']).size().reset_index().rename(columns={0: 'count'}).sort_values(
            by=['count'], ascending=False).reset_index(drop=True)
        diag_pd = diag_pd[diag_pd['ICD9_CODE'].isin(diag_count.loc[:1999, 'ICD9_CODE'])]

        return diag_pd.reset_index(drop=True)

    diag_pd = filter_2000_most_diag(diag_pd)

    return diag_pd

##### process procedure #####
def procedure_process(procedure_file):
    pro_pd = pd.read_csv(procedure_file, dtype={'ICD9_CODE': 'category'})
    pro_pd.drop(columns=['ROW_ID'], inplace=True)
    pro_pd.drop_duplicates(inplace=True)
    pro_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'SEQ_NUM'], inplace=True)
    pro_pd.drop(columns=['SEQ_NUM'], inplace=True)
    pro_pd.drop_duplicates(inplace=True)
    pro_pd.reset_index(drop=True, inplace=True)

    return pro_pd

def statistics(data):
    print('#patients ', data['SUBJECT_ID'].unique().shape)
    print('#clinical events ', len(data))

    diag = data['ICD9_CODE'].values
    med = data['NDC'].values
    pro = data['PRO_CODE'].values
    sym = data['SYM_LIST'].values

    unique_diag = set([j for i in diag for j in list(i)])
    unique_med = set([j for i in med for j in list(i)])
    unique_pro = set([j for i in pro for j in list(i)])
    unique_sym = set([j for i in sym for j in list(i)])

    print('#diagnosis ', len(unique_diag))
    print('#med ', len(unique_med))
    print('#procedure', len(unique_pro))
    print('#symptoms', len(unique_sym))

    avg_diag, avg_med, avg_pro, avg_sym, max_diag, max_med, max_pro, max_sym, cnt, max_visit, avg_visit = [0 for i in range(11)]

    for subject_id in data['SUBJECT_ID'].unique():
        item_data = data[data['SUBJECT_ID'] == subject_id]
        x, y, z,s = [], [], [], []
        visit_cnt = 0
        for index, row in item_data.iterrows():
            visit_cnt += 1
            cnt += 1
            x.extend(list(row['ICD9_CODE']))
            y.extend(list(row['NDC']))
            z.extend(list(row['PRO_CODE']))
            s.extend(list(row['SYM_LIST']))
        x, y, z, s = set(x), set(y), set(z), set(s)
        avg_diag += len(x)
        avg_med += len(y)
        avg_pro += len(z)
        avg_sym += len(s)                              
        avg_visit += visit_cnt
        if len(x) > max_diag:
            max_diag = len(x)
        if len(y) > max_med:
            max_med = len(y)
        if len(z) > max_pro:
            max_pro = len(z)
        if len(s) > max_sym:
            max_sym = len(s)
        if visit_cnt > max_visit:
            max_visit = visit_cnt

    print('#avg of diagnoses ', avg_diag / cnt)
    print('#avg of medicines ', avg_med / cnt)
    print('#avg of procedures ', avg_pro / cnt)
    print('#avg of symptoms ', avg_sym / cnt)
    print('#avg of vists ', avg_visit / len(data['SUBJECT_ID'].unique()))

    print('#max of diagnoses ', max_diag)
    print('#max of medicines ', max_med)
    print('#max of procedures ', max_pro)
    print('#max of symptoms ', max_sym)
    print('#max of visit ', max_visit)

In [4]:
###### combine three tables #####
def combine_process(med_pd, diag_pd, pro_pd):
    med_pd_key = med_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    diag_pd_key = diag_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
    pro_pd_key = pro_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()

    combined_key = med_pd_key.merge(diag_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    combined_key = combined_key.merge(pro_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')

    diag_pd = diag_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    med_pd = med_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    pro_pd = pro_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')

    # flatten and merge
    diag_pd = diag_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['ICD9_CODE'].unique().reset_index()
    med_pd = med_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['NDC'].unique().reset_index()
    pro_pd = pro_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['ICD9_CODE'].unique().reset_index().rename(
        columns={'ICD9_CODE': 'PRO_CODE'})

    med_pd['NDC'] = med_pd['NDC'].map(lambda x: list(x))
    pro_pd['PRO_CODE'] = pro_pd['PRO_CODE'].map(lambda x: list(x))

    data = diag_pd.merge(med_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    data = data.merge(pro_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
    #     data['ICD9_CODE_Len'] = data['ICD9_CODE'].map(lambda x: len(x))
    data['NDC_Len'] = data['NDC'].map(lambda x: len(x))

    return data

In [5]:
def get_side(source_df, side_df, side_columns, aligh_column):

    side_df = side_df[side_columns]
    source_df = pd.merge(source_df, side_df, how="left", on=aligh_column)

    return source_df

In [6]:
def profile_tokenization(df, profile_columns):
    prof_dict = {"word2idx":{}, "idx2word": {}}
    for prof in profile_columns:
        prof_dict["idx2word"][prof] = dict(zip(range(df[prof].nunique()), df[prof].unique()))
        prof_dict["word2idx"][prof] = dict(zip(df[prof].unique(), range(df[prof].nunique())))
    return prof_dict

In [7]:
def match_symptoms(search):

    """re for symptom matching"""

    symptomList = []

    for item in search:
        pattern = r'<c>(.*?)<\/c>'
        noun_prases = re.findall(pattern, item, flags=0)
        for s in noun_prases:
            if s.lower() in symptoms_list:
                symptomList.append(s)

    return list(set(symptomList))


def main_text(full_text):
    sentence = full_text.replace('"','')

    sentences = nltk.sent_tokenize(sentence)
    word_tokens = [nltk.word_tokenize(sent) for sent in sentences]
    sentences = [nltk.pos_tag(sent) for sent in word_tokens]

    def chunk(text):
        """ Return the noun phrases using reguler expressoins"""

        '''pattern = ['JJ', 'NN', 'VB', 'NN']
            matches = []

            for i in range(len(tagged)):
                if tagged[i:i+len(pattern)] == pattern:
                    matches.append(sentences[i:i+len(pattern)])

            matches = [' '.join(match) for match in matches]
            print(matches)'''

        grammar = """NP: {<V.*>+(<RP?><NN>)?}
                    NP: {(<NN.*><DT>)?(<NN.*><IN>)?<NN.*>?<JJ.>*<NN.*>+}
                    NP: {<V.*>}
                    ENTITY: {<NN.*>}"""

        parser = nltk.RegexpParser(grammar)
        result = parser.parse(text)
        t_sent = ' '.join(word for word, pos in text)
        for subtree in result.subtrees():
            if subtree.label() == 'NP':
                noun_phrases_list = ' '.join(word for word, pos in subtree.leaves())
                t_sent = t_sent.replace(noun_phrases_list, "<c>"+noun_phrases_list+"</c>", 1)
        return t_sent
    
    chunk_sent = []
    for sentence in sentences:
        chunk_sent.append(chunk(sentence))
    return chunk_sent

def symptoms_tagger(x):

    search = main_text(x)

    tagged_symptom_list = match_symptoms(search)
    return list(set(tagged_symptom_list))


def text_to_symptom(text):
    text_list = text.split('\n')
    sym_list = []
    for i in range(len(text_list)):
        sym = symptoms_tagger(text_list[i])
        sym_list += sym
    return list(set([sym.lower() for sym in sym_list]))

In [8]:
def get_ddi_matrix(records, med_voc, ddi_file):
    TOPK = 40  # topk drug-drug interaction
    cid2atc_dic = defaultdict(set)
    med_voc_size = len(med_voc.idx2word)
    med_unique_word = [med_voc.idx2word[i] for i in range(med_voc_size)]
    atc3_atc4_dic = defaultdict(set)
    for item in med_unique_word:
        atc3_atc4_dic[item[:4]].add(item)

    with open(cid_atc, 'r') as f:
        for line in f:
            line_ls = line[:-1].split(',')
            cid = line_ls[0]
            atcs = line_ls[1:]
            for atc in atcs:
                if len(atc3_atc4_dic[atc[:4]]) != 0:
                    cid2atc_dic[cid].add(atc[:4])

    # ddi load
    ddi_df = pd.read_csv(ddi_file)
    # fliter sever side effect
    ddi_most_pd = ddi_df.groupby(by=['Polypharmacy Side Effect', 'Side Effect Name']).size().reset_index().rename(
        columns={0: 'count'}).sort_values(by=['count'], ascending=False).reset_index(drop=True)
    ddi_most_pd = ddi_most_pd.iloc[-TOPK:, :]
    # ddi_most_pd = pd.DataFrame(columns=['Side Effect Name'], data=['as','asd','as'])
    fliter_ddi_df = ddi_df.merge(ddi_most_pd[['Side Effect Name']], how='inner', on=['Side Effect Name'])
    ddi_df = fliter_ddi_df[['STITCH 1', 'STITCH 2']].drop_duplicates().reset_index(drop=True)
    ddi_adj = np.zeros((med_voc_size, med_voc_size))
    for index, row in ddi_df.iterrows():
        # ddi
        cid1 = row['STITCH 1']
        cid2 = row['STITCH 2']

        # cid -> atc_level3
        for atc_i in cid2atc_dic[cid1]:
            for atc_j in cid2atc_dic[cid2]:

                # atc_level3 -> atc_level4
                for i in atc3_atc4_dic[atc_i]:
                    for j in atc3_atc4_dic[atc_j]:
                        if med_voc.word2idx[i] != med_voc.word2idx[j]:
                            ddi_adj[med_voc.word2idx[i], med_voc.word2idx[j]] = 1
                            ddi_adj[med_voc.word2idx[j], med_voc.word2idx[i]] = 1
    dill.dump(ddi_adj, open('./output/ddi_A_final.pkl', 'wb'))

    return ddi_adj

In [9]:
med_file = '/mimiciii/PRESCRIPTIONS.csv'
diag_file = '/mimiciii/DIAGNOSES_ICD.csv'
procedure_file = 'mimiciii/PROCEDURES_ICD.csv'
profile_file = '/mimiciii/ADMISSIONS.csv'
text_file = '/mimiciii/NOTEEVENTS.csv'

symptom_file = './inputs/symptoms_list.pkl'
med_structure_file = './inputs/idx2drug.pkl'
# drug code mapping files
ndc2atc_file = './inputs/ndc2atc_level4.csv'
cid_atc = './inputs/drug-atc.csv'
ndc_rxnorm_file = './inputs/ndc2rxnorm_mapping.txt'
# ddi information
ddi_file = './inputs/drug-DDI.csv'

In [11]:
med_pd = med_process(med_file)
med_pd_lg2 = process_visit_lg2(med_pd).reset_index(drop=True)
med_pd = med_pd.merge(med_pd_lg2[['SUBJECT_ID']], on='SUBJECT_ID', how='inner').reset_index(drop=True)
med_pd = ndc2atc4(med_pd)
NDCList = dill.load(open(med_structure_file, 'rb'))
med_pd = med_pd[med_pd.NDC.isin(list(NDCList.keys()))]
med_pd = filter_300_most_med(med_pd)
med_pd

  med_pd = pd.read_csv(med_file, dtype={'NDC': 'category'})


Unnamed: 0,SUBJECT_ID,HADM_ID,STARTDATE,NDC
0,17,161087,2135-05-09,N02B
1,17,194023,2134-12-27,N02B
2,21,111970,2135-02-06,N02B
3,23,152223,2153-09-03,N02B
4,36,122659,2131-05-15,N02B
...,...,...,...,...
704656,97547,112445,2125-11-11,N05A
704657,97547,112445,2125-11-19,N05A
704658,97547,112445,2125-11-10,N05A
704659,97547,127852,2125-10-29,N05A


In [12]:
diag_pd = diag_process(diag_file)
diag_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE
0,2,163353,V3001
1,2,163353,V053
2,2,163353,V290
3,3,145834,0389
4,3,145834,78559
...,...,...,...
625429,99995,137810,41401
625430,99999,113369,7861
625431,99999,113369,4019
625432,99999,113369,25000


In [13]:
pro_pd = procedure_process(procedure_file)
pro_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE
0,2,163353,9955
1,3,145834,9604
2,3,145834,9962
3,3,145834,8964
4,3,145834,9672
...,...,...,...
228674,99999,113369,8108
228675,99999,113369,8051
228676,99999,113369,8162
228677,99999,113369,9979


In [14]:
data = combine_process(med_pd, diag_pd, pro_pd)

In [15]:
admission = pd.read_csv(profile_file)
data = get_side(data, admission, 
                ["HADM_ID", "INSURANCE", "LANGUAGE", "RELIGION", "MARITAL_STATUS", "ETHNICITY", "DIAGNOSIS"],
                "HADM_ID"
                )
data.fillna(value="unknown", inplace=True)

In [16]:
#combine symp
symptoms_list = dill.load(open(symptom_file, 'rb'))
symptoms_list = list(set([sym.lower() for sym in symptoms_list]))
notes = pd.read_csv(text_file, usecols=['SUBJECT_ID','HADM_ID','CATEGORY','TEXT'])
notes1 = notes[notes.CATEGORY=='Discharge summary'].sort_values(by=['SUBJECT_ID', 'HADM_ID'])
notes2 = notes1.groupby(by=['SUBJECT_ID','HADM_ID'])['TEXT'].apply('\n'.join).reset_index()
data1 = data.merge(notes2, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
data1['SYM_LIST'] = data1['TEXT'].progress_apply(text_to_symptom) + data1['DIAGNOSIS'].map(text_to_symptom)
data1['SYM_len'] = data1['SYM_LIST'].map(len)
data = data1[data1['SYM_len'] > 0].reset_index()
data = data.drop(columns=['index', 'TEXT'])
# data.to_pickle('./output/data_final.pkl')
# statistics(data)
print('complete combining')

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14699/14699 [54:50<00:00,  4.47it/s]


complete combining


In [17]:
data

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE,NDC,PRO_CODE,NDC_Len,INSURANCE,LANGUAGE,RELIGION,MARITAL_STATUS,ETHNICITY,DIAGNOSIS,SYM_LIST,SYM_len
0,17,161087,"[4239, 5119, 78551, 4589, 311, 7220, 71946, 2724]","[N02B, A01A, A02B, A06A, B05C, A12A, A12C, C01...","[3731, 8872, 3893]",15,Private,ENGL,CATHOLIC,MARRIED,WHITE,PERICARDIAL EFFUSION,"[sob, pain, depression, chest pain]",4
1,17,194023,"[7455, 45829, V1259, 2724]","[N02B, A01A, A02B, A06A, A12A, B05C, A12C, C01...","[3571, 3961, 8872]",16,Private,ENGL,CATHOLIC,MARRIED,WHITE,PATIENT FORAMEN OVALE\ PATENT FORAMEN OVALE MI...,[depression],1
2,21,109451,"[41071, 78551, 5781, 5849, 40391, 4280, 4592, ...","[A06A, B05C, C07A, A12B, C03C, A12A, A02A, J01...","[0066, 3761, 3950, 3606, 0042, 0047, 3895, 399...",23,Medicare,unknown,JEWISH,MARRIED,WHITE,CONGESTIVE HEART FAILURE,"[weakness, chills, chest pain, dyspnea, vomiti...",12
3,21,111970,"[0388, 78552, 40391, 42731, 70709, 5119, 6823,...","[N02B, A06A, B05C, A12C, A07A, A02A, B01A, N06...","[3995, 8961, 0014]",19,Medicare,unknown,JEWISH,MARRIED,WHITE,SEPSIS,"[febrile, weakness, ulcers, cold, shock, sleep...",9
4,23,124321,"[2252, 3485, 78039, 4241, 4019, 2720, 2724, V4...","[B05C, A07A, C07A, A06A, N02B, C02D, B01A, A02...",[0151],17,Medicare,ENGL,CATHOLIC,MARRIED,WHITE,BRAIN MASS,"[hallucinations, weakness, tenderness, dizzine...",18
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14594,99923,164914,"[45829, 4532, 2761, 5723, 4561, 45621, 5849, 7...","[N02B, A02A, B01A, A06A, J01M, H01C, A07A, C01C]","[5491, 4513]",8,Private,ENGL,CATHOLIC,MARRIED,WHITE,HYPONATREMIA,"[sob, constipation, orthopnea, pnd, pain, asci...",13
14595,99923,192053,"[5712, 5856, 5724, 40391, 9974, 5601, 30393, V...","[A06A, A12A, A12C, N01A, C07A, C03C, B01A, A02...","[5059, 504, 5569, 0093]",24,Private,ENGL,CATHOLIC,MARRIED,WHITE,END STAGE LIVER DISEASE,"[chills, jaundice, cyanosis, constipation, spl...",14
14596,99982,112748,"[4280, 42823, 5849, 4254, 2763, 42731, 78729, ...","[A01A, B05C, C01C, C03C, A12C, A06A, A02B, B01...",[3721],14,Medicare,ENGL,CATHOLIC,MARRIED,WHITE,SHORTNESS OF BREATH,"[f/c, dyspnea, dysphagia, dyspnea on exertion,...",12
14597,99982,151454,"[42823, 4254, 2875, 42731, 3970, 5303, 4280, V...","[N02B, A01A, A06A, B05C, A12A, A12C, C01C, N01...","[3527, 3961]",20,Medicare,ENGL,CATHOLIC,MARRIED,WHITE,TVR,"[dysphagia, constipation, fever, pain]",4


In [18]:
statistics(data)

#patients  (6313,)
#clinical events  14599
#diagnosis  1957
#med  131
#procedure 1416
#symptoms 434
#avg of diagnoses  10.670456880608262
#avg of medicines  11.664566066168916
#avg of procedures  3.8904719501335707
#avg of symptoms  7.879169806151106
#avg of vists  2.3125297006177727
#max of diagnoses  128
#max of medicines  65
#max of procedures  50
#max of symptoms  78
#max of visit  29


In [128]:
# #和以前的研究同数据集，但是就诊中会包括只有一次就诊记录的患者

# med_pd = med_process(med_file)
# med_pd_lg2 = process_visit_lg2(med_pd).reset_index(drop=True)
# med_pd = med_pd.merge(
#     med_pd_lg2[["SUBJECT_ID"]], on="SUBJECT_ID", how="inner"
# ).reset_index(drop=True)
# med_pd = ndc2atc4(med_pd)
# NDCList = dill.load(open(med_structure_file, 'rb'))
# med_pd = med_pd[med_pd.NDC.isin(list(NDCList.keys()))]
# med_pd = filter_300_most_med(med_pd)
# med_pd
# multi_data = combine_process(med_pd, diag_pd, pro_pd)
# multi_data = pd.merge(multi_data[['SUBJECT_ID', 'HADM_ID']], data1[data1['SYM_len'] > 0], on=['SUBJECT_ID', 'HADM_ID'], how='inner')
# statistics(multi_data)

In [34]:
subject_counts = data['SUBJECT_ID'].value_counts()

# 将SUBJECT_ID仅出现过一次的行划分到single_df
single_df = data[data['SUBJECT_ID'].isin(subject_counts[subject_counts == 1].index)]

# 将SUBJECT_ID出现多次的行划分到multi_df
multi_df = data[data['SUBJECT_ID'].isin(subject_counts[subject_counts > 1].index)]

In [19]:
dill.dump(data, open('./outputs/data_final.pkl', 'wb'))
# dill.dump(single_df, open('./outputs/single_data_final.pkl', 'wb'))
# dill.dump(multi_df, open('./outputs/multi_data_final.pkl', 'wb'))

In [20]:
class Voc(object):
    def __init__(self):
        self.idx2word = {}
        self.word2idx = {}

    def add_sentence(self, sentence):
        for word in sentence:
            if word not in self.word2idx:
                self.idx2word[len(self.word2idx)] = word
                self.word2idx[word] = len(self.word2idx)
                
def create_str_token_mapping(df):
    diag_voc = Voc()
    med_voc = Voc()
    pro_voc = Voc()
    sym_voc = Voc()
    
    for index, row in df.iterrows():
        diag_voc.add_sentence(["MASK"])
        diag_voc.add_sentence(row['ICD9_CODE'])
        pro_voc.add_sentence(["MASK"])
        pro_voc.add_sentence(row['PRO_CODE'])
        sym_voc.add_sentence(["MASK"])
        sym_voc.add_sentence(row['SYM_LIST'])
        med_voc.add_sentence(["MASK"])
        med_voc.add_sentence(row['NDC'])
    
    dill.dump(obj={'diag_voc':diag_voc, 'med_voc':med_voc ,'pro_voc':pro_voc, 'sym_voc':sym_voc}, file=open('./outputs/voc_final.pkl','wb'))
    return diag_voc, pro_voc, sym_voc, med_voc 

In [24]:
diag_voc, pro_voc, sym_voc, med_voc = create_str_token_mapping(data)

len(diag_voc.idx2word), len(pro_voc.idx2word), len(sym_voc.idx2word), len(med_voc.idx2word)

(1998, 1955, 505, 132)

In [22]:
def random_word(tokens, vocab, prob_input=0.15, replace_prob=0.1):
    tmp_tokens = tokens[:]
    num_tokens = len(tokens)
    
    # 随机选取 60% 的 token
    num_select = int(num_tokens * prob_input)
    selected_indices = random.sample(range(num_tokens), num_select)
    
    for i in selected_indices:
        prob = random.random()
        # 替换为随机字符的概率
        if prob < replace_prob:
            # 选择随机字符
            tmp_tokens[i] = random.choice(list(vocab.idx2word.items()))[0]
        else:
            # 替换为 mask token
            tmp_tokens[i] = diag_voc.word2idx['MASK']
    
    return tmp_tokens

In [28]:
def create_patient_record(df, diag_voc, med_voc, pro_voc, sym_voc, profile_tokenizer, df_type):
    records = [] # (patient, code_kind:3, codes)  code_kind:diag, proc, med
    for subject_id in df['SUBJECT_ID'].unique():
        item_df = df[df['SUBJECT_ID'] == subject_id]
        patient = []
        for index, row in item_df.iterrows():
            admission = []
            profile = []
            admission.append([diag_voc.word2idx[i] for i in row['ICD9_CODE']])
            admission.append([pro_voc.word2idx[i] for i in row['PRO_CODE']])
            admission.append([sym_voc.word2idx[i] for i in row['SYM_LIST']])
            admission.append([med_voc.word2idx[i] for i in row['NDC']])
            profile.append(profile_tokenizer["word2idx"]['INSURANCE'].get(row['INSURANCE'],-1))
            profile.append(profile_tokenizer["word2idx"]['LANGUAGE'].get(row['LANGUAGE'],-1))
            profile.append(profile_tokenizer["word2idx"]['RELIGION'].get(row['RELIGION'],-1))
            profile.append(profile_tokenizer["word2idx"]['MARITAL_STATUS'].get(row['MARITAL_STATUS'],-1))
            profile.append(profile_tokenizer["word2idx"]['ETHNICITY'].get(row['ETHNICITY'],-1))
            admission.append(profile)
            patient.append(admission)
        records.append(patient) 
    if (df_type == 'single'):
        for step, patient in enumerate(records):
            for idx, adm in enumerate(patient):
                adm.append(random_word(adm[0],diag_voc,0.15,0.1))
                adm.append(random_word(adm[1],pro_voc,0.15,0.1))
                adm.append(random_word(adm[2],sym_voc,0.15,0.1))
                adm.append(random_word(adm[3],med_voc,0.5,0.4))
        dill.dump(obj=records, file=open('./outputs/mask_single_records_final.pkl', 'wb'))
    elif (df_type == 'multi'):
        dill.dump(obj=records, file=open('./outputs/records_final.pkl', 'wb'))
    elif (df_type == 'all'):
        dill.dump(obj=records, file=open('./outputs/all_records_final.pkl', 'wb'))
    return records

In [29]:
profile_tokenizer = profile_tokenization(data, ["INSURANCE", "LANGUAGE", "RELIGION", "MARITAL_STATUS", "ETHNICITY"])
all_records = create_patient_record(data, diag_voc, med_voc, pro_voc, sym_voc, profile_tokenizer, "all")
single_records = create_patient_record(single_df, diag_voc, med_voc, pro_voc, sym_voc, profile_tokenizer, "single")
multi_records = create_patient_record(multi_df, diag_voc, med_voc, pro_voc, sym_voc, profile_tokenizer, "multi")

In [3]:
def get_ddi_matrix(records, med_voc, ddi_file):
    TOPK = 40  # topk drug-drug interaction
    cid2atc_dic = defaultdict(set)
    med_voc_size = len(med_voc.idx2word)
    med_unique_word = [med_voc.idx2word[i] for i in range(med_voc_size)]
    atc3_atc4_dic = defaultdict(set)
    for item in med_unique_word:
        atc3_atc4_dic[item[:4]].add(item)

    with open(cid_atc, 'r') as f:
        for line in f:
            line_ls = line[:-1].split(',')
            cid = line_ls[0]
            atcs = line_ls[1:]
            for atc in atcs:
                if len(atc3_atc4_dic[atc[:4]]) != 0:
                    cid2atc_dic[cid].add(atc[:4])

    # ddi load
    ddi_df = pd.read_csv(ddi_file)
    # fliter sever side effect
    ddi_most_pd = ddi_df.groupby(by=['Polypharmacy Side Effect', 'Side Effect Name']).size().reset_index().rename(
        columns={0: 'count'}).sort_values(by=['count'], ascending=False).reset_index(drop=True)
    ddi_most_pd = ddi_most_pd.iloc[-TOPK:, :]
    # ddi_most_pd = pd.DataFrame(columns=['Side Effect Name'], data=['as','asd','as'])
    fliter_ddi_df = ddi_df.merge(ddi_most_pd[['Side Effect Name']], how='inner', on=['Side Effect Name'])
    ddi_df = fliter_ddi_df[['STITCH 1', 'STITCH 2']].drop_duplicates().reset_index(drop=True)
    ddi_adj = np.zeros((med_voc_size, med_voc_size))
    for index, row in ddi_df.iterrows():
        # ddi
        cid1 = row['STITCH 1']
        cid2 = row['STITCH 2']

        # cid -> atc_level3
        for atc_i in cid2atc_dic[cid1]:
            for atc_j in cid2atc_dic[cid2]:

                # atc_level3 -> atc_level4
                for i in atc3_atc4_dic[atc_i]:
                    for j in atc3_atc4_dic[atc_j]:
                        if med_voc.word2idx[i] != med_voc.word2idx[j]:
                            ddi_adj[med_voc.word2idx[i], med_voc.word2idx[j]] = 1
                            ddi_adj[med_voc.word2idx[j], med_voc.word2idx[i]] = 1
    dill.dump(ddi_adj, open('./outputs/ddi_A_final.pkl', 'wb'))
    return ddi_adj

In [5]:
def ddi_rate_score(record, path="./outputs/ddi_A_final.pkl"):
    # ddi rate
    ddi_A = dill.load(open(path, "rb"))
    all_cnt = 0
    dd_cnt = 0
    for patient in record:
        for adm in patient:
            med_code_set = adm
            for i, med_i in enumerate(med_code_set):
                for j, med_j in enumerate(med_code_set):
                    if j <= i:
                        continue
                    all_cnt += 1
                    if ddi_A[med_i, med_j] == 1 or ddi_A[med_j, med_i] == 1:
                        dd_cnt += 1
    if all_cnt == 0:
        return 0
    return dd_cnt / all_cnt

In [7]:
len(all_records)

6313

In [8]:
ddi_sum=0
visit_num=0
for step, patient in enumerate(all_records):
    for idx, adm in enumerate(patient):
        ddi_sum+=ddi_rate_score([[adm[3]]])
        visit_num+=1
print("DDI rate:",ddi_sum/visit_num)


DDI rate: 0.09040181569538848
