In [82]:
import numpy as np
import pandas as pd

from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score, auc
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report
from collections import Counter
import re
from tqdm import trange

In [2]:
pwd = '/home/chujunyi/4_GNN/GAEMDA-miRNA-disease/data/'
disease_id_name = pd.read_csv(pwd + 'disease_name.csv')
mirna_id_name = pd.read_csv(pwd + 'miRNA_name.csv')

In [3]:
def metrics(y_true, y_pred, y_prob):

    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

    pos_acc = tp / sum(y_true)
    neg_acc = tn / (len(y_pred) - sum(y_pred)) # [y_true=0 & y_pred=0] / y_pred=0
    accuracy = (tp+tn)/(tn+fp+fn+tp)
    
    recall = tp / (tp+fn)
    precision = tp / (tp+fp)
    f1 = 2*precision*recall / (precision+recall)
    
    roc_auc = roc_auc_score(y_true, y_prob)
    prec, reca, _ = precision_recall_curve(y_true, y_prob)
    aupr = auc(reca, prec)
    average1 = (accuracy + precision + recall + roc_auc + aupr) / 5
    average2 = (accuracy + f1 + roc_auc + aupr) / 4
    average3 = (f1 + aupr) / 2
    print('tn = {}, fp = {}, fn = {}, tp = {}'.format(tn, fp, fn, tp))
    print('y_pred: 0 = {} | 1 = {}'.format(Counter(y_pred)[0], Counter(y_pred)[1]))
    print('y_true: 0 = {} | 1 = {}'.format(Counter(y_true)[0], Counter(y_true)[1]))
    print('acc={:.4f}|precision={:.4f}|recall={:.4f}|f1={:.4f}|auc={:.4f}|aupr={:.4f}|pos_acc={:.4f}|neg_acc={:.4f}'.format(accuracy, precision, recall, f1, roc_auc, aupr, pos_acc, neg_acc))
    print('{:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}'.format(accuracy, precision, recall, f1, roc_auc, aupr, average1, average2, average3))

In [4]:
def train_test_file(task, balance):
    train_test_id_idx = np.load('/home/chujunyi/4_GNN/GraphSAINT/miRNA_disease_data/task_' + task + balance + '__testlabel0_knn_edge_train_test_index_all.npz', allow_pickle = True)
    train_index_all = train_test_id_idx['train_index_all']
    test_index_all = train_test_id_idx['test_index_all']
    train_id_all = train_test_id_idx['train_id_all'] # 'miRNA', 'disease'
    test_id_all = train_test_id_idx['test_id_all'] # 'miRNA', 'disease'
    return test_index_all, test_id_all, (train_index_all, train_id_all)

In [5]:
def balanced_results_file(task, knn, lr, fold): #weight = None
    file = np.load('task_' + task + '_balanced_' + knn + '_lr' + str(lr) + '_fold' + str(fold) + '.npz')
    y_true_train, y_pred_train, y_prob_train = file['ys_train'][0], file['ys_train'][1], file['ys_train'][2]
    y_true_test, y_pred_test, y_prob_test = file['ys_test'][0], file['ys_test'][1], file['ys_test'][2] 
    
    print('Train:')
    metrics(y_true_train, y_pred_train, y_prob_train)
    print('Test:')
    metrics(y_true_test, y_pred_test, y_prob_test)
    
    return y_true_test, y_pred_test, y_prob_test, (y_true_train, y_pred_train, y_prob_train)

In [8]:
def sample(random_seed):
    all_associations = pd.read_csv('/home/chujunyi/4_GNN/GAEMDA-miRNA-disease/data/all_mirna_disease_pairs.csv', names=['miRNA', 'disease', 'label'])
    known_associations = all_associations.loc[all_associations['label'] == 1]
    unknown_associations = all_associations.loc[all_associations['label'] == 0]
    random_negative = unknown_associations.sample(n=known_associations.shape[0], random_state=random_seed, axis=0)

    sample_df = known_associations.append(random_negative)
    sample_df.reset_index(drop=True, inplace=True)

    return sample_df

In [6]:
def run_balanced_Tp(task, balance, knn, lr):
    test_index_all, test_id_all, _ = train_test_file(task, balance)# '__nobalance'

    for i in range(5):
        print('==== Fold ', i)
        y_true_test, y_pred_test, y_prob_test, _ = balanced_results_file(task, knn, lr, fold = i)

        if i == 0:
             y_true_test_all, y_pred_test_all, y_prob_test_all = y_true_test, y_pred_test, y_prob_test
        else:
            y_true_test_all = np.vstack([y_true_test_all, y_true_test])
            y_pred_test_all = np.vstack([y_pred_test_all, y_pred_test])
            y_prob_test_all = np.vstack([y_prob_test_all, y_prob_test])
            assert (y_prob_test_all[i] == y_prob_test).all()

    results_df = pd.DataFrame(test_id_all.reshape(-1, 2), columns = ['miRNA', 'disease'])
    results_df['y_true'] = y_true_test_all.reshape(-1)
    results_df['y_pred'] = y_pred_test_all.reshape(-1)
    results_df['y_prob'] = y_prob_test_all.reshape(-1)

    results_df = pd.merge(results_df, mirna_id_name, left_on = 'miRNA', right_on = 'id')
    results_df = pd.merge(results_df, disease_id_name, left_on = 'disease', right_on = 'id')
    results_df.drop(labels = ['id_x', 'id_y'], axis = 1, inplace = True)
    results_df.sort_values(by = ['disease_x', 'y_prob'], ascending = False, inplace = True)
    
    results_df.to_csv(task + '_balanced_case_study_0.csv')
    
    return results_df

In [9]:
def run_balanced_Tmd(task, balance, knn, lr):
    dtp = sample(random_seed = 1234)
    test_index_all, test_id_all, _ = train_test_file(task, balance)# '__nobalance'

    for i in range(5):
        print('==== Fold ', i)
        y_true_test, y_pred_test, y_prob_test, _ = balanced_results_file(task, knn, lr, fold = i)

        temp = dtp.iloc[test_index_all[i]][['miRNA', 'disease']]
        if i == 0:
            y_true_test_all, y_pred_test_all, y_prob_test_all = y_true_test, y_pred_test, y_prob_test
            
            results_df = temp
        else:
            y_true_test_all = np.hstack([y_true_test_all, y_true_test])
            y_pred_test_all = np.hstack([y_pred_test_all, y_pred_test])
            y_prob_test_all = np.hstack([y_prob_test_all, y_prob_test])
            
            results_df = pd.concat([results_df, temp], axis = 0)
            
    results_df['y_true'] = y_true_test_all.reshape(-1)
    results_df['y_pred'] = y_pred_test_all.reshape(-1)
    results_df['y_prob'] = y_prob_test_all.reshape(-1)

    results_df = pd.merge(results_df, mirna_id_name, left_on = 'miRNA', right_on = 'id')
    results_df = pd.merge(results_df, disease_id_name, left_on = 'disease', right_on = 'id')
    results_df.drop(labels = ['id_x', 'id_y'], axis = 1, inplace = True)
    results_df.sort_values(by = ['disease_x', 'y_prob'], ascending = False, inplace = True)
    
    results_df.to_csv(task + '_balanced_case_study_0.csv')
    
    return results_df

# Run balanced

In [7]:
results_Tp_balanced = run_balanced_Tp(task = 'Tp', balance = '', knn = '10knn', lr = 0.001)
results_Tp_balanced

==== Fold  0
Train:
tn = 2366, fp = 5, fn = 20, tp = 2382
y_pred: 0 = 2386 | 1 = 2387
y_true: 0 = 2371 | 1 = 2402
acc=0.9948|precision=0.9979|recall=0.9917|f1=0.9948|auc=0.9999|aupr=0.9999|pos_acc=0.9917|neg_acc=0.9916
0.9948, 0.9979, 0.9917, 0.9948, 0.9999, 0.9999, 0.9968, 0.9973, 0.9973
Test:
tn = 1080, fp = 14, fn = 26, tp = 1052
y_pred: 0 = 1106 | 1 = 1066
y_true: 0 = 1094 | 1 = 1078
acc=0.9816|precision=0.9869|recall=0.9759|f1=0.9813|auc=0.9972|aupr=0.9976|pos_acc=0.9759|neg_acc=0.9765
0.9816, 0.9869, 0.9759, 0.9813, 0.9972, 0.9976, 0.9878, 0.9894, 0.9895
==== Fold  1
Train:
tn = 2424, fp = 3, fn = 7, tp = 2346
y_pred: 0 = 2431 | 1 = 2349
y_true: 0 = 2427 | 1 = 2353
acc=0.9979|precision=0.9987|recall=0.9970|f1=0.9979|auc=1.0000|aupr=1.0000|pos_acc=0.9970|neg_acc=0.9971
0.9979, 0.9987, 0.9970, 0.9979, 1.0000, 1.0000, 0.9987, 0.9989, 0.9989
Test:
tn = 1075, fp = 21, fn = 20, tp = 1056
y_pred: 0 = 1095 | 1 = 1077
y_true: 0 = 1096 | 1 = 1076
acc=0.9811|precision=0.9805|recall=0.9814|f

Unnamed: 0,miRNA_x,disease_x,y_true,y_pred,y_prob,miRNA_y,disease_y
7250,14,383,1.0,1.0,0.999627,hsa-mir-21,['Wounds and Injuries']
7253,116,383,1.0,1.0,0.999554,hsa-mir-483,['Wounds and Injuries']
7249,10,383,1.0,1.0,0.999477,hsa-mir-145,['Wounds and Injuries']
7258,148,383,1.0,1.0,0.996552,hsa-mir-9,['Wounds and Injuries']
7248,9,383,1.0,1.0,0.996511,hsa-mir-143,['Wounds and Injuries']
...,...,...,...,...,...,...,...
12,327,1,0.0,0.0,0.000531,hsa-mir-512,"['Abortion, Habitual']"
10,164,1,0.0,0.0,0.000500,hsa-let-7i,"['Abortion, Habitual']"
9,108,1,0.0,0.0,0.000489,hsa-mir-942,"['Abortion, Habitual']"
13,371,1,0.0,0.0,0.000476,hsa-mir-1275,"['Abortion, Habitual']"


In [10]:
results_Tm_balanced = run_balanced_Tmd(task = 'Tm', balance = '', knn = '7knn', lr = 0.01)
results_Tm_balanced

==== Fold  0
Train:
tn = 2283, fp = 9, fn = 16, tp = 2349
y_pred: 0 = 2299 | 1 = 2358
y_true: 0 = 2292 | 1 = 2365
acc=0.9946|precision=0.9962|recall=0.9932|f1=0.9947|auc=0.9999|aupr=0.9999|pos_acc=0.9932|neg_acc=0.9930
0.9946, 0.9962, 0.9932, 0.9947, 0.9999, 0.9999, 0.9968, 0.9973, 0.9973
Test:
tn = 1093, fp = 10, fn = 32, tp = 887
y_pred: 0 = 1125 | 1 = 897
y_true: 0 = 1103 | 1 = 919
acc=0.9792|precision=0.9889|recall=0.9652|f1=0.9769|auc=0.9941|aupr=0.9945|pos_acc=0.9652|neg_acc=0.9716
0.9792, 0.9889, 0.9652, 0.9769, 0.9941, 0.9945, 0.9844, 0.9862, 0.9857
==== Fold  1
Train:
tn = 2292, fp = 5, fn = 9, tp = 2369
y_pred: 0 = 2301 | 1 = 2374
y_true: 0 = 2297 | 1 = 2378
acc=0.9970|precision=0.9979|recall=0.9962|f1=0.9971|auc=1.0000|aupr=1.0000|pos_acc=0.9962|neg_acc=0.9961
0.9970, 0.9979, 0.9962, 0.9971, 1.0000, 1.0000, 0.9982, 0.9985, 0.9985
Test:
tn = 1097, fp = 7, fn = 72, tp = 850
y_pred: 0 = 1169 | 1 = 857
y_true: 0 = 1104 | 1 = 922
acc=0.9610|precision=0.9918|recall=0.9219|f1=0.955

Unnamed: 0,miRNA_x,disease_x,y_true,y_pred,y_prob,miRNA_y,disease_y
9224,148,383,1.0,1.0,0.999978,hsa-mir-9,['Wounds and Injuries']
9226,9,383,1.0,1.0,0.999975,hsa-mir-143,['Wounds and Injuries']
9223,116,383,1.0,1.0,0.999960,hsa-mir-483,['Wounds and Injuries']
9231,10,383,1.0,1.0,0.999892,hsa-mir-145,['Wounds and Injuries']
9237,14,383,1.0,1.0,0.999791,hsa-mir-21,['Wounds and Injuries']
...,...,...,...,...,...,...,...
1402,24,1,0.0,0.0,0.000030,hsa-mir-126,"['Abortion, Habitual']"
1393,108,1,0.0,0.0,0.000026,hsa-mir-942,"['Abortion, Habitual']"
1390,164,1,0.0,0.0,0.000025,hsa-let-7i,"['Abortion, Habitual']"
1405,327,1,0.0,0.0,0.000025,hsa-mir-512,"['Abortion, Habitual']"


In [12]:
results_Td_balanced = run_balanced_Tmd(task = 'Td', balance = '', knn = '5knn', lr = 0.001)
results_Td_balanced

==== Fold  0
Train:
tn = 2249, fp = 4, fn = 10, tp = 2162
y_pred: 0 = 2259 | 1 = 2166
y_true: 0 = 2253 | 1 = 2172
acc=0.9968|precision=0.9982|recall=0.9954|f1=0.9968|auc=1.0000|aupr=1.0000|pos_acc=0.9954|neg_acc=0.9956
0.9968, 0.9982, 0.9954, 0.9968, 1.0000, 1.0000, 0.9981, 0.9984, 0.9984
Test:
tn = 1078, fp = 12, fn = 25, tp = 1119
y_pred: 0 = 1103 | 1 = 1131
y_true: 0 = 1090 | 1 = 1144
acc=0.9834|precision=0.9894|recall=0.9781|f1=0.9837|auc=0.9974|aupr=0.9979|pos_acc=0.9781|neg_acc=0.9773
0.9834, 0.9894, 0.9781, 0.9837, 0.9974, 0.9979, 0.9893, 0.9906, 0.9908
==== Fold  1
Train:
tn = 2187, fp = 2, fn = 13, tp = 2354
y_pred: 0 = 2200 | 1 = 2356
y_true: 0 = 2189 | 1 = 2367
acc=0.9967|precision=0.9992|recall=0.9945|f1=0.9968|auc=1.0000|aupr=1.0000|pos_acc=0.9945|neg_acc=0.9941
0.9967, 0.9992, 0.9945, 0.9968, 1.0000, 1.0000, 0.9981, 0.9984, 0.9984
Test:
tn = 997, fp = 26, fn = 13, tp = 722
y_pred: 0 = 1010 | 1 = 748
y_true: 0 = 1023 | 1 = 735
acc=0.9778|precision=0.9652|recall=0.9823|f1=0

Unnamed: 0,miRNA_x,disease_x,y_true,y_pred,y_prob,miRNA_y,disease_y
8048,116,383,1.0,1.0,0.999104,hsa-mir-483,['Wounds and Injuries']
8049,9,383,1.0,1.0,0.999097,hsa-mir-143,['Wounds and Injuries']
8037,148,383,1.0,1.0,0.999097,hsa-mir-9,['Wounds and Injuries']
8036,14,383,1.0,1.0,0.999079,hsa-mir-21,['Wounds and Injuries']
8034,10,383,1.0,1.0,0.998415,hsa-mir-145,['Wounds and Injuries']
...,...,...,...,...,...,...,...
8554,482,1,0.0,0.0,0.001152,hsa-mir-570,"['Abortion, Habitual']"
8543,24,1,0.0,0.0,0.001151,hsa-mir-126,"['Abortion, Habitual']"
8551,164,1,0.0,0.0,0.001149,hsa-let-7i,"['Abortion, Habitual']"
8553,454,1,0.0,0.0,0.001148,hsa-mir-767,"['Abortion, Habitual']"
