In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_curve, auc
from itertools import cycle
import os
from tqdm import tqdm
from code.models.RsFPN import Res_FPN
import pandas as pd

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [None]:
# Save the prediction
def save_predict_result(data, output):
    with open(output, 'w') as f:
        if len(data) > 1:
            for i in range(len(data)):
                f.write('# result for fold %d\n' % (i + 1))
                for j in range(len(data[i])):
                    f.write('%d\t%s\n' % (data[i][j][0], data[i][j][2]))
        else:
            for i in range(len(data)):
                f.write('# result for predict\n')
                for j in range(len(data[i])):
                    f.write('%d\t%s\n' % (data[i][j][0], data[i][j][2]))
        f.close()
    return None


# Plot the ROC curve and return the AUC value
def plot_roc_curve(data, output, label_column=0, score_column=2):
    datasize = len(data)
    tprs = []
    aucs = []
    fprArray = []
    tprArray = []
    thresholdsArray = []
    mean_fpr = np.linspace(0, 1, 100)
    for i in range(len(data)):
        fpr, tpr, thresholds = roc_curve(data[i][:, label_column], data[i][:, score_column])
        fprArray.append(fpr)
        tprArray.append(tpr)
        thresholdsArray.append(thresholds)
        tprs.append(np.interp(mean_fpr, fpr, tpr))
        tprs[-1][0] = 0.0
        roc_auc = auc(fpr, tpr)
        aucs.append(roc_auc)
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'blueviolet', 'deeppink'])
    plt.figure(figsize=(7, 7), dpi=300)
    for i, color in zip(range(len(fprArray)), colors):
        if datasize > 1:
            plt.plot(fprArray[i], tprArray[i], lw=1, alpha=0.7, color=color,
                     label='ROC fold %d (AUC = %0.4f)' % (i + 1, aucs[i]))
        else:
            plt.plot(fprArray[i], tprArray[i], lw=1, alpha=0.7, color=color,
                     label='ROC (AUC = %0.4f)' % aucs[i])
    plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
             label='Random', alpha=.8)
    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = auc(mean_fpr, mean_tpr)
    # Calculate the standard deviation
    std_auc = np.std(aucs)
    if datasize > 1:
        plt.plot(mean_fpr, mean_tpr, color='blue',
                 label=r'Mean ROC (AUC = %0.4f $\pm$ %0.3f)' % (mean_auc, std_auc),
                 lw=2, alpha=.9)
    std_tpr = np.std(tprs, axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
    if datasize > 1:
        plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
                         label=r'$\pm$ 1 std. dev.')
    plt.xlim([0, 1.0])
    plt.ylim([0, 1.0])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.legend(loc="lower right")
    plt.savefig(output)
    plt.close(0)
    return mean_auc, aucs


# Calculate and save performance metrics
def calculate_metrics(labels, scores, cutoff=0.5, po_label=1):
    my_metrics = {
        'SN': 'NA',
        'SP': 'NA',
        'ACC': 'NA',
        'MCC': 'NA',
        'Recall': 'NA',
        'Precision': 'NA',
        'F1-score': 'NA',
        'Cutoff': cutoff,
    }

    tp, tn, fp, fn = 0, 0, 0, 0
    for i in range(len(scores)):
        if labels[i] == po_label:
            if scores[i] >= cutoff:
                tp = tp + 1
            else:
                fn = fn + 1
        else:
            if scores[i] < cutoff:
                tn = tn + 1
            else:
                fp = fp + 1

    my_metrics['SN'] = tp / (tp + fn) if (tp + fn) != 0 else 'NA'
    my_metrics['SP'] = tn / (fp + tn) if (fp + tn) != 0 else 'NA'
    my_metrics['ACC'] = (tp + tn) / (tp + fn + tn + fp)
    my_metrics['MCC'] = (tp * tn - fp * fn) / np.math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) \
        if (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) != 0 else 'NA'
    my_metrics['Precision'] = tp / (tp + fp) if (tp + fp) != 0 else 'NA'
    my_metrics['Recall'] = my_metrics['SN']
    my_metrics['F1-score'] = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) != 0 else 'NA'
    return my_metrics


def calculate_metrics_list(data, label_column=0, score_column=2, cutoff=0.5, po_label=1):
    metrics_list = []
    for i in data:
        metrics_list.append(calculate_metrics(i[:, label_column], i[:, score_column], cutoff=cutoff, po_label=po_label))
    if len(metrics_list) == 1:
        return metrics_list
    else:
        mean_dict = {}
        std_dict = {}
        keys = metrics_list[0].keys()
        for i in keys:
            mean_list = []
            for metric in metrics_list:
                mean_list.append(metric[i])
            mean_dict[i] = np.array(mean_list).sum() / len(metrics_list)
            std_dict[i] = np.array(mean_list).std()
        metrics_list.append(mean_dict)
        metrics_list.append(std_dict)
        return metrics_list


def save_prediction_metrics_list(metrics_list, output):
    if len(metrics_list) == 1:
        with open(output, 'w') as f:
            f.write('Result')
            for keys in metrics_list[0]:
                f.write('\t%s' % keys)
            f.write('\n')
            for i in range(len(metrics_list)):
                f.write('value')
                for keys in metrics_list[i]:
                    f.write('\t%s' % metrics_list[i][keys])
                f.write('\n')
            f.close()
    else:
        with open(output, 'w') as f:
            f.write('Fold')
            for keys in metrics_list[0]:
                f.write('\t%s' % keys)
            f.write('\n')
            for i in range(len(metrics_list)):
                if i <= len(metrics_list) - 3:
                    f.write('%d' % (i + 1))
                elif i == len(metrics_list) - 2:
                    f.write('mean')
                else:
                    f.write('std')
                for keys in metrics_list[i]:
                    f.write('\t%s' % metrics_list[i][keys])
                f.write('\n')
            f.close()
    return None


# Fixed SP value, computing performance
def fixed_sp_calculate_metrics_list(data, cutoffs, label_column=0, score_column=1, po_label=1):
    metrics_list = []
    for index, i in enumerate(data):
        metrics_list.append(
            calculate_metrics(i[:, label_column], i[:, score_column], cutoff=cutoffs[index], po_label=po_label))
    if len(metrics_list) == 1:
        return metrics_list
    else:
        mean_dict = {}
        std_dict = {}
        keys = metrics_list[0].keys()
        for i in keys:
            mean_list = []
            for metric in metrics_list:
                mean_list.append(metric[i])
            mean_dict[i] = np.array(mean_list).sum() / len(metrics_list)
            std_dict[i] = np.array(mean_list).std()
        metrics_list.append(mean_dict)
        metrics_list.append(std_dict)
        return metrics_list


def save_result(cv_res, ind_res, outPath, codename):
    out = os.path.join(outPath, codename.lower())
    save_predict_result(cv_res, out + '_pre_cv.txt')
    cv_meanauc, cv_auc = plot_roc_curve(cv_res, out + '_roc_cv.png', label_column=0, score_column=2)
    cv_metrics = calculate_metrics_list(cv_res, label_column=0, score_column=2, cutoff=0.5, po_label=1)
    save_prediction_metrics_list(cv_metrics, out + '_metrics_cv.txt')

    if ind_res is not None:
        save_predict_result(ind_res, out + '_pre_ind.txt')
        ind_meanauc, ind_auc = plot_roc_curve(ind_res, out + '_roc_ind.png', label_column=0, score_column=2)
        ind_metrics = calculate_metrics_list(ind_res, label_column=0, score_column=2, cutoff=0.5, po_label=1)
        save_prediction_metrics_list(ind_metrics, out + '_metrics_ind.txt')


# Create folder
def mkdir(path):
    path = path.strip()
    path = path.rstrip("\\")
    # Check if the path exists
    isExists = os.path.exists(path)
    if not isExists:
        # Create the directory if it doesn't exist
        os.makedirs(path)
    else:
        # Do not create directory if it exists
        pass

In [None]:
parent_dir = os.path.abspath(os.path.dirname(os.getcwd()))

In [None]:
trainfilepath = r'../../dataset/five_fold_cross_validation.csv'
testfilepath = r'../../dataset/independent.csv'

# Load AAF

In [None]:
from code.feature_extraction.aaindex import AAIndex
aaindex_train = AAIndex(trainfilepath)
aaindex_test = AAIndex(testfilepath)

# Load ZSF and labels

In [None]:
from code.feature_extraction.zsf import ZScale
zscale_train, y = ZScale(trainfilepath, 1)
zscale_test, y_test = ZScale(testfilepath, 1)

# Load PBF

In [None]:
from code.feature_extraction.pbf import extract_embedding_features
train_seqs = pd.read_csv(trainfilepath, sep=',')['Sequence']
protein_bert_train = extract_embedding_features(train_seqs.values.tolist())
protein_bert_train = np.float32(protein_bert_train)

test_seqs = pd.read_csv(testfilepath, sep=',')['Sequence']
protein_bert_test = extract_embedding_features(test_seqs.values.tolist())
protein_bert_test = np.float32(protein_bert_test)

# experimental results of 5-fload cross validation and independent test

In [None]:
def train():
    suboutput = os.path.join(parent_dir, 'Results')
    mkdir(suboutput)
    prediction_result_cv = []
    prediction_result_ind = []
    file_Name = 'iSumo-RsFPN'
    folds = StratifiedKFold(5).split(zscale_train, y)
    historys = []
    for i, (train, valid) in tqdm(enumerate(folds)):
        train_x_emb, train_y = protein_bert_train[train], y[train]
        valid_x_emb, valid_y = protein_bert_train[valid], y[valid]
        train_x_zscale, valid_x_zscale = zscale_train[train], zscale_train[valid]
        train_x_aaindex, valid_x_aaindex = aaindex_train[train], aaindex_train[valid]
        modelName = 'model1' + str(i + 1) + '.h5'
        filepath_1 = os.path.join(suboutput, modelName)
        network_1 = Res_FPN(train_x_emb)
        early_stopping_1 = tf.keras.callbacks.EarlyStopping(monitor='val_acc', patience=50)
        best_saving_1 = tf.keras.callbacks.ModelCheckpoint(filepath_1, monitor='val_acc', mode='auto',
                                                           verbose=0, save_best_only=True, save_weights_only=True)
        network_1.fit(train_x_emb, train_y, validation_data=(valid_x_emb, valid_y), epochs=1000, batch_size=128,
                      shuffle=True, callbacks=[best_saving_1, early_stopping_1], verbose=0)
        network_1.load_weights(filepath_1)
        p1, p2, p3 = network_1.predict(valid_x_emb)
        p_1 = (p1 + p2 + p3) / 3

        modelName = 'model2' + str(i + 1) + '.h5'
        filepath_2 = os.path.join(suboutput, modelName)
        network_2 = Res_FPN(train_x_zscale)
        early_stopping_2 = tf.keras.callbacks.EarlyStopping(monitor='val_acc', patience=50)
        best_saving_2 = tf.keras.callbacks.ModelCheckpoint(filepath_2, monitor='val_acc', mode='auto',
                                                           verbose=0, save_best_only=True, save_weights_only=True)
        network_2.fit(train_x_zscale, train_y, validation_data=(valid_x_zscale, valid_y), epochs=1000, batch_size=128,
                      shuffle=True, callbacks=[best_saving_2, early_stopping_2], verbose=0)
        network_2.load_weights(filepath_2)
        p1, p2, p3 = network_2.predict(valid_x_zscale)
        p_2 = (p1 + p2 + p3) / 3

        modelName = 'model3' + str(i + 1) + '.h5'
        filepath_3 = os.path.join(suboutput, modelName)
        network_3 = Res_FPN(train_x_aaindex)
        early_stopping_3 = tf.keras.callbacks.EarlyStopping(monitor='val_acc', patience=50)
        best_saving_3 = tf.keras.callbacks.ModelCheckpoint(filepath_3, monitor='val_acc', mode='auto',
                                                           verbose=0, save_best_only=True, save_weights_only=True)
        network_3.fit(train_x_aaindex, train_y, validation_data=(valid_x_aaindex, valid_y), epochs=1000, batch_size=128,
                      shuffle=True, callbacks=[best_saving_3, early_stopping_3], verbose=0)
        network_3.load_weights(filepath_3)
        p1, p2, p3 = network_3.predict(valid_x_aaindex)
        p_3 = (p1 + p2 + p3) / 3

        p = (p_1 + p_2 + p_3) / 3

        tmp_result = np.zeros((len(valid_y), 3))
        tmp_result[:, 0], tmp_result[:, 1:] = valid_y, p
        prediction_result_cv.append(tmp_result)

        tmp_result1 = np.zeros((len(y_test), 3))
        p1, p2, p3 = network_1.predict(protein_bert_test)
        p_1 = (p1 + p2 + p3) / 3
        p1, p2, p3 = network_2.predict(zscale_test)
        p_2 = (p1 + p2 + p3) / 3
        p1, p2, p3 = network_3.predict(aaindex_test)
        p_3 = (p1 + p2 + p3) / 3
        p = (p_1 + p_2 + p_3) / 3
        tmp_result1[:, 0], tmp_result1[:, 1:] = y_test, p
        prediction_result_ind.append(tmp_result1)
    save_result(prediction_result_cv, prediction_result_ind, suboutput, file_Name)

    return historys

In [None]:
flag = train()