#### prediction performance of Enformer

In [1]:
import numpy as np
from tensorflow.keras.utils import to_categorical
from sklearn.metrics import precision_score, accuracy_score,recall_score, f1_score
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.metrics import roc_curve, auc
from prettytable import PrettyTable

In [2]:
result_path = '../../model/pred_results_tenfold_enformer/'

In [3]:
model_size = 'small'
acc_list = []
precision_list = []
recall_list = []
f1_list = []
auc_list = []
for f in range(10):
    fold = str(f+1)
    label = np.load(result_path + model_size + '_split' + fold + '_label.npy')
    y_score = np.load(result_path + model_size +'_split' + fold + '_score.npy')
    y_score_pro = np.load(result_path + model_size + '_split' + fold + '_score_pro.npy')
    y_one_hot = to_categorical(label)
    y_score_one_hot = to_categorical(y_score)

    acc = np.round(accuracy_score(label, y_score),3)
    precision = np.round(precision_score(label, y_score),3)
    recall = np.round(recall_score(label, y_score),3)
    f1 = np.round(f1_score(label, y_score),3)
    fpr, tpr, thresholds = roc_curve(y_one_hot.ravel(),y_score_pro.ravel()) 
    auc_ = np.round(auc(fpr, tpr),3)

    table = PrettyTable(['ACC','Precision','Recall','F1-score','AUC'])
    table.add_row([acc,precision,recall,f1,auc_])
    print(table)

    acc_list.append(acc)
    precision_list.append(precision)
    recall_list.append(recall)
    f1_list.append(f1)
    auc_list.append(auc_)

print('(mean) ACC: ', np.mean(acc_list), 'Precision: ', np.mean(precision_list), 'Recall: ', np.mean(recall_list), 'F1: ', np.mean(f1_list), 'AUC: ', np.mean(auc_list))
print('(std) ACC: ', np.std(acc_list), 'Precision: ', np.std(precision_list), 'Recall: ', np.std(recall_list), 'F1: ', np.std(f1_list), 'AUC: ', np.std(auc_list))

+-------+-----------+--------+----------+------+
|  ACC  | Precision | Recall | F1-score | AUC  |
+-------+-----------+--------+----------+------+
| 0.895 |   0.914   | 0.878  |  0.896   | 0.96 |
+-------+-----------+--------+----------+------+
+-------+-----------+--------+----------+-------+
|  ACC  | Precision | Recall | F1-score |  AUC  |
+-------+-----------+--------+----------+-------+
| 0.901 |   0.919   | 0.886  |  0.902   | 0.963 |
+-------+-----------+--------+----------+-------+
+-------+-----------+--------+----------+-------+
|  ACC  | Precision | Recall | F1-score |  AUC  |
+-------+-----------+--------+----------+-------+
| 0.884 |   0.893   | 0.873  |  0.883   | 0.951 |
+-------+-----------+--------+----------+-------+
+-------+-----------+--------+----------+-------+
|  ACC  | Precision | Recall | F1-score |  AUC  |
+-------+-----------+--------+----------+-------+
| 0.895 |   0.908   | 0.878  |  0.893   | 0.964 |
+-------+-----------+--------+----------+-------+
+----

In [4]:
model_size = 'large'
acc_list = []
precision_list = []
recall_list = []
f1_list = []
auc_list = []
for f in range(10):
    fold = str(f+1)
    label = np.load(result_path + model_size + '_split' + fold + '_label.npy')
    y_score = np.load(result_path + model_size +'_split' + fold + '_score.npy')
    y_score_pro = np.load(result_path + model_size + '_split' + fold + '_score_pro.npy')
    y_one_hot = to_categorical(label)
    y_score_one_hot = to_categorical(y_score)

    acc = np.round(accuracy_score(label, y_score),3)
    precision = np.round(precision_score(label, y_score),3)
    recall = np.round(recall_score(label, y_score),3)
    f1 = np.round(f1_score(label, y_score),3)
    fpr, tpr, thresholds = roc_curve(y_one_hot.ravel(),y_score_pro.ravel()) 
    auc_ = np.round(auc(fpr, tpr),3)

    table = PrettyTable(['ACC','Precision','Recall','F1-score','AUC'])
    table.add_row([acc,precision,recall,f1,auc_])
    print(table)

    acc_list.append(acc)
    precision_list.append(precision)
    recall_list.append(recall)
    f1_list.append(f1)
    auc_list.append(auc_)

print('(mean) ACC: ', np.mean(acc_list), 'Precision: ', np.mean(precision_list), 'Recall: ', np.mean(recall_list), 'F1: ', np.mean(f1_list), 'AUC: ', np.mean(auc_list))
print('(std) ACC: ', np.std(acc_list), 'Precision: ', np.std(precision_list), 'Recall: ', np.std(recall_list), 'F1: ', np.std(f1_list), 'AUC: ', np.std(auc_list))

+-------+-----------+--------+----------+-------+
|  ACC  | Precision | Recall | F1-score |  AUC  |
+-------+-----------+--------+----------+-------+
| 0.829 |   0.849   | 0.807  |  0.828   | 0.911 |
+-------+-----------+--------+----------+-------+
+-------+-----------+--------+----------+-------+
|  ACC  | Precision | Recall | F1-score |  AUC  |
+-------+-----------+--------+----------+-------+
| 0.823 |   0.838   | 0.815  |  0.826   | 0.911 |
+-------+-----------+--------+----------+-------+
+-------+-----------+--------+----------+-------+
|  ACC  | Precision | Recall | F1-score |  AUC  |
+-------+-----------+--------+----------+-------+
| 0.837 |   0.829   |  0.83  |  0.829   | 0.913 |
+-------+-----------+--------+----------+-------+
+-------+-----------+--------+----------+-------+
|  ACC  | Precision | Recall | F1-score |  AUC  |
+-------+-----------+--------+----------+-------+
| 0.834 |   0.845   | 0.816  |   0.83   | 0.918 |
+-------+-----------+--------+----------+-------+
