In [1]:
import os
import sys
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split


# set seed
seed = 100
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# Define data path
DATA_PATH = "./preprocessed/"
from scipy.stats.stats import pearsonr

  from scipy.stats.stats import pearsonr


In [2]:
pids = pickle.load(open(os.path.join(DATA_PATH,'pid.pkl'), 'rb'))
x_dem = pickle.load(open(os.path.join(DATA_PATH,'x_dem.pkl'), 'rb'))
x_per = pickle.load(open(os.path.join(DATA_PATH,'x_per_added.pkl'), 'rb'))
x_d = pickle.load(open(os.path.join(DATA_PATH,'x_d_added.pkl'), 'rb'))

In [3]:
### Leaving Only patients who have both of examination and prescriptions
dem_no_idx = set()
per_no_idx = set()
d_no_idx = set()

for idx in range(len(x_dem)):
    
    if len(x_dem[idx]) == 0:
        dem_no_idx.add(idx)
    
    if len(x_per[idx]) == 0:
        per_no_idx.add(idx)
    
    if len(x_d[idx]) == 0:
        d_no_idx.add(idx)

In [4]:
tot_no_idx = per_no_idx | d_no_idx
all_idx = set(range(len(x_dem)))
alive_idx = all_idx - tot_no_idx

alive_idx_list = list(alive_idx)
alive_idx_list.sort()


x_dem_f = [x_dem[i] for i in alive_idx_list]
x_per_f = [x_per[i] for i in alive_idx_list]
x_d_f = [x_d[i] for i in alive_idx_list]


In [5]:
print(len(x_dem_f))
print(len(x_per_f))
print(len(x_d_f))


19432
19432
19432


In [28]:
## Correlation between examination values and number of diagnosis 
## 4/19 examination metrics
signal_0 = []
signal_1 = []
signal_2 = []
signal_3 = []
signal_4 = []
signal_5 = []
signal_6 = []
signal_7 = []
signal_8 = []
signal_9 = []


sum_prescription = []

for p_idx, patient in enumerate(range(len(x_per_f))):
    for v_idx, visit in enumerate(range(len(x_per_f[p_idx]))):
        signal_0.append(x_per_f[p_idx][v_idx][0])
        signal_1.append(x_per_f[p_idx][v_idx][1])
        signal_2.append(x_per_f[p_idx][v_idx][2])
        signal_3.append(x_per_f[p_idx][v_idx][3])
        signal_4.append(x_per_f[p_idx][v_idx][4])
        signal_5.append(x_per_f[p_idx][v_idx][5])
        signal_6.append(x_per_f[p_idx][v_idx][6])
        signal_7.append(x_per_f[p_idx][v_idx][7])
        signal_8.append(x_per_f[p_idx][v_idx][8])
        signal_9.append(x_per_f[p_idx][v_idx][9])


        sum_prescription.append(sum(x_d_f[p_idx][v_idx]))


In [23]:
print(len(signal_0))
print(len(signal_1))
print(len(sum_prescription))

105101
105101
105101


In [7]:
from scipy import stats

In [29]:
print(stats.pearsonr(signal_0, sum_prescription))
print(stats.pearsonr(signal_1, sum_prescription))
print(stats.pearsonr(signal_2, sum_prescription))
print(stats.pearsonr(signal_3, sum_prescription))
print(stats.pearsonr(signal_4, sum_prescription))
print(stats.pearsonr(signal_5, sum_prescription))
print(stats.pearsonr(signal_6, sum_prescription))
print(stats.pearsonr(signal_7, sum_prescription))
print(stats.pearsonr(signal_8, sum_prescription))
print(stats.pearsonr(signal_9, sum_prescription))

(0.0005614566878054583, 0.8555685292643915)
(0.2571851660934447, 0.0)
(0.39731365175755795, 0.0)
(0.22764478054341095, 0.0)
(0.24459958746913443, 0.0)
(0.09014999436986289, 1.59713967319103e-188)
(0.05174021166753502, 3.1502892271935005e-63)
(0.041472581000349396, 3.048036413527396e-41)
(-0.006905286810450665, 0.02517904786970362)
(-0.009515598685796255, 0.0020360544355439093)


In [26]:
print(stats.pointbiserialr(signal_0, sum_prescription))
print(stats.pointbiserialr(signal_1, sum_prescription))
print(stats.pointbiserialr(signal_2, sum_prescription))
print(stats.pointbiserialr(signal_3, sum_prescription))


PointbiserialrResult(correlation=0.0005614566878054583, pvalue=0.8555685292643915)
PointbiserialrResult(correlation=0.2571851660934447, pvalue=0.0)
PointbiserialrResult(correlation=0.39731365175755795, pvalue=0.0)
PointbiserialrResult(correlation=0.22764478054341095, pvalue=0.0)


In [30]:
prescription_1 = []
prescription_2 = []
prescription_3 = []


for p_idx, patient in enumerate(range(len(x_per_f))):
    for v_idx, visit in enumerate(range(len(x_per_f[p_idx]))):
        prescription_1.append(x_d_f[p_idx][v_idx][0])
        prescription_2.append(x_d_f[p_idx][v_idx][1])
        prescription_3.append(x_d_f[p_idx][v_idx][2])

In [41]:
print(stats.pointbiserialr(signal_0, prescription_1))
print(stats.pointbiserialr(signal_0, prescription_2))
print(stats.pointbiserialr(signal_0, prescription_3))

print("####")

print(stats.pointbiserialr(signal_1, prescription_1))
print(stats.pointbiserialr(signal_1, prescription_2))
print(stats.pointbiserialr(signal_1, prescription_3))

print("2")

print(stats.pointbiserialr(signal_2, prescription_1))
print(stats.pointbiserialr(signal_2, prescription_2))
print(stats.pointbiserialr(signal_2, prescription_3))


print(stats.pointbiserialr(signal_3, prescription_1))
print(stats.pointbiserialr(signal_3, prescription_2))
print(stats.pointbiserialr(signal_3, prescription_3))

print(stats.pointbiserialr(signal_4, prescription_1))
print(stats.pointbiserialr(signal_4, prescription_2))
print(stats.pointbiserialr(signal_4, prescription_3))

print("5")

print(stats.pointbiserialr(signal_5, prescription_1))
print(stats.pointbiserialr(signal_5, prescription_2))
print(stats.pointbiserialr(signal_5, prescription_3))

print(stats.pointbiserialr(signal_6, prescription_1))
print(stats.pointbiserialr(signal_6, prescription_2))
print(stats.pointbiserialr(signal_6, prescription_3))

PointbiserialrResult(correlation=-0.001484819185381471, pvalue=0.630259029821445)
PointbiserialrResult(correlation=-0.0020092603779998915, pvalue=0.5148007361976643)
PointbiserialrResult(correlation=-0.0015774108329600414, pvalue=0.6090853186923579)
####
PointbiserialrResult(correlation=0.0888419060495691, pvalue=3.947225610981999e-183)
PointbiserialrResult(correlation=-0.07633525716313086, pvalue=1.3600771856671149e-135)
PointbiserialrResult(correlation=0.015506046817262914, pvalue=4.977144126960538e-07)
2
PointbiserialrResult(correlation=-0.0018792043045574814, pvalue=0.5423800304285233)
PointbiserialrResult(correlation=0.04697210730782794, pvalue=2.032045605289511e-52)
PointbiserialrResult(correlation=0.020081780019347795, pvalue=7.467776994708218e-11)
PointbiserialrResult(correlation=0.073917750693716, pvalue=3.0503684347772116e-127)
PointbiserialrResult(correlation=-0.07843558638292139, pvalue=4.55388550783013e-143)
PointbiserialrResult(correlation=0.0010743020260381563, pvalue=0.