<a href="https://colab.research.google.com/github/Shiy-Li/Bi-SGTAR/blob/main/Pub_Bi_SGTAR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [1]:
%cd /content/drive/MyDrive/Bi-SGTAR/

/content/drive/MyDrive/Bi-SGTAR


# Utils

In [13]:
import random
import torch
import numpy as np
import torch.nn.functional as F


def load_dict(data):
    if data == 1:
        cancer_dict = {'glioma': 7, 'bladder cancer': 9, 'breast cancer': 10, 'cervical cancer': 53,
                       'cervical carcinoma': 64, 'colorectal cancer': 11, 'gastric cancer': 19}
    elif data == 2:
        cancer_dict = {'glioma': 23, 'bladder cancer': 2, 'breast cancer': 4, 'cervical cancer': 6,
                       'colorectal cancer': 12, 'gastric cancer': 20}
    elif data == 3:
        cancer_dict = {'glioma': 20, 'bladder cancer': 19, 'breast cancer': 6, 'cervical cancer': 16,
                       'colorectal cancer': 1, 'gastric cancer': 0}
    elif data == 4:
        # circ2Traits
        cancer_dict = {'bladder cancer': 58, 'breast cancer': 46, 'glioma': 89, 'glioblastoma': 88,
                       'glioblastoma multiforme': 59, 'cervical cancer': 23, 'colorectal cancer': 6,
                       'gastric cancer': 15}
    elif data == 5:
        # circad
        cancer_dict = {'bladder cancer': 94, 'breast cancer': 53, 'triple-negative breast cancer': 111, 'gliomas': 56,
                       'glioma': 76,
                       'cervical cancer': 65, 'colorectal cancer': 143, 'gastric cancer': 28}
    else:
        cancer_dict = {}
    return cancer_dict


def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


def build_model(model_type):
    if model_type == 'BiSGTAR':
        return BiSGTAR

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def sort_matrix(score_matrix, interact_matrix):
    sort_index = np.argsort(-score_matrix, axis=0)
    score_sorted = np.zeros(score_matrix.shape)
    y_sorted = np.zeros(interact_matrix.shape)
    for i in range(interact_matrix.shape[1]):
        score_sorted[:, i] = score_matrix[:, i][sort_index[:, i]]
        y_sorted[:, i] = interact_matrix[:, i][sort_index[:, i]]
    return y_sorted, score_sorted, sort_index


def load_association(args):
    if args.data == 5:
        circrna_disease_matrix = np.loadtxt('./data/Dataset5/1265_151_circrna_disease_assoication.csv',
                                            delimiter=',')
    elif args.data == 4:
        circrna_disease_matrix = np.loadtxt('./data/Dataset4/923_104_circrna_disease_assoication.csv',
                                            delimiter=',')
    elif args.data == 3:
        circrna_disease_matrix = np.loadtxt('./data/Dataset3/312_40_circrna_disease_assoication.csv',
                                            delimiter=',')
    elif args.data == 2:
        circrna_disease_matrix = np.loadtxt('./data/Dataset2/514_62_circrna_disease_assoication.csv',
                                            delimiter=',')
    elif args.data == 1:
        circrna_disease_matrix = np.loadtxt('./data/Dataset1/533_89_circrna_disease_assoication.csv',
                                          delimiter=',')
    elif args.data == 6:
        circrna_disease_matrix = np.loadtxt('./data/KGET-Dataset1/330_79_circrna_disease_assoication.csv',
                                            delimiter=',')
    elif args.data == 7:
        circrna_disease_matrix = np.loadtxt('./data/KGET-Dataset2/561_190_circrna_disease_assoication.csv',
                                            delimiter=',')
    elif args.data == 8:
        circrna_disease_matrix = np.loadtxt('./data/Dataset6/l_d2.csv', delimiter=',')

    elif args.data == 9:
        circrna_disease_matrix = np.loadtxt('./data/Dataset7/C_D2.csv', delimiter=',')

    return circrna_disease_matrix


# Model

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class MLPAE(nn.Module):
    def __init__(self, hid_dim, out_dim, bias=False):
        super(MLPAE, self).__init__()
        # encoder-1
        self.e1 = nn.Linear(out_dim, hid_dim, bias=bias)
        # decoder
        self.d1 = nn.Linear(hid_dim, out_dim, bias=bias)

        self.Confidence = nn.Linear(hid_dim, out_dim, bias=bias)
        self.act1 = nn.ELU()
        self.act2 = nn.Sigmoid()
        self.act3 = nn.Sigmoid()

    def encoder(self, x):
        h = self.act1(self.e1(x))
        return h

    def decoder(self, z):
        h = self.act2(self.d1(z))
        return h

    def confidencer(self, z):
        y = self.act3(self.Confidence(z))
        return y

    def forward(self, x):
        z = self.encoder(x)
        h = self.decoder(z)
        y = self.confidencer(z)
        return y, h

class BiSGTAR(nn.Module):
    def __init__(self, args):
        super(BiSGTAR, self).__init__()
        dis_num = args.dis_num
        rna_num = args.rna_num
        self.input_drop = nn.Dropout(0.)
        self.att_drop = nn.Dropout(0.)
        self.FeatQC_rna = nn.Linear(dis_num, dis_num, bias=True)
        self.FeatQC_dis = nn.Linear(rna_num, rna_num, bias=True)
        self.AE_rna = MLPAE(args.hidden, dis_num)
        self.AE_dis = MLPAE(args.hidden, rna_num)
        self.act = nn.Sigmoid()
        self.dropout = args.dropout

    def forward(self, feat):
        rna_quality = self.act(F.dropout(self.FeatQC_rna(feat), self.dropout))
        dis_quality = self.act(F.dropout(self.FeatQC_dis(feat.t()), self.dropout))

        rna_sparse_feat = torch.mul(rna_quality, feat)
        dis_sparse_feat = torch.mul(dis_quality, feat.t())

        yc, hc = self.AE_rna(rna_sparse_feat)
        yd, hd = self.AE_dis(dis_sparse_feat)

        return yc, rna_sparse_feat, rna_quality, hc, yd, dis_sparse_feat, dis_quality, hd

# Metrics

In [17]:
from sklearn.metrics import average_precision_score, roc_auc_score
import numpy as np


def metrics(score_matrix, roc_circrna_disease_matrix):
    sorted_circrna_disease_matrix, sorted_score_matrix, sort_index = sort_matrix(score_matrix,
                                                                                 roc_circrna_disease_matrix)
    tpr_list = []
    fpr_list = []
    recall_list = []
    precision_list = []
    accuracy_list = []
    F1_list = []
    for cutoff in range(sorted_circrna_disease_matrix.shape[0]):
        P_matrix = sorted_circrna_disease_matrix[0:cutoff + 1, :]
        N_matrix = sorted_circrna_disease_matrix[cutoff + 1:sorted_circrna_disease_matrix.shape[0] + 1, :]
        TP = np.sum(P_matrix == 1)
        FP = np.sum(P_matrix == 0)
        TN = np.sum(N_matrix == 0)
        FN = np.sum(N_matrix == 1)
        tpr = TP / (TP + FN)
        fpr = FP / (FP + TN)
        tpr_list.append(tpr)
        fpr_list.append(fpr)
        recall = TP / (TP + FN)
        precision = TP / (TP + FP)
        recall_list.append(recall)
        precision_list.append(precision)
        accuracy = (TN + TP) / (TN + TP + FN + FP)
        F1 = (2 * TP) / (2 * TP + FP + FN)
        if (2 * TP + FP + FN) == 0:
            F1 = 0
        F1_list.append(F1)
        accuracy_list.append(accuracy)


    top_list = [50, 100, 200]
    for num in top_list:
        P_matrix = sorted_circrna_disease_matrix[0:num, :]
        N_matrix = sorted_circrna_disease_matrix[num:sorted_circrna_disease_matrix.shape[0] + 1, :]
        top_count = np.sum(P_matrix == 1)
        # print("top" + str(num) + ": " + str(top_count))


    tpr_arr_epoch = np.array(tpr_list)
    fpr_arr_epoch = np.array(fpr_list)
    recall_arr_epoch = np.array(recall_list)
    precision_arr_epoch = np.array(precision_list)
    accuracy_arr_epoch = np.array(accuracy_list)
    F1_arr_epoch = np.array(F1_list)
    auc_epoch = np.trapz(tpr_arr_epoch, fpr_arr_epoch)
    aupr_epoch = np.trapz(precision_arr_epoch, recall_arr_epoch)
    return tpr_list, fpr_list, recall_list, precision_list, accuracy_list, F1_list


def calculate_performace(y_prob, y_test):
    tp = 0
    fp = 0
    tn = 0
    fn = 0
    y_prob = np.array(y_prob)
    y_test = np.array(y_test)
    num = len(y_prob)
    y_pred = np.where(y_prob >= 0.5, 1., 0.)
    for index in range(num):
        if y_test[index] == 1:
            if y_test[index] == y_pred[index]:
                tp = tp + 1
            else:
                fn = fn + 1
        else:
            if y_test[index] == y_pred[index]:
                tn = tn + 1
            else:
                fp = fp + 1

    acc = float(tp + tn) / num
    try:
        precision = float(tp) / (tp + fp)
        recall = float(tp) / (tp + fn)
        f1_score = float((2 * precision * recall) / (precision + recall))
        MCC = float(tp * tn - fp * fn) / (np.sqrt(float((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))))
        sens = tp / (tp + fn)
        spec = tn / (tn + tp)
    except ZeroDivisionError:
        print("You can't divide by 0.")
        precision = recall = f1_score = sens = MCC = spec = 100
    AUC = roc_auc_score(y_test, y_prob)
    auprc = average_precision_score(y_test, y_prob)

    return tp, fp, tn, fn, acc, precision, sens, f1_score, MCC, AUC, auprc, spec


# Parser

In [18]:
import argparse


def parameter_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='Disables CUDA training.')
    parser.add_argument('--model-type', type=str, default='BiSGTAR',
                        help='choose the model.')
    parser.add_argument('--rna-num', type=int, default=0, help='circrna number.')
    parser.add_argument('--dis-num', type=int, default=0, help='disease number.')
    parser.add_argument('--seed', type=int, default=1, help='Random seed.')
    parser.add_argument('--epochs', type=int, default=500,
                        help='Number of epochs to train.')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='Learning rate.')
    parser.add_argument('--weight_decay', type=float, default=1e-8,
                        help='Weight decay (L2 loss on parameters).')
    parser.add_argument('--hidden', type=int, default=256,
                        help='Dimension of representations')
    parser.add_argument('--alpha', type=float, default=0.5,
                        help='Weight between lncRNA space and disease space')
    parser.add_argument('--beta', type=float, default=0.4,
                        help='Weight to balance the modules SPC and TAR')
    parser.add_argument('--gama', type=float, default=0.05,
                        help='Weight of the fusion part')
    parser.add_argument('--dropout', type=float, default=0.)
    parser.add_argument('--data', type=int, default=5, help='Dataset')
    parser.add_argument('--para', type=float, default=1e-2, help='Smooth Factor')

    args = parser.parse_known_args()[0]
    return args

# Train

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F


def find_key(i, cancer_dict):
    name = list(cancer_dict.keys())[list(cancer_dict.values()).index(i)]
    return name


def train(model, y0, args, alpha, i, rel):
    loss_list = []
    opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    criterion = nn.BCELoss()
    for e in range(args.epochs):
        model.train()
        yl, rna_feat, rna_quality, hc, yd, dis_feat, dis_quality, hd = model(y0)
        y = alpha * yl + (1 - alpha) * yd.t()

        rna_confidence = torch.mul(hc, rel)
        dis_confidence = torch.mul(hd, rel.t())

        # Intra-view loss

        # RNA
        rna_SPC = torch.mean(rna_quality)
        rna_TAR = criterion(hc, rel) + F.mse_loss(yl, rna_confidence)
        rna_loss = args.beta * rna_TAR + (1-args.beta) * rna_SPC

        # Disease
        dis_SPC = torch.mean(dis_quality)
        dis_TAR = criterion(hd, rel.t()) + F.mse_loss(yd, dis_confidence)
        dis_loss = args.beta * dis_TAR + (1 - args.beta) * dis_SPC

        loss_inter = alpha * rna_loss + (1-alpha) * dis_loss
        loss_cls = criterion(y, rel)

        loss = args.gama * loss_cls + (1 - args.gama) * loss_inter

        loss_list.append(loss.item())
        opt.zero_grad()
        loss.backward()
        opt.step()

        with torch.no_grad():
            yl, _, _, zc, yd, _, _, zd = model(y0)
        # if e % 50 == 0:
        #   print('Epoch %d | Lossp: %.4f' % (e, lossp.item()))
    model.eval()
    yli, rna_feat, quality_yli, hc, ydi, dis_feat, quality_ydi, hd = model(y0)
    y = alpha * yli + (1 - alpha) * ydi.t()
    return y, model


# Main

In [20]:
import math
import h5py

import numpy as np
import torch
from sklearn.metrics import roc_curve, average_precision_score, auc
import random


def main(args):
    circrna_disease_matrix = load_association(args)

    print('Now Load Dataset ' + str(args.data))
    args.rna_num = circrna_disease_matrix.shape[0]
    args.dis_num = circrna_disease_matrix.shape[1]
    print('rna_num', args.rna_num)
    print('dis_num', args.dis_num)

    n_fold = 5
    if args.data == 8:
        n_fold = 10
    index_tuple = (np.where(circrna_disease_matrix == 1))
    index_tuple_0 = (np.where(circrna_disease_matrix == 0))
    one_list = list(zip(index_tuple[0], index_tuple[1]))
    zero_list = list(zip(index_tuple_0[0], index_tuple_0[1]))
    rnd_state = random.Random(0)
    rnd_state.shuffle(one_list)
    rnd_state.shuffle(zero_list)

    split = math.ceil(len(one_list) / n_fold)
    split_0 = math.ceil(len(zero_list) / n_fold)
    print('split: ', split)
    print('split_0: ', split_0)
    # Evaluation Option-1
    all_tpr = []
    all_fpr = []
    all_recall = []
    all_precision = []
    all_accuracy = []
    all_F1 = []

    # Evaluation Option-2
    # y_prob = []
    # y_test = []
    # 5-fold start
    for i in range(n_fold):
        test_index = one_list[i * split:(i + 1) * split]
        test_index_0 = zero_list[i * split_0:(i + 1) * split_0]
        new_circrna_disease_matrix = circrna_disease_matrix.copy()

        for index in test_index:
            new_circrna_disease_matrix[index[0], index[1]] = 0
        roc_circrna_disease_matrix = new_circrna_disease_matrix + circrna_disease_matrix
        rel_matrix = new_circrna_disease_matrix
        circnum = rel_matrix.shape[0]
        disnum = rel_matrix.shape[1]
        rel_matrix_tensor = torch.tensor(np.array(rel_matrix).astype(np.float32))

        model_init = build_model(args.model_type)
        model = model_init(args)
        if args.cuda:
            model = model.cuda()
            rel_matrix_tensor = rel_matrix_tensor.cuda()

        smooth_factor = args.para
        norm_rel = smooth_factor + (1 - 2 * smooth_factor) * rel_matrix_tensor
        resi,model = train(model, norm_rel, args, args.alpha, i, rel_matrix_tensor)
        if args.cuda:
            ymat = resi.cpu().detach().numpy()
        else:
            ymat = resi.detach().numpy()
###--------------------------------Evaluation Option 1------------------------------------------------------###
    #     S = ymat
    #     for index_0 in test_index_0:
    #         y_prob.append(S[index_0[0],index_0[1]])
    #         y_test.append(circrna_disease_matrix[index_0[0],index_0[1]])
    #     for index in test_index:
    #         y_prob.append(S[index[0],index[1]])
    #         y_test.append(circrna_disease_matrix[index[0],index[1]])

    # y_prob = np.array(y_prob)
    # y_test = np.array(y_test)
    # fpr, tpr, threshold = roc_curve(y_test, y_prob)
    # auc_val = auc(fpr, tpr)
    # aupr_val = average_precision_score(y_test, y_prob)
    # print('Final: \n  auc_val = \t'+ str(auc_val)+'\n  avpr_val = \t'+ str(aupr_val))
    # print('-' * 200)
    # return auc_val

###--------------------------------Evaluation Option 2------------------------------------------------------###
        S = ymat
        prediction_matrix = S
        zero_matrix = np.zeros((prediction_matrix.shape[0], prediction_matrix.shape[1]))
        score_matrix_temp = prediction_matrix.copy()
        score_matrix = score_matrix_temp + zero_matrix
        minvalue = np.min(score_matrix)
        score_matrix[np.where(roc_circrna_disease_matrix == 2)] = minvalue - 20
        tpr_list, fpr_list, recall_list, precision_list, accuracy_list, F1_list = metrics(score_matrix,roc_circrna_disease_matrix)
        all_tpr.append(tpr_list)
        all_fpr.append(fpr_list)
        all_recall.append(recall_list)
        all_precision.append(precision_list)
        all_accuracy.append(accuracy_list)
        all_F1.append(F1_list)

    tpr_arr = np.array(all_tpr)
    fpr_arr = np.array(all_fpr)
    recall_arr = np.array(all_recall)
    precision_arr = np.array(all_precision)
    accuracy_arr = np.array(all_accuracy)
    F1_arr = np.array(all_F1)

    mean_cross_tpr = np.mean(tpr_arr, axis=0)
    mean_cross_fpr = np.mean(fpr_arr, axis=0)
    mean_cross_recall = np.mean(recall_arr, axis=0)
    mean_cross_precision = np.mean(precision_arr, axis=0)
    mean_cross_accuracy = np.mean(accuracy_arr, axis=0)

    mean_accuracy = np.mean(np.mean(accuracy_arr, axis=1), axis=0)
    std_accuracy = np.std(np.mean(accuracy_arr, axis=1), axis=0)
    mean_recall = np.mean(np.mean(recall_arr, axis=1), axis=0)
    mean_precision = np.mean(np.mean(precision_arr, axis=1), axis=0)
    mean_F1 = np.mean(np.mean(F1_arr, axis=1), axis=0)

    print("K-fold cross-validation performance")
    print("accuracy:%.4f,recall:%.4f,precision:%.4f,F1:%.4f" % (mean_accuracy, mean_recall, mean_precision, mean_F1))

    roc_auc = np.trapz(mean_cross_tpr, mean_cross_fpr)
    AUPR = np.trapz(mean_cross_precision, mean_cross_recall)
    print("AUC:%.4f,AUPR:%.4f" % (roc_auc, AUPR))
    print('-' * 200)
    return roc_auc, AUPR, mean_accuracy, mean_recall, mean_precision, mean_F1


# Start

In [21]:
import torch
import pandas as pd
if __name__ == "__main__":
    import warnings

    warnings.filterwarnings("ignore")
    args = parameter_parser()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    model_types = ['BiSGTAR']
    count = 0
    count_dis = 0
    args.seed = 1

    for i in range(5):
        args.data = i + 1
        if args.data == 8:
            args.alpha, args.beta, args.gama = 0.8, 0.8, 0.8
            args.weight_decay = 1e-8
            args.epochs = 600
        elif args.data == 9:
            args.alpha, args.beta, args.gama = 0.8, 0.6, 0.8
            args.weight_decay = 1e-10
            args.epochs = 400
        for each in model_types:
            args.model_type = each
            set_seed(args.seed)
            print('Now model is: ', args.model_type)
            roc_auc, AUPR, mean_accuracy, mean_recall, mean_precision, mean_F1 = main(args)

Now model is:  BiSGTAR
Now Load Dataset 1
rna_num 533
dis_num 89
split:  119
split_0:  9369
K-fold cross-validation performance
accuracy:0.4960,recall:0.8645,precision:0.0062,F1:0.0121
AUC:0.8631,AUPR:0.0110
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Now model is:  BiSGTAR
Now Load Dataset 2
rna_num 514
dis_num 63
split:  130
split_0:  6347
K-fold cross-validation performance
accuracy:0.4941,recall:0.8440,precision:0.0097,F1:0.0186
AUC:0.8418,AUPR:0.0173
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Now model is:  BiSGTAR
Now Load Dataset 3
rna_num 312
dis_num 40
split:  67
split_0:  2430
K-fold cross-validation performance
accuracy:0.4925,recall:0.8872,precision:0.0134,F1:0.025