In [2]:
import numpy as np
import pandas as pd

from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score, auc
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report
from collections import Counter
import re
from tqdm import trange

In [3]:
pwd = '/home/chujunyi/4_GNN/GAEMDA-miRNA-disease/data/'
disease_id_name = pd.read_csv(pwd + 'disease_name.csv')
mirna_id_name = pd.read_csv(pwd + 'miRNA_name.csv')

In [4]:
def metrics(y_true, y_pred, y_prob):

    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

    pos_acc = tp / sum(y_true)
    neg_acc = tn / (len(y_pred) - sum(y_pred)) # [y_true=0 & y_pred=0] / y_pred=0
    accuracy = (tp+tn)/(tn+fp+fn+tp)
    
    recall = tp / (tp+fn)
    precision = tp / (tp+fp)
    f1 = 2*precision*recall / (precision+recall)
    
    roc_auc = roc_auc_score(y_true, y_prob)
    prec, reca, _ = precision_recall_curve(y_true, y_prob)
    aupr = auc(reca, prec)
    average1 = (accuracy + precision + recall + roc_auc + aupr) / 5
    average2 = (accuracy + f1 + roc_auc + aupr) / 4
    average3 = (f1 + aupr) / 2
    print('tn = {}, fp = {}, fn = {}, tp = {}'.format(tn, fp, fn, tp))
    print('y_pred: 0 = {} | 1 = {}'.format(Counter(y_pred)[0], Counter(y_pred)[1]))
    print('y_true: 0 = {} | 1 = {}'.format(Counter(y_true)[0], Counter(y_true)[1]))
    print('acc={:.4f}|precision={:.4f}|recall={:.4f}|f1={:.4f}|auc={:.4f}|aupr={:.4f}|pos_acc={:.4f}|neg_acc={:.4f}'.format(accuracy, precision, recall, f1, roc_auc, aupr, pos_acc, neg_acc))
    print('{:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}'.format(accuracy, precision, recall, f1, roc_auc, aupr, average1, average2, average3))

In [5]:
def train_test_file(task, balance):
    train_test_id_idx = np.load('/home/chujunyi/4_GNN/GraphSAINT/miRNA_disease_data/task_' + task + balance + '__testlabel0_knn_edge_train_test_index_all.npz', allow_pickle = True)
    train_index_all = train_test_id_idx['train_index_all']
    test_index_all = train_test_id_idx['test_index_all']
    train_id_all = train_test_id_idx['train_id_all'] # 'miRNA', 'disease'
    test_id_all = train_test_id_idx['test_id_all'] # 'miRNA', 'disease'
    return test_index_all, test_id_all, (train_index_all, train_id_all)

In [6]:
def unbalanced_results_file(task, knn, lr, weight, fold): 
    file = np.load('task_' + task + '_unbalanced_' + knn + '_lr' + str(lr) + '_weight' + str(weight) + '_fold' + str(fold) + '.npz')
    y_true_train, y_pred_train, y_prob_train = file['ys_train'][0], file['ys_train'][1], file['ys_train'][2]
    y_true_test, y_pred_test, y_prob_test = file['ys_test'][0], file['ys_test'][1], file['ys_test'][2] 
    
    print('Train:')
    metrics(y_true_train, y_pred_train, y_prob_train)
    print('Test:')
    metrics(y_true_test, y_pred_test, y_prob_test)
    
    return y_true_test, y_pred_test, y_prob_test, (y_true_train, y_pred_train, y_prob_train)

In [7]:
def run_unbalanced_Tp(task, balance, knn, lr, weight):
    test_index_all, test_id_all, _ = train_test_file(task, balance)# '__nobalance'

    for i in range(5):
        print('==== Fold ', i)
        y_true_test, y_pred_test, y_prob_test, _ = unbalanced_results_file(task, knn, lr, weight, fold = i)

        if i == 0:
             y_true_test_all, y_pred_test_all, y_prob_test_all = y_true_test, y_pred_test, y_prob_test
        else:
            y_true_test_all = np.vstack([y_true_test_all, y_true_test])
            y_pred_test_all = np.vstack([y_pred_test_all, y_pred_test])
            y_prob_test_all = np.vstack([y_prob_test_all, y_prob_test])
            assert (y_prob_test_all[i] == y_prob_test).all()

    results_df = pd.DataFrame(test_id_all.reshape(-1, 2), columns = ['miRNA', 'disease'])
    results_df['y_true'] = y_true_test_all.reshape(-1)
    results_df['y_pred'] = y_pred_test_all.reshape(-1)
    results_df['y_prob'] = y_prob_test_all.reshape(-1)

    results_df = pd.merge(results_df, mirna_id_name, left_on = 'miRNA', right_on = 'id')
    results_df = pd.merge(results_df, disease_id_name, left_on = 'disease', right_on = 'id')
    results_df.drop(labels = ['id_x', 'id_y'], axis = 1, inplace = True)
    results_df.sort_values(by = ['disease_x', 'y_prob'], ascending = False, inplace = True)
    
    results_df.to_csv(task + '_unbalanced_case_study_0.csv')
    
    return results_df

In [8]:
def run_unbalanced_Tmd(task, balance, knn, lr, weight):
    dtp = pd.read_csv('/home/chujunyi/4_GNN/GAEMDA-miRNA-disease/data/all_mirna_disease_pairs.csv', names=['miRNA', 'disease', 'label'])
    test_index_all, test_id_all, _ = train_test_file(task, balance)# '__nobalance'

    for i in range(5):
        print('==== Fold ', i)
        y_true_test, y_pred_test, y_prob_test, _ = unbalanced_results_file(task, knn, lr, weight, fold = i)

        temp = dtp.iloc[test_index_all[i]][['miRNA', 'disease']]
        if i == 0:
            y_true_test_all, y_pred_test_all, y_prob_test_all = y_true_test, y_pred_test, y_prob_test
            
            results_df = temp
        else:
            y_true_test_all = np.hstack([y_true_test_all, y_true_test])
            y_pred_test_all = np.hstack([y_pred_test_all, y_pred_test])
            y_prob_test_all = np.hstack([y_prob_test_all, y_prob_test])
            
            results_df = pd.concat([results_df, temp], axis = 0)
            
    results_df['y_true'] = y_true_test_all.reshape(-1)
    results_df['y_pred'] = y_pred_test_all.reshape(-1)
    results_df['y_prob'] = y_prob_test_all.reshape(-1)

    results_df = pd.merge(results_df, mirna_id_name, left_on = 'miRNA', right_on = 'id')
    results_df = pd.merge(results_df, disease_id_name, left_on = 'disease', right_on = 'id')
    results_df.drop(labels = ['id_x', 'id_y'], axis = 1, inplace = True)
    results_df.sort_values(by = ['disease_x', 'y_prob'], ascending = False, inplace = True)
    
    results_df.to_csv(task + '_unbalanced_case_study_0.csv')
    
    return results_df

# Run unbalanced

In [20]:
results_Tp_unbalanced = run_unbalanced_Tp(task = 'Tp', balance = '__nobalance', knn = '15knn', lr = 0.001, weight = 10)
results_Tp_unbalanced

==== Fold  0
Train:
tn = 6947, fp = 43, fn = 9, tp = 206
y_pred: 0 = 6956 | 1 = 249
y_true: 0 = 6990 | 1 = 215
acc=0.9928|precision=0.8273|recall=0.9581|f1=0.8879|auc=0.9990|aupr=0.9697|pos_acc=0.9581|neg_acc=0.9987
0.9928, 0.8273, 0.9581, 0.8879, 0.9990, 0.9697, 0.9494, 0.9624, 0.9288
Test:
tn = 36163, fp = 699, fn = 365, tp = 690
y_pred: 0 = 36528 | 1 = 1389
y_true: 0 = 36862 | 1 = 1055
acc=0.9719|precision=0.4968|recall=0.6540|f1=0.5646|auc=0.9476|aupr=0.6046|pos_acc=0.6540|neg_acc=0.9900
0.9719, 0.4968, 0.6540, 0.5646, 0.9476, 0.6046, 0.7350, 0.7722, 0.5846
==== Fold  1
Train:
tn = 6946, fp = 46, fn = 1, tp = 199
y_pred: 0 = 6947 | 1 = 245
y_true: 0 = 6992 | 1 = 200
acc=0.9935|precision=0.8122|recall=0.9950|f1=0.8944|auc=0.9996|aupr=0.9857|pos_acc=0.9950|neg_acc=0.9999
0.9935, 0.8122, 0.9950, 0.8944, 0.9996, 0.9857, 0.9572, 0.9683, 0.9400
Test:
tn = 36247, fp = 642, fn = 373, tp = 655
y_pred: 0 = 36620 | 1 = 1297
y_true: 0 = 36889 | 1 = 1028
acc=0.9732|precision=0.5050|recall=0.637

Unnamed: 0,miRNA_x,disease_x,y_true,y_pred,y_prob,miRNA_y,disease_y
155938,14,383,1.0,1.0,0.999916,hsa-mir-21,['Wounds and Injuries']
155948,24,383,0.0,1.0,0.997668,hsa-mir-126,['Wounds and Injuries']
155934,10,383,1.0,1.0,0.708075,hsa-mir-145,['Wounds and Injuries']
155947,23,383,0.0,0.0,0.268306,hsa-mir-10b,['Wounds and Injuries']
156343,419,383,0.0,0.0,0.197201,hsa-mir-641,['Wounds and Injuries']
...,...,...,...,...,...,...,...
35953,314,1,0.0,0.0,0.000004,hsa-mir-330,"['Abortion, Habitual']"
35908,269,1,0.0,0.0,0.000004,hsa-mir-526a,"['Abortion, Habitual']"
35808,169,1,0.0,0.0,0.000004,hsa-mir-497,"['Abortion, Habitual']"
36022,383,1,0.0,0.0,0.000003,hsa-mir-520g,"['Abortion, Habitual']"


In [12]:
results_Tm_unbalanced = run_unbalanced_Tmd(task = 'Tm', balance = '__nobalance', knn = '15knn', lr = 0.01, weight = 10)
results_Tm_unbalanced

==== Fold  0
Train:
tn = 6958, fp = 55, fn = 7, tp = 190
y_pred: 0 = 6965 | 1 = 245
y_true: 0 = 7013 | 1 = 197
acc=0.9914|precision=0.7755|recall=0.9645|f1=0.8597|auc=0.9990|aupr=0.9703|pos_acc=0.9645|neg_acc=0.9990
0.9914, 0.7755, 0.9645, 0.8597, 0.9990, 0.9703, 0.9401, 0.9551, 0.9150
Test:
tn = 36190, fp = 752, fn = 327, tp = 648
y_pred: 0 = 36517 | 1 = 1400
y_true: 0 = 36942 | 1 = 975
acc=0.9715|precision=0.4629|recall=0.6646|f1=0.5457|auc=0.9502|aupr=0.6141|pos_acc=0.6646|neg_acc=0.9910
0.9715, 0.4629, 0.6646, 0.5457, 0.9502, 0.6141, 0.7327, 0.7704, 0.5799
==== Fold  1
Train:
tn = 6899, fp = 55, fn = 0, tp = 226
y_pred: 0 = 6899 | 1 = 281
y_true: 0 = 6954 | 1 = 226
acc=0.9923|precision=0.8043|recall=1.0000|f1=0.8915|auc=0.9997|aupr=0.9909|pos_acc=1.0000|neg_acc=1.0000
0.9923, 0.8043, 1.0000, 0.8915, 0.9997, 0.9909, 0.9574, 0.9686, 0.9412
Test:
tn = 36603, fp = 470, fn = 308, tp = 536
y_pred: 0 = 36911 | 1 = 1006
y_true: 0 = 37073 | 1 = 844
acc=0.9795|precision=0.5328|recall=0.6351|

Unnamed: 0,miRNA_x,disease_x,y_true,y_pred,y_prob,miRNA_y,disease_y
76649,308,383,0.0,1.0,6.435882e-01,hsa-mir-208a,['Wounds and Injuries']
76272,65,383,0.0,1.0,5.558605e-01,hsa-mir-125b,['Wounds and Injuries']
76542,138,383,0.0,1.0,5.408920e-01,hsa-mir-199a,['Wounds and Injuries']
76628,47,383,0.0,0.0,4.035085e-01,hsa-mir-223,['Wounds and Injuries']
76265,29,383,0.0,0.0,3.324498e-01,hsa-mir-181b,['Wounds and Injuries']
...,...,...,...,...,...,...,...
70028,212,1,0.0,0.0,1.406977e-07,hsa-mir-1471,"['Abortion, Habitual']"
70078,302,1,0.0,0.0,1.170808e-07,hsa-mir-1271,"['Abortion, Habitual']"
70084,385,1,0.0,0.0,1.085159e-07,hsa-mir-873,"['Abortion, Habitual']"
70015,107,1,0.0,0.0,1.038046e-07,hsa-mir-744,"['Abortion, Habitual']"


In [22]:
results_Td_unbalanced = run_unbalanced_Tmd(task = 'Td', balance = '__nobalance', knn = '7knn', lr = 0.0001, weight = 10)
results_Td_unbalanced

==== Fold  0
Train:
tn = 6475, fp = 68, fn = 5, tp = 197
y_pred: 0 = 6480 | 1 = 265
y_true: 0 = 6543 | 1 = 202
acc=0.9892|precision=0.7434|recall=0.9752|f1=0.8437|auc=0.9990|aupr=0.9668|pos_acc=0.9752|neg_acc=0.9992
0.9892, 0.7434, 0.9752, 0.8437, 0.9990, 0.9668, 0.9347, 0.9497, 0.9053
Test:
tn = 34794, fp = 1648, fn = 422, tp = 756
y_pred: 0 = 35216 | 1 = 2404
y_true: 0 = 36442 | 1 = 1178
acc=0.9450|precision=0.3145|recall=0.6418|f1=0.4221|auc=0.9269|aupr=0.4593|pos_acc=0.6418|neg_acc=0.9880
0.9450, 0.3145, 0.6418, 0.4221, 0.9269, 0.4593, 0.6575, 0.6883, 0.4407
==== Fold  1
Train:
tn = 6510, fp = 65, fn = 1, tp = 207
y_pred: 0 = 6511 | 1 = 272
y_true: 0 = 6575 | 1 = 208
acc=0.9903|precision=0.7610|recall=0.9952|f1=0.8625|auc=0.9991|aupr=0.9713|pos_acc=0.9952|neg_acc=0.9998
0.9903, 0.7610, 0.9952, 0.8625, 0.9991, 0.9713, 0.9434, 0.9558, 0.9169
Test:
tn = 35115, fp = 1677, fn = 330, tp = 498
y_pred: 0 = 35445 | 1 = 2175
y_true: 0 = 36792 | 1 = 828
acc=0.9467|precision=0.2290|recall=0.60

Unnamed: 0,miRNA_x,disease_x,y_true,y_pred,y_prob,miRNA_y,disease_y
129262,488,383,1.0,1.0,0.881445,hsa-mir-938,['Wounds and Injuries']
129261,176,383,1.0,1.0,0.881387,hsa-mir-193a,['Wounds and Injuries']
129683,389,383,1.0,1.0,0.879987,hsa-mir-1322,['Wounds and Injuries']
129655,120,383,1.0,1.0,0.879839,hsa-mir-106b,['Wounds and Injuries']
129438,75,383,1.0,1.0,0.879612,hsa-mir-19b,['Wounds and Injuries']
...,...,...,...,...,...,...,...
134536,101,1,0.0,0.0,0.120544,hsa-mir-619,"['Abortion, Habitual']"
134145,438,1,0.0,0.0,0.120346,hsa-mir-511,"['Abortion, Habitual']"
134235,479,1,0.0,0.0,0.120238,hsa-mir-518f,"['Abortion, Habitual']"
134275,68,1,0.0,0.0,0.120039,hsa-mir-146b,"['Abortion, Habitual']"
