In [29]:
import pandas as pd
import os, shutil
import numpy as np
import matplotlib
import sys; sys.path.append('..')
import pickle
import requests
import torch
import dill

from queue import Queue
from tqdm.notebook import tqdm, tqdm_notebook
from collections import defaultdict

path_ddi_dataset = r"/data/data2/041/datasets/DDI"
path_iii_dataset = r"/data/data2/041/datasets/mimic-iii-clinical-database-1.4"

ndc_rxnorm_file    = r"/data/data2/041/datasets/DDI/ndc2rxnorm_mapping.txt"
ndc2atc_file       = r'/data/data2/041/datasets/DDI/ndc2atc_level4.csv'
med_structure_file = r'/data/data2/041/datasets/DDI/idx2drug.pkl'   # 药物到分子式的映射
cid_atc            = r'/data/data2/041/datasets/DDI/drug-atc.csv'   # drug（CID） to ATC code mapping file，用于处理DDI表

In [2]:
med_pd = pd.read_csv(os.path.join(path_iii_dataset, "PRESCRIPTIONS.csv.gz"), dtype={'NDC':'category'})
med_pd.head()

  has_raised = await self.run_ast_nodes(code_ast.body, cell_name,


Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,ICUSTAY_ID,STARTDATE,ENDDATE,DRUG_TYPE,DRUG,DRUG_NAME_POE,DRUG_NAME_GENERIC,FORMULARY_DRUG_CD,GSN,NDC,PROD_STRENGTH,DOSE_VAL_RX,DOSE_UNIT_RX,FORM_VAL_DISP,FORM_UNIT_DISP,ROUTE
0,2214776,6,107064,,2175-06-11 00:00:00,2175-06-12 00:00:00,MAIN,Tacrolimus,Tacrolimus,Tacrolimus,TACR1,21796.0,469061711,1mg Capsule,2,mg,2,CAP,PO
1,2214775,6,107064,,2175-06-11 00:00:00,2175-06-12 00:00:00,MAIN,Warfarin,Warfarin,Warfarin,WARF5,6562.0,56017275,5mg Tablet,5,mg,1,TAB,PO
2,2215524,6,107064,,2175-06-11 00:00:00,2175-06-12 00:00:00,MAIN,Heparin Sodium,,,HEPAPREMIX,6522.0,338055002,"25,000 unit Premix Bag",25000,UNIT,1,BAG,IV
3,2216265,6,107064,,2175-06-11 00:00:00,2175-06-12 00:00:00,BASE,D5W,,,HEPBASE,,0,HEPARIN BASE,250,ml,250,ml,IV
4,2214773,6,107064,,2175-06-11 00:00:00,2175-06-12 00:00:00,MAIN,Furosemide,Furosemide,Furosemide,FURO20,8208.0,54829725,20mg Tablet,20,mg,1,TAB,PO


In [3]:
med_pd.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4156450 entries, 0 to 4156449
Data columns (total 19 columns):
 #   Column             Dtype   
---  ------             -----   
 0   ROW_ID             int64   
 1   SUBJECT_ID         int64   
 2   HADM_ID            int64   
 3   ICUSTAY_ID         float64 
 4   STARTDATE          object  
 5   ENDDATE            object  
 6   DRUG_TYPE          object  
 7   DRUG               object  
 8   DRUG_NAME_POE      object  
 9   DRUG_NAME_GENERIC  object  
 10  FORMULARY_DRUG_CD  object  
 11  GSN                object  
 12  NDC                category
 13  PROD_STRENGTH      object  
 14  DOSE_VAL_RX        object  
 15  DOSE_UNIT_RX       object  
 16  FORM_VAL_DISP      object  
 17  FORM_UNIT_DISP     object  
 18  ROUTE              object  
dtypes: category(1), float64(1), int64(3), object(14)
memory usage: 578.9+ MB


# `med_process`

In [4]:
med_pd.drop(columns=['ROW_ID','DRUG_TYPE','DRUG_NAME_POE','DRUG_NAME_GENERIC',
                    'FORMULARY_DRUG_CD','PROD_STRENGTH','DOSE_VAL_RX',
                    'DOSE_UNIT_RX','FORM_VAL_DISP','FORM_UNIT_DISP', 'GSN', 'FORM_UNIT_DISP',
                    'ROUTE','ENDDATE','DRUG'], axis=1, inplace=True)
med_pd.drop(index = med_pd[med_pd['NDC'] == '0'].index, axis=0, inplace=True)
med_pd.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 3569864 entries, 0 to 4156449
Data columns (total 5 columns):
 #   Column      Dtype   
---  ------      -----   
 0   SUBJECT_ID  int64   
 1   HADM_ID     int64   
 2   ICUSTAY_ID  float64 
 3   STARTDATE   object  
 4   NDC         category
dtypes: category(1), float64(1), int64(2), object(1)
memory usage: 143.1+ MB


In [5]:
med_pd.fillna(method='pad', inplace=True)
med_pd.dropna(inplace=True)
med_pd.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 3569840 entries, 30 to 4156449
Data columns (total 5 columns):
 #   Column      Dtype   
---  ------      -----   
 0   SUBJECT_ID  int64   
 1   HADM_ID     int64   
 2   ICUSTAY_ID  float64 
 3   STARTDATE   object  
 4   NDC         category
dtypes: category(1), float64(1), int64(2), object(1)
memory usage: 143.1+ MB


After dropping many edge features columns, these previous medications recommendation works only consider the NDC identifier of drug, but ignore many variables which describe the dose of drug.

This is why they do following `drop_duplicates`.

In [6]:
med_pd.loc[med_pd.duplicated()]

Unnamed: 0,SUBJECT_ID,HADM_ID,ICUSTAY_ID,STARTDATE,NDC
44,13,143045,263738.0,2167-01-09 00:00:00,00074258702
47,13,143045,263738.0,2167-01-09 00:00:00,00074407532
53,13,143045,263738.0,2167-01-09 00:00:00,00338001702
68,6,107064,263738.0,2175-06-15 00:00:00,00004003822
78,9,150750,263738.0,2149-11-09 00:00:00,00456066270
...,...,...,...,...,...
4156442,98887,121032,238144.0,2144-09-06 00:00:00,61553020648
4156443,98887,121032,238144.0,2144-09-06 00:00:00,61553020648
4156445,98887,121032,238144.0,2144-09-06 00:00:00,00054001820
4156446,98887,121032,238144.0,2144-09-06 00:00:00,00487980125


In [7]:
med_pd.drop_duplicates(inplace=True)
med_pd.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 2987420 entries, 30 to 4156449
Data columns (total 5 columns):
 #   Column      Dtype   
---  ------      -----   
 0   SUBJECT_ID  int64   
 1   HADM_ID     int64   
 2   ICUSTAY_ID  float64 
 3   STARTDATE   object  
 4   NDC         category
dtypes: category(1), float64(1), int64(2), object(1)
memory usage: 119.8+ MB


In [8]:
med_pd['ICUSTAY_ID'] = med_pd['ICUSTAY_ID'].astype('int64')
med_pd['STARTDATE'] = pd.to_datetime(med_pd['STARTDATE'], format='%Y-%m-%d %H:%M:%S')
med_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID', 'STARTDATE'], inplace=True)
med_pd = med_pd.reset_index(drop=True)
med_pd.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2987420 entries, 0 to 2987419
Data columns (total 5 columns):
 #   Column      Dtype         
---  ------      -----         
 0   SUBJECT_ID  int64         
 1   HADM_ID     int64         
 2   ICUSTAY_ID  int64         
 3   STARTDATE   datetime64[ns]
 4   NDC         category      
dtypes: category(1), datetime64[ns](1), int64(3)
memory usage: 97.0 MB


In [9]:
med_pd = med_pd.drop(columns=['ICUSTAY_ID'])
med_pd = med_pd.drop_duplicates()
med_pd = med_pd.reset_index(drop=True)
med_pd.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2961058 entries, 0 to 2961057
Data columns (total 4 columns):
 #   Column      Dtype         
---  ------      -----         
 0   SUBJECT_ID  int64         
 1   HADM_ID     int64         
 2   STARTDATE   datetime64[ns]
 3   NDC         category      
dtypes: category(1), datetime64[ns](1), int64(2)
memory usage: 73.6 MB


# `process_visit_lg2`

Filter out patient data with less than two admissions.

In [10]:
a = med_pd[['SUBJECT_ID', 'HADM_ID']].groupby(by='SUBJECT_ID')['HADM_ID'].unique().reset_index()
a

Unnamed: 0,SUBJECT_ID,HADM_ID
0,2,[163353]
1,4,[185777]
2,6,[107064]
3,8,[159514]
4,9,[150750]
...,...,...
39355,99985,[176670]
39356,99991,[151118]
39357,99992,[197084]
39358,99995,[137810]


In [11]:
a['HADM_ID_Len'] = a['HADM_ID'].map(lambda x:len(x))
a

Unnamed: 0,SUBJECT_ID,HADM_ID,HADM_ID_Len
0,2,[163353],1
1,4,[185777],1
2,6,[107064],1
3,8,[159514],1
4,9,[150750],1
...,...,...,...
39355,99985,[176670],1
39356,99991,[151118],1
39357,99992,[197084],1
39358,99995,[137810],1


In [12]:
a = a[a['HADM_ID_Len'] > 1]
a

Unnamed: 0,SUBJECT_ID,HADM_ID,HADM_ID_Len
9,17,"[161087, 194023]",2
13,21,"[109451, 111970]",2
14,23,"[124321, 152223]",2
22,34,"[115799, 144319]",2
24,36,"[122659, 165660, 182104]",3
...,...,...,...
39310,99822,"[146997, 163117, 195871]",3
39328,99883,"[150755, 198523]",2
39331,99897,"[162913, 181057]",2
39338,99923,"[164914, 192053]",2


In [13]:
med_pd_lg2 = a
med_pd_lg2

Unnamed: 0,SUBJECT_ID,HADM_ID,HADM_ID_Len
9,17,"[161087, 194023]",2
13,21,"[109451, 111970]",2
14,23,"[124321, 152223]",2
22,34,"[115799, 144319]",2
24,36,"[122659, 165660, 182104]",3
...,...,...,...
39310,99822,"[146997, 163117, 195871]",3
39328,99883,"[150755, 198523]",2
39331,99897,"[162913, 181057]",2
39338,99923,"[164914, 192053]",2


In [14]:
med_pd = med_pd.merge(med_pd_lg2[['SUBJECT_ID']], on='SUBJECT_ID', how='inner').reset_index(drop=True)
med_pd.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1135469 entries, 0 to 1135468
Data columns (total 4 columns):
 #   Column      Non-Null Count    Dtype         
---  ------      --------------    -----         
 0   SUBJECT_ID  1135469 non-null  int64         
 1   HADM_ID     1135469 non-null  int64         
 2   STARTDATE   1135469 non-null  datetime64[ns]
 3   NDC         1135469 non-null  category      
dtypes: category(1), datetime64[ns](1), int64(2)
memory usage: 28.3 MB


# `ndc2atc4`

In [15]:
with open(ndc_rxnorm_file, 'r') as f:
    ndc2rxnorm = eval(f.read())
    
# 根据ndc_rxnorm_file文件读取ndc到xnorm的映射（这个xnorm似乎等同于下面的RXCUI）
med_pd['RXCUI'] = med_pd['NDC'].map(ndc2rxnorm)
med_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,STARTDATE,NDC,RXCUI
0,17,161087,2135-05-09,00713016550,209363
1,17,161087,2135-05-09,00904770418,1293665
2,17,161087,2135-05-09,00904404073,318272
3,17,161087,2135-05-09,00904526161,198191
4,17,161087,2135-05-09,00121075210,755272
...,...,...,...,...,...
1135464,99982,183791,2157-02-16,51991045757,876193
1135465,99982,183791,2157-02-16,00409490234,727517
1135466,99982,183791,2157-02-16,00904404073,318272
1135467,99982,183791,2157-02-16,63323026201,1361615


In [16]:
med_pd.dropna(inplace=True) # 实际上啥也没删掉
med_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,STARTDATE,NDC,RXCUI
0,17,161087,2135-05-09,00713016550,209363
1,17,161087,2135-05-09,00904770418,1293665
2,17,161087,2135-05-09,00904404073,318272
3,17,161087,2135-05-09,00904526161,198191
4,17,161087,2135-05-09,00121075210,755272
...,...,...,...,...,...
1135464,99982,183791,2157-02-16,51991045757,876193
1135465,99982,183791,2157-02-16,00409490234,727517
1135466,99982,183791,2157-02-16,00904404073,318272
1135467,99982,183791,2157-02-16,63323026201,1361615


In [17]:
med_pd.drop(index = med_pd[med_pd['RXCUI'].isin([''])].index, axis=0, inplace=True)     # 删除特定的RXCUI
med_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,STARTDATE,NDC,RXCUI
0,17,161087,2135-05-09,00713016550,209363
1,17,161087,2135-05-09,00904770418,1293665
2,17,161087,2135-05-09,00904404073,318272
3,17,161087,2135-05-09,00904526161,198191
4,17,161087,2135-05-09,00121075210,755272
...,...,...,...,...,...
1135464,99982,183791,2157-02-16,51991045757,876193
1135465,99982,183791,2157-02-16,00409490234,727517
1135466,99982,183791,2157-02-16,00904404073,318272
1135467,99982,183791,2157-02-16,63323026201,1361615


In [18]:
med_pd['RXCUI'] = med_pd['RXCUI'].astype('int64').reset_index(drop=True)
med_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,STARTDATE,NDC,RXCUI
0,17,161087,2135-05-09,00713016550,209363.0
1,17,161087,2135-05-09,00904770418,1293665.0
2,17,161087,2135-05-09,00904404073,318272.0
3,17,161087,2135-05-09,00904526161,198191.0
4,17,161087,2135-05-09,00121075210,755272.0
...,...,...,...,...,...
1135464,99982,183791,2157-02-16,51991045757,
1135465,99982,183791,2157-02-16,00409490234,
1135466,99982,183791,2157-02-16,00904404073,
1135467,99982,183791,2157-02-16,63323026201,


## rxnorm2atc

In [19]:
rxnorm2atc = pd.read_csv(ndc2atc_file)
rxnorm2atc = rxnorm2atc.drop(columns=['YEAR','MONTH','NDC'])    # NDC删了，直接从RXCUI映射到ATC
rxnorm2atc

Unnamed: 0,RXCUI,ATC4
0,853004,C01BA
1,1551300,A10BJ
2,1551300,A10BJ
3,1551306,A10BJ
4,1551306,A10BJ
...,...,...
318252,543484,D07AC
318253,543484,C05AA
318254,543484,D07AC
318255,1482689,D01AC


In [20]:
# 根据RXCUI删除重复列
rxnorm2atc.drop_duplicates(subset=['RXCUI'], inplace=True)
rxnorm2atc

Unnamed: 0,RXCUI,ATC4
0,853004,C01BA
1,1551300,A10BJ
3,1551306,A10BJ
5,1745108,L04AC
9,1599948,G03BA
...,...,...
318217,858374,A01AB
318223,858064,A01AB
318229,1014022,A01AB
318241,1014018,A01AB


In [21]:
med_pd = med_pd.merge(rxnorm2atc, on=['RXCUI'])     # 合并两个表
med_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,STARTDATE,NDC,RXCUI,ATC4
0,17,161087,2135-05-09,00713016550,209363.0,N02BE
1,17,194023,2134-12-27,51079000220,209363.0,N02BE
2,21,111970,2135-02-06,63323022905,209363.0,N02BE
3,23,152223,2153-09-03,00406051262,209363.0,N02BE
4,36,122659,2131-05-15,00409176230,209363.0,N02BE
...,...,...,...,...,...,...
873668,96950,176286,2103-06-20,00574705050,198324.0,N05AB
873669,96950,190421,2103-09-04,00409739172,198324.0,N05AB
873670,96950,190421,2103-09-10,00186109039,198324.0,N05AB
873671,96958,102063,2131-03-25,51079025520,198324.0,N05AB


In [22]:
med_pd.drop(columns=['NDC', 'RXCUI'], inplace=True) # 干掉NDC\RXCUI，只剩ATC4了
med_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,STARTDATE,ATC4
0,17,161087,2135-05-09,N02BE
1,17,194023,2134-12-27,N02BE
2,21,111970,2135-02-06,N02BE
3,23,152223,2153-09-03,N02BE
4,36,122659,2131-05-15,N02BE
...,...,...,...,...
873668,96950,176286,2103-06-20,N05AB
873669,96950,190421,2103-09-04,N05AB
873670,96950,190421,2103-09-10,N05AB
873671,96958,102063,2131-03-25,N05AB


In [23]:
med_pd = med_pd.rename(columns={'ATC4':'NDC'})      # 重新命名为NDC
med_pd['NDC'] = med_pd['NDC'].map(lambda x: x[:4])  # 只保留前四位

med_pd = med_pd.drop_duplicates()
med_pd = med_pd.reset_index(drop=True)
med_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,STARTDATE,NDC
0,17,161087,2135-05-09,N02B
1,17,194023,2134-12-27,N02B
2,21,111970,2135-02-06,N02B
3,23,152223,2153-09-03,N02B
4,36,122659,2131-05-15,N02B
...,...,...,...,...
694306,96950,169823,2103-08-11,N05A
694307,96950,176286,2103-06-20,N05A
694308,96950,190421,2103-09-04,N05A
694309,96950,190421,2103-09-10,N05A


In [24]:
med_pd.NDC.value_counts()

A12C    49593
A06A    47786
B05C    40078
B01A    35704
C07A    34325
        ...  
L02A        2
P02C        1
B06A        1
R02A        1
S01K        1
Name: NDC, Length: 151, dtype: int64

In [25]:
NDCList = dill.load(open(med_structure_file, 'rb'))  # 药物到分子式的映射
NDCList

{'A01A': {'CC(=O)OC1=CC=CC=C1C(O)=O',
  '[F-].[Na+]',
  '[H][C@@]12C[C@@H](C)[C@](O)(C(=O)CO)[C@@]1(C)C[C@H](O)[C@@]1(F)[C@@]2([H])CCC2=CC(=O)C=C[C@]12C'},
 'A02A': {'[MgH2]', '[OH-].[OH-].[Mg++]'},
 'A02B': {'CC1=C(OCC(F)(F)F)C=CN=C1CS(=O)C1=NC2=CC=CC=C2N1',
  'COC1=C(OC)C(CS(=O)C2=NC3=C(N2)C=C(OC(F)F)C=C3)=NC=C1',
  'COC1=CC2=C(C=C1)N=C(N2)S(=O)CC1=NC=C(C)C(OC)=C1C'},
 'A03A': {'CCN(CC)CCOC(=O)C1(CCCCC1)C1CCCCC1'},
 'A03B': {'CCOC(=O)C1(CCN(CCC(C#N)(C2=CC=CC=C2)C2=CC=CC=C2)CC1)C1=CC=CC=C1',
  'CN1[C@H]2CC[C@@H]1C[C@@H](C2)OC(=O)C(CO)C1=CC=CC=C1',
  'CN1[C@H]2CC[C@@H]1C[C@@H](C2)OC(=O)[C@H](CO)C1=CC=CC=C1'},
 'A03F': {'CCN(CC)CCNC(=O)C1=CC(Cl)=C(N)C=C1OC'},
 'A04A': {'CN1C2=C(C3=CC=CC=C13)C(=O)C(CN1C=CN=C1C)CC2',
  'CN1[C@H]2C[C@@H](C[C@@H]1[C@H]1O[C@@H]21)OC(=O)[C@H](CO)C1=CC=CC=C1',
  '[H][C@@]12C=C(C)CC[C@@]1([H])C(C)(C)OC1=C2C(O)=CC(CCCCC)=C1'},
 'A05A': {'[H][C@@]1(CC[C@@]2([H])[C@]3([H])[C@@H](O)C[C@]4([H])C[C@H](O)CC[C@]4(C)[C@@]3([H])CC[C@]12C)[C@H](C)CCC(O)=O'},
 'A06A': {'CC

In [26]:
med_pd = med_pd[med_pd.NDC.isin(list(NDCList.keys()))]
med_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,STARTDATE,NDC
0,17,161087,2135-05-09,N02B
1,17,194023,2134-12-27,N02B
2,21,111970,2135-02-06,N02B
3,23,152223,2153-09-03,N02B
4,36,122659,2131-05-15,N02B
...,...,...,...,...
694306,96950,169823,2103-08-11,N05A
694307,96950,176286,2103-06-20,N05A
694308,96950,190421,2103-09-04,N05A
694309,96950,190421,2103-09-10,N05A


In [27]:
med_pd.NDC.value_counts()

A12C    49593
A06A    47786
B05C    40078
B01A    35704
C07A    34325
        ...  
V04C        3
S01A        3
R02A        1
P02C        1
S01K        1
Name: NDC, Length: 131, dtype: int64

# `filter_300_most_med`

In [28]:
# 按照NDC出现的次数降序排列，取前300
# 但由于此前只剩下131个不同的NDC（实际上是ATC4）了，这一步相当于什么也没删除
med_count = med_pd.groupby(by=['NDC'])\
                  .size()\
                  .reset_index()\
                  .rename(columns={0:'count'})\
                  .sort_values(by=['count'],ascending=False)\
                  .reset_index(drop=True)
med_pd = med_pd[med_pd['NDC'].isin(med_count.loc[:299, 'NDC'])]
med_pd.reset_index(drop=True)

Unnamed: 0,SUBJECT_ID,HADM_ID,STARTDATE,NDC
0,17,161087,2135-05-09,N02B
1,17,194023,2134-12-27,N02B
2,21,111970,2135-02-06,N02B
3,23,152223,2153-09-03,N02B
4,36,122659,2131-05-15,N02B
...,...,...,...,...
673969,96950,169823,2103-08-11,N05A
673970,96950,176286,2103-06-20,N05A
673971,96950,190421,2103-09-04,N05A
673972,96950,190421,2103-09-10,N05A


# `diag_process`

In [35]:
diag_pd = pd.read_csv(os.path.join(path_iii_dataset, "DIAGNOSES_ICD.csv.gz"))
diag_pd.head()

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,SEQ_NUM,ICD9_CODE
0,1297,109,172335,1.0,40301
1,1298,109,172335,2.0,486
2,1299,109,172335,3.0,58281
3,1300,109,172335,4.0,5855
4,1301,109,172335,5.0,4254


In [36]:
diag_pd.dropna(inplace=True)
diag_pd.drop(columns=['SEQ_NUM','ROW_ID'],inplace=True)
diag_pd.drop_duplicates(inplace=True)
diag_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE
0,109,172335,40301
1,109,172335,486
2,109,172335,58281
3,109,172335,5855
4,109,172335,4254
...,...,...,...
651042,97503,188195,20280
651043,97503,188195,V5869
651044,97503,188195,V1279
651045,97503,188195,5275


In [37]:
diag_pd.sort_values(by=['SUBJECT_ID','HADM_ID'], inplace=True)
diag_pd = diag_pd.reset_index(drop=True)
diag_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE
0,2,163353,V3001
1,2,163353,V053
2,2,163353,V290
3,3,145834,0389
4,3,145834,78559
...,...,...,...
650935,99999,113369,75612
650936,99999,113369,7861
650937,99999,113369,4019
650938,99999,113369,25000


In [38]:
def filter_2000_most_diag(diag_pd):
    diag_count = diag_pd.groupby(by=['ICD9_CODE'])\
                        .size()\
                        .reset_index()\
                        .rename(columns={0:'count'})\
                        .sort_values(by=['count'],ascending=False)\
                        .reset_index(drop=True)
    diag_pd = diag_pd[diag_pd['ICD9_CODE'].isin(diag_count.loc[:1999, 'ICD9_CODE'])]
    
    return diag_pd.reset_index(drop=True)

diag_pd = filter_2000_most_diag(diag_pd)
diag_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE
0,2,163353,V3001
1,2,163353,V053
2,2,163353,V290
3,3,145834,0389
4,3,145834,78559
...,...,...,...
625429,99995,137810,41401
625430,99999,113369,7861
625431,99999,113369,4019
625432,99999,113369,25000


# `procedure_process`

In [48]:
pro_pd = pd.read_csv(os.path.join(path_iii_dataset, "PROCEDURES_ICD.csv.gz"), dtype={'ICD9_CODE':'category'})
pro_pd

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,SEQ_NUM,ICD9_CODE
0,944,62641,154460,3,3404
1,945,2592,130856,1,9671
2,946,2592,130856,2,3893
3,947,55357,119355,1,9672
4,948,55357,119355,2,0331
...,...,...,...,...,...
240090,228330,67415,150871,5,3736
240091,228331,67415,150871,6,3893
240092,228332,67415,150871,7,8872
240093,228333,67415,150871,8,3893


In [49]:
pro_pd.drop(columns=['ROW_ID'], inplace=True)
pro_pd.drop_duplicates(inplace=True)
pro_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,SEQ_NUM,ICD9_CODE
0,62641,154460,3,3404
1,2592,130856,1,9671
2,2592,130856,2,3893
3,55357,119355,1,9672
4,55357,119355,2,0331
...,...,...,...,...
240090,67415,150871,5,3736
240091,67415,150871,6,3893
240092,67415,150871,7,8872
240093,67415,150871,8,3893


In [50]:
pro_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'SEQ_NUM'], inplace=True)
pro_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,SEQ_NUM,ICD9_CODE
95085,2,163353,1,9955
45149,3,145834,1,9604
45150,3,145834,2,9962
45151,3,145834,3,8964
45152,3,145834,4,9672
...,...,...,...,...
165967,99999,113369,1,8108
165968,99999,113369,2,8051
165969,99999,113369,3,8162
165970,99999,113369,4,9979


In [51]:
pro_pd.drop(columns=['SEQ_NUM'], inplace=True)
pro_pd.drop_duplicates(inplace=True)
pro_pd.reset_index(drop=True, inplace=True)
pro_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE
0,2,163353,9955
1,3,145834,9604
2,3,145834,9962
3,3,145834,8964
4,3,145834,9672
...,...,...,...
228674,99999,113369,8108
228675,99999,113369,8051
228676,99999,113369,8162
228677,99999,113369,9979


# `combine_process`

"""药物、症状、proc的数据结合"""

In [52]:
med_pd_key = med_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
diag_pd_key = diag_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()
pro_pd_key = pro_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates()

在三个表中均存在的`SUBJECT_ID`, `HADM_ID`才能得到保留

In [55]:
combined_key = med_pd_key.merge(diag_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
combined_key = combined_key.merge(pro_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')

combined_key

Unnamed: 0,SUBJECT_ID,HADM_ID
0,17,161087
1,17,194023
2,21,111970
3,23,152223
4,36,122659
...,...,...
15002,26200,115886
15003,21235,155959
15004,29129,131352
15005,26704,163846


In [56]:
diag_pd = diag_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
med_pd = med_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
pro_pd = pro_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner')

In [57]:
diag_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index()

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE
0,17,161087,"[4239, 5119, 78551, 4589, 311, 7220, 71946, 2724]"
1,17,194023,"[7455, 45829, V1259, 2724]"
2,21,109451,"[41071, 78551, 5781, 5849, 40391, 4280, 4592, ..."
3,21,111970,"[0388, 78552, 40391, 42731, 70709, 5119, 6823,..."
4,23,124321,"[2252, 3485, 78039, 4241, 4019, 2720, 2724, V4..."
...,...,...,...
15002,99408,169240,"[1508, 5070, 42732, 5119, 99659, 45829, V444, ..."
15003,99439,143661,"[34830, 42842, 5849, 4168, 43883, 496, 32723, ..."
15004,99464,162179,"[44103, 3361, 43401, 4441, 3441, V4986, 311, 5..."
15005,99469,126023,"[99672, 41402, 4111, 42822, 41401, 4142, 4280,..."


In [58]:
med_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['NDC'].unique().reset_index()

Unnamed: 0,SUBJECT_ID,HADM_ID,NDC
0,17,161087,"[N02B, A01A, A02B, A06A, B05C, A12A, A12C, C01..."
1,17,194023,"[N02B, A01A, A02B, A06A, A12A, B05C, A12C, C01..."
2,21,109451,"[A06A, B05C, A12C, C07A, A12B, C03C, A12A, J01..."
3,21,111970,"[N02B, A06A, B05C, A12C, A07A, A02A, B01A, N06..."
4,23,124321,"[B05C, A07A, C07A, A06A, N02B, C02D, B01A, A02..."
...,...,...,...
15002,99408,169240,"[N02B, A02B, A06A, B05C, A12A, A12C, A07A, C07..."
15003,99439,143661,"[A12C, C07A, A06A, A02A, B01A, A02B, D01A, J01..."
15004,99464,162179,"[N02B, A01A, B05C, A12C, C01C, A06A, C03C, N07..."
15005,99469,126023,"[N02B, C03C, B01A, B05C, A06A, A04A, A02B, N02..."


In [59]:
diag_pd = diag_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index()  
med_pd = med_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])['NDC'].unique().reset_index()
pro_pd = pro_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index()\
                                                                          .rename(columns={'ICD9_CODE':'PRO_CODE'}) 

In [60]:
med_pd['NDC'] = med_pd['NDC'].map(lambda x: list(x))
med_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,NDC
0,17,161087,"[N02B, A01A, A02B, A06A, B05C, A12A, A12C, C01..."
1,17,194023,"[N02B, A01A, A02B, A06A, A12A, B05C, A12C, C01..."
2,21,109451,"[A06A, B05C, A12C, C07A, A12B, C03C, A12A, J01..."
3,21,111970,"[N02B, A06A, B05C, A12C, A07A, A02A, B01A, N06..."
4,23,124321,"[B05C, A07A, C07A, A06A, N02B, C02D, B01A, A02..."
...,...,...,...
15002,99408,169240,"[N02B, A02B, A06A, B05C, A12A, A12C, A07A, C07..."
15003,99439,143661,"[A12C, C07A, A06A, A02A, B01A, A02B, D01A, J01..."
15004,99464,162179,"[N02B, A01A, B05C, A12C, C01C, A06A, C03C, N07..."
15005,99469,126023,"[N02B, C03C, B01A, B05C, A06A, A04A, A02B, N02..."


In [61]:
pro_pd['PRO_CODE'] = pro_pd['PRO_CODE'].map(lambda x: list(x))
pro_pd

Unnamed: 0,SUBJECT_ID,HADM_ID,PRO_CODE
0,17,161087,"[3731, 8872, 3893]"
1,17,194023,"[3571, 3961, 8872]"
2,21,109451,"[0066, 3761, 3950, 3606, 0042, 0047, 3895, 399..."
3,21,111970,"[3995, 8961, 0014]"
4,23,124321,[0151]
...,...,...,...
15002,99408,169240,"[4242, 9803, 9671, 3324, 4422, 4513, 966]"
15003,99439,143661,"[9671, 9604, 3893, 3322]"
15004,99464,162179,"[0110, 3893, 3891]"
15005,99469,126023,"[0066, 3723, 8856, 8857, 0040]"


## final data

In [63]:
data = diag_pd.merge(med_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
data = data.merge(pro_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner')
data['NDC_Len'] = data['NDC'].map(lambda x: len(x))

data

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE,NDC,PRO_CODE,NDC_Len
0,17,161087,"[4239, 5119, 78551, 4589, 311, 7220, 71946, 2724]","[N02B, A01A, A02B, A06A, B05C, A12A, A12C, C01...","[3731, 8872, 3893]",15
1,17,194023,"[7455, 45829, V1259, 2724]","[N02B, A01A, A02B, A06A, A12A, B05C, A12C, C01...","[3571, 3961, 8872]",18
2,21,109451,"[41071, 78551, 5781, 5849, 40391, 4280, 4592, ...","[A06A, B05C, A12C, C07A, A12B, C03C, A12A, J01...","[0066, 3761, 3950, 3606, 0042, 0047, 3895, 399...",23
3,21,111970,"[0388, 78552, 40391, 42731, 70709, 5119, 6823,...","[N02B, A06A, B05C, A12C, A07A, A02A, B01A, N06...","[3995, 8961, 0014]",20
4,23,124321,"[2252, 3485, 78039, 4241, 4019, 2720, 2724, V4...","[B05C, A07A, C07A, A06A, N02B, C02D, B01A, A02...",[0151],17
...,...,...,...,...,...,...
15002,99408,169240,"[1508, 5070, 42732, 5119, 99659, 45829, V444, ...","[N02B, A02B, A06A, B05C, A12A, A12C, A07A, C07...","[4242, 9803, 9671, 3324, 4422, 4513, 966]",24
15003,99439,143661,"[34830, 42842, 5849, 4168, 43883, 496, 32723, ...","[A12C, C07A, A06A, A02A, B01A, A02B, D01A, J01...","[9671, 9604, 3893, 3322]",17
15004,99464,162179,"[44103, 3361, 43401, 4441, 3441, V4986, 311, 5...","[N02B, A01A, B05C, A12C, C01C, A06A, C03C, N07...","[0110, 3893, 3891]",19
15005,99469,126023,"[99672, 41402, 4111, 42822, 41401, 4142, 4280,...","[N02B, C03C, B01A, B05C, A06A, A04A, A02B, N02...","[0066, 3723, 8856, 8857, 0040]",17


In [68]:
data.to_csv(os.path.join(path_iii_dataset, "final_data_of_previous_works.csv.gz"))

# `create_str_token_mapping`

In [66]:
class Voc(object):
    def __init__(self):
        self.idx2word = {}
        self.word2idx = {}

    def add_sentence(self, sentence):
        for word in sentence:
            
            # new word
            if word not in self.word2idx:
                self.idx2word[len(self.word2idx)] = word
                self.word2idx[word] = len(self.word2idx)
                
diag_voc = Voc()
med_voc = Voc()
pro_voc = Voc()

for index, row in data.iterrows():
    diag_voc.add_sentence(row['ICD9_CODE'])
    med_voc.add_sentence(row['NDC'])
    pro_voc.add_sentence(row['PRO_CODE'])
    
med_voc.idx2word

{0: 'N02B',
 1: 'A01A',
 2: 'A02B',
 3: 'A06A',
 4: 'B05C',
 5: 'A12A',
 6: 'A12C',
 7: 'C01C',
 8: 'A07A',
 9: 'M01A',
 10: 'N01A',
 11: 'C07A',
 12: 'C03C',
 13: 'A12B',
 14: 'N07A',
 15: 'C02D',
 16: 'N02A',
 17: 'N06A',
 18: 'A02A',
 19: 'J01M',
 20: 'C02A',
 21: 'B01A',
 22: 'A11C',
 23: 'C03A',
 24: 'A03B',
 25: 'C10A',
 26: 'C01B',
 27: 'N05C',
 28: 'C09A',
 29: 'D01A',
 30: 'H03A',
 31: 'J01D',
 32: 'B02B',
 33: 'R06A',
 34: 'J01X',
 35: 'N03A',
 36: 'N05A',
 37: 'C08C',
 38: 'D11A',
 39: 'C01D',
 40: 'A04A',
 41: 'M03A',
 42: 'A07E',
 43: 'R03A',
 44: 'B03B',
 45: 'D07A',
 46: 'N07B',
 47: 'N05B',
 48: 'R05C',
 49: 'D06A',
 50: 'A03F',
 51: 'R01A',
 52: 'G04B',
 53: 'C01E',
 54: 'L01A',
 55: 'A07D',
 56: 'D04A',
 57: 'A05A',
 58: 'P01C',
 59: 'D06B',
 60: 'L01B',
 61: 'C01A',
 62: 'C05A',
 63: 'C03D',
 64: 'P01A',
 65: 'J02A',
 66: 'J05A',
 67: 'L01X',
 68: 'H02A',
 69: 'V03A',
 70: 'J01F',
 71: 'G03A',
 72: 'J01E',
 73: 'J04A',
 74: 'D10A',
 75: 'P01B',
 76: 'R05D',
 77: 'N04

In [67]:
dill.dump(obj={'diag_voc':diag_voc, 
               'med_voc':med_voc ,
               'pro_voc':pro_voc}, 
          file=open(r'/data/data2/041/datasets/DDI/voc_final.pkl','wb'))

# `create_patient_record`

保存list类型的记录

每一项代表一个患者，患者中有多个visit，每个visit包含三者数组，按顺序分别表示诊断、proc与药物

存储的均为编号，可以通过voc_final.pkl来查看对应的具体word

In [72]:
records = [] # (patient, code_kind:3, codes)  code_kind:diag, proc, med

for subject_id in data['SUBJECT_ID'].unique():
    item_df = data[data['SUBJECT_ID'] == subject_id]  # all records of current patient
    
    patient = []  # addmission(s)
    for index, row in item_df.iterrows():
        admission = []
        admission.append([diag_voc.word2idx[i] for i in row['ICD9_CODE']])
        admission.append([pro_voc.word2idx[i] for i in row['PRO_CODE']])
        admission.append([med_voc.word2idx[i] for i in row['NDC']])
        
        patient.append(admission)
    records.append(patient)

records[0][0]

[[0, 1, 2, 3, 4, 5, 6, 7],
 [0, 1, 2],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]

In [73]:
records[0][1]

[[8, 9, 10, 7],
 [3, 4, 1],
 [0, 1, 2, 3, 5, 4, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18]]

In [75]:
dill.dump(obj=records, file=open(os.path.join(path_iii_dataset, 'records_final_previous_works.pkl'), 'wb'))

# `get_ddi_matrix`

In [82]:
med_voc_size = len(med_voc.idx2word)
med_unique_word = [med_voc.idx2word[i] for i in range(med_voc_size)]    # 所有的药物的ATC4
med_unique_word

['N02B',
 'A01A',
 'A02B',
 'A06A',
 'B05C',
 'A12A',
 'A12C',
 'C01C',
 'A07A',
 'M01A',
 'N01A',
 'C07A',
 'C03C',
 'A12B',
 'N07A',
 'C02D',
 'N02A',
 'N06A',
 'A02A',
 'J01M',
 'C02A',
 'B01A',
 'A11C',
 'C03A',
 'A03B',
 'C10A',
 'C01B',
 'N05C',
 'C09A',
 'D01A',
 'H03A',
 'J01D',
 'B02B',
 'R06A',
 'J01X',
 'N03A',
 'N05A',
 'C08C',
 'D11A',
 'C01D',
 'A04A',
 'M03A',
 'A07E',
 'R03A',
 'B03B',
 'D07A',
 'N07B',
 'N05B',
 'R05C',
 'D06A',
 'A03F',
 'R01A',
 'G04B',
 'C01E',
 'L01A',
 'A07D',
 'D04A',
 'A05A',
 'P01C',
 'D06B',
 'L01B',
 'C01A',
 'C05A',
 'C03D',
 'P01A',
 'J02A',
 'J05A',
 'L01X',
 'H02A',
 'V03A',
 'J01F',
 'G03A',
 'J01E',
 'J04A',
 'D10A',
 'P01B',
 'R05D',
 'N04B',
 'G04C',
 'J01C',
 'S01E',
 'H05B',
 'M04A',
 'C09C',
 'J01G',
 'C08D',
 'N06D',
 'H01C',
 'L04A',
 'A10B',
 'C05B',
 'B02A',
 'D08A',
 'A16A',
 'A11D',
 'C02C',
 'J01A',
 'A11G',
 'H03B',
 'L01D',
 'N06B',
 'C03B',
 'N01B',
 'G03C',
 'N04A',
 'N02C',
 'M03B',
 'A07B',
 'A11H',
 'M05B',
 'S01F',
 

In [83]:
atc3_atc4_dic = defaultdict(set)
for item in med_unique_word:
    atc3_atc4_dic[item[:4]].add(item)

atc3_atc4_dic

defaultdict(set,
            {'N02B': {'N02B'},
             'A01A': {'A01A'},
             'A02B': {'A02B'},
             'A06A': {'A06A'},
             'B05C': {'B05C'},
             'A12A': {'A12A'},
             'A12C': {'A12C'},
             'C01C': {'C01C'},
             'A07A': {'A07A'},
             'M01A': {'M01A'},
             'N01A': {'N01A'},
             'C07A': {'C07A'},
             'C03C': {'C03C'},
             'A12B': {'A12B'},
             'N07A': {'N07A'},
             'C02D': {'C02D'},
             'N02A': {'N02A'},
             'N06A': {'N06A'},
             'A02A': {'A02A'},
             'J01M': {'J01M'},
             'C02A': {'C02A'},
             'B01A': {'B01A'},
             'A11C': {'A11C'},
             'C03A': {'C03A'},
             'A03B': {'A03B'},
             'C10A': {'C10A'},
             'C01B': {'C01B'},
             'N05C': {'N05C'},
             'C09A': {'C09A'},
             'D01A': {'D01A'},
             'H03A': {'H03A'},
             'J01D': {

In [84]:
cid2atc_dic = defaultdict(set)  # 去重
with open(cid_atc, 'r') as f:
    for line in f:
        line_ls = line[:-1].split(',')
        
        cid = line_ls[0]
        atcs = line_ls[1:]
        
        for atc in atcs:
            if len(atc3_atc4_dic[atc[:4]]) != 0:  # exist atc4
                cid2atc_dic[cid].add(atc[:4])
                
cid2atc_dic

defaultdict(set,
            {'CID000004011': {'N06A'},
             'CID000071273': {'N01B'},
             'CID000062816': {'C10A'},
             'CID000052421': {'D07A'},
             'CID000056339': {'N03A'},
             'CID000077992': {'N02C'},
             'CID000005300': {'L01A'},
             'CID000222786': {'H02A'},
             'CID000071301': {'C07A'},
             'CID000004870': {'C03A'},
             'CID000004873': {'A12B'},
             'CID000002142': {'D06A', 'J01G', 'S01A'},
             'CID000002141': {'V03A'},
             'CID000131536': {'J05A'},
             'CID000001775': {'N03A'},
             'CID000003305': {'M05B'},
             'CID000005454': {'N05A'},
             'CID000003308': {'M01A'},
             'CID000004909': {'D08A', 'N03A', 'R01A', 'R02A', 'S01A'},
             'CID000004908': {'P01B'},
             'CID000002099': {'A03A'},
             'CID000054688': {'J01F'},
             'CID000002092': {'G04C'},
             'CID000005503': {'A10B'},

In [87]:
# 加载DDI数据
ddi_df = pd.read_csv(os.path.join(path_ddi_dataset, "drug-DDI.csv"))
ddi_df

Unnamed: 0,STITCH 1,STITCH 2,Polypharmacy Side Effect,Side Effect Name
0,CID000002173,CID000003345,C0151714,hypermagnesemia
1,CID000002173,CID000003345,C0035344,retinopathy of prematurity
2,CID000002173,CID000003345,C0004144,atelectasis
3,CID000002173,CID000003345,C0002063,alkalosis
4,CID000002173,CID000003345,C0004604,Back Ache
...,...,...,...,...
4649436,CID000003461,CID000003954,C0149871,deep vein thromboses
4649437,CID000003461,CID000003954,C0035410,rhabdomyolysis
4649438,CID000003461,CID000003954,C0043096,loss of weight
4649439,CID000003461,CID000003954,C0003962,ascites


In [88]:
# fliter sever side effect，也是采取topK的形式
ddi_most_pd = ddi_df.groupby(by=['Polypharmacy Side Effect', 'Side Effect Name'])\
                    .size()\
                    .reset_index()\
                    .rename(columns={0:'count'})\
                    .sort_values(by=['count'],ascending=False)\
                    .reset_index(drop=True)

TOPK = 40 # topk drug-drug interaction
ddi_most_pd = ddi_most_pd.iloc[-TOPK:,:]
ddi_most_pd

Unnamed: 0,Polypharmacy Side Effect,Side Effect Name,count
1277,C0042755,Masculinization,28
1278,C0156273,bladder diverticulum,27
1279,C0155707,trifascicular block,26
1280,C0008513,chorioretinitis,25
1281,C0549398,meibomianitis,23
1282,C0027134,Myringitis,23
1283,C0032024,pityriasis,22
1284,C0014390,entropion,21
1285,C0008533,hemophilia B,21
1286,C0008497,choriocarcinoma,18


In [89]:
fliter_ddi_df = ddi_df.merge(ddi_most_pd[['Side Effect Name']], how='inner', on=['Side Effect Name'])
fliter_ddi_df

Unnamed: 0,STITCH 1,STITCH 2,Polypharmacy Side Effect,Side Effect Name
0,CID000002802,CID000003639,C0027086,myoma
1,CID000000937,CID000002250,C0027086,myoma
2,CID000005038,CID000027661,C0027086,myoma
3,CID000005038,CID000006691,C0027086,myoma
4,CID000000937,CID000027661,C0027086,myoma
...,...,...,...,...
471,CID000004044,CID000004542,C0025218,chloasma
472,CID000000450,CID000000853,C0025218,chloasma
473,CID000010631,CID000060852,C0014935,estrogen replacement
474,CID000000450,CID000060852,C0014935,estrogen replacement


In [94]:
ddi_df = fliter_ddi_df[['STITCH 1','STITCH 2']].drop_duplicates().reset_index(drop=True)
ddi_df

Unnamed: 0,STITCH 1,STITCH 2
0,CID000002802,CID000003639
1,CID000000937,CID000002250
2,CID000005038,CID000027661
3,CID000005038,CID000006691
4,CID000000937,CID000027661
...,...,...
455,CID000004044,CID000004542
456,CID000000450,CID000000853
457,CID000010631,CID000060852
458,CID000000450,CID000060852


In [91]:
# ddi adj，DDI表是CID编码的，因此需要将CID映射到ACT编码，才能记录数据集中药物之间的冲突信息
ddi_adj = np.zeros((med_voc_size,med_voc_size))
for index, row in ddi_df.iterrows():
    # ddi
    cid1 = row['STITCH 1']
    cid2 = row['STITCH 2']
    
    # cid -> atc_level3
    for atc_i in cid2atc_dic[cid1]:
        for atc_j in cid2atc_dic[cid2]:
            
            # atc_level3 -> atc_level4
            for i in atc3_atc4_dic[atc_i]:
                for j in atc3_atc4_dic[atc_j]:
                    if med_voc.word2idx[i] != med_voc.word2idx[j]:
                        ddi_adj[med_voc.word2idx[i], med_voc.word2idx[j]] = 1
                        ddi_adj[med_voc.word2idx[j], med_voc.word2idx[i]] = 1
                        
ddi_adj

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [93]:
ddi_adj.shape

(131, 131)

In [97]:
sum(sum(ddi_adj)) / ddi_adj.size

0.05221140959151565