In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import DataLoader,Dataset
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GATConv
from sklearn.ensemble import RandomForestClassifier
torch.cuda.empty_cache()
import random
device=torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

In [2]:
class GCN(nn.Module):
    def __init__(self, circRNA_number,disease_number,fcir,fdis,gcn_layers,out_channels):
        super(GCN, self).__init__()
        self.circRNA_number = circRNA_number
        self.disease_number = disease_number
        self.fcir= fcir
        self.fdis= fdis
        self.gcn_layers=gcn_layers
        self.out_channels=out_channels
        self.x_cir=nn.parameter.Parameter(data=torch.randn(self.circRNA_number, self.fcir), requires_grad=True)
        self.x_dis=nn.parameter.Parameter(data=torch.randn(self.disease_number, self.fdis), requires_grad=True)
        self.gcn_cir1_f = GCNConv(self.fcir, self.fcir)
        self.gcn_cir2_f = GCNConv(self.fcir,self.fcir)
        self.gcn_dis1_f = GCNConv(self.fdis, self.fdis)
        self.gcn_dis2_f = GCNConv(self.fdis, self.fdis)
        self.cnn_cir = nn.Conv2d(in_channels=self.gcn_layers,
                               out_channels=self.out_channels,
                               kernel_size=(self.fcir, 1),
                               stride=1,
                               bias=True)
        self.cnn_dis = nn.Conv2d(in_channels=self.gcn_layers,
                               out_channels=self.out_channels,
                               kernel_size=(self.fdis, 1),
                               stride=1,
                               bias=True)
        self.gat_cir1_f = GATConv(self.fcir, self.fcir,heads=4,concat=False,edge_dim=1)
        self.gat_dis1_f = GATConv(self.fdis, self.fdis,heads=4,concat=False,edge_dim=1)
    def forward(self, data):
        x_cir_f1 = torch.relu(self.gcn_cir1_f(self.x_cir, data['cc']['edges'],data['cc']['edge_value']))# 
        x_cir_att= torch.relu(self.gat_cir1_f(x_cir_f1,data['cc']['edges'],data['cc']['edge_value']))#
        x_cir_f2 = torch.relu(self.gcn_cir2_f(x_cir_att, data['cc']['edges'],data['cc']['edge_value']))#
        x_dis_f1 = torch.relu(self.gcn_dis1_f(self.x_dis, data['dd']['edges'],data['dd']['edge_value']))#
        x_dis_att =torch.relu(self.gat_dis1_f(x_dis_f1, data['dd']['edges'],data['dd']['edge_value']))#
        x_dis_f2 = torch.relu(self.gcn_dis2_f(x_dis_att, data['dd']['edges'],data['dd']['edge_value']))#
        X_cir = torch.cat((x_cir_f1, x_cir_f2), 1).t()
        X_cir = X_cir.view(1, self.gcn_layers, self.fcir, -1)
        X_dis = torch.cat((x_dis_f1, x_dis_f2), 1).t()
        X_dis = X_dis.view(1, self.gcn_layers, self.fdis, -1)
        cir_fea = self.cnn_cir(X_cir)
        cir_fea = cir_fea.view(self.out_channels, self.circRNA_number).t()
        dis_fea = self.cnn_dis(X_dis)
        dis_fea = dis_fea.view(self.out_channels, self.disease_number).t()
        return cir_fea.mm(dis_fea.t()),cir_fea,dis_fea
# _,cd,fea,tri,tei=torch.load('./data_circ/dataset/circ_CNN.pth')
# input={'cc':{},'dd':{}}
# data_matrix=fea[0][:834,:834]
# input['cc']['edges']=torch.argwhere(data_matrix>0.85).t()
# input['cc']['edge_value']=data_matrix[input['cc']['edges'][0], input['cc']['edges'][1]].float().to(device)
# input['cc']['edges']=input['cc']['edges'].to(device)
# data_matrix=fea[0][834:834+138,834:834+138]
# input['dd']['edges']=torch.argwhere(data_matrix>0).t()
# input['dd']['edge_value']=data_matrix[input['dd']['edges'][0], input['dd']['edges'][1]].float().to(device)
# input['dd']['edges']=input['dd']['edges'].to(device)
# # print(input['cc']['data_matrix'].shape,input['dd']['data_matrix'].shape)
# net=GCN(circRNA_number=834,
#         disease_number=138,
#         fcir=64,
#         fdis=64,
#         gcn_layers=2,
#         out_channels=2).to(device)
# net(input)[2].shape

In [3]:
import time


_,cd,fea,tri,tei=torch.load('circ_CNN.pth')
# # 设置随机种子
# def set_seed(seed):
#     random.seed(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     if torch.cuda.is_available():
#         torch.cuda.manual_seed(seed)
#         torch.cuda.manual_seed_all(seed)
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = False
# cd = np.load(rf'E:\CompeletedMethodsCodeAndPaper\data_circ\dataset\circRNA_disease.npy')
# fea = torch.load(rf'E:\CompeletedMethodsCodeAndPaper\data_circ\dataset\cover_feature_matrix.pth')
# tri = torch.load(rf'E:\CompeletedMethodsCodeAndPaper\data_circ\dataset\train_dataset.pth')
# tei = torch.load(rf'E:\CompeletedMethodsCodeAndPaper\data_circ\dataset\test_data.pth')
seedIndex = [2048, 2048, 2048, 2048, 2048] 
res=[]
criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
for i in range(1):    
    # set_seed(seedIndex[i])
    print('cross:%d'%i)
    net=GCN(circRNA_number=834,
        disease_number=138,
        fcir=128,
        fdis=128,
        gcn_layers=2,
        out_channels=128).to(device)
    optimizer=torch.optim.Adam(net.parameters(),lr=0.005)#weight_decay=5e-5
    feat=fea[i][:834,834:834+138].float()
    data_matrix=fea[i][:834,:834]
    input={'cc':{},'dd':{}}
    input['cc']['edges']=torch.argwhere(data_matrix>0.85).t()
    input['cc']['edge_value']=data_matrix[input['cc']['edges'][0], input['cc']['edges'][1]].float().to(device)
    input['cc']['edges']=input['cc']['edges'].to(device)
    data_matrix=fea[i][834:834+138,834:834+138]
    input['dd']['edges']=torch.argwhere(data_matrix>0).t()
    input['dd']['edge_value']=data_matrix[input['dd']['edges'][0], input['dd']['edges'][1]].float().to(device)
    input['dd']['edges']=input['dd']['edges'].to(device)
    train_start_time = time.time()
    for e in range(100):
        score,x,y=net(input)
        loss = criterion(score, feat.to(device))
        optimizer.zero_grad()
        print(loss)
        loss.backward()
        optimizer.step()
        # print(loss.item())
    train_end_time = time.time()
    print('train time:', train_end_time - train_start_time)
    test_start_time = time.time()
    score,x,y=net(input)
    x,y=x.detach().cpu(),y.detach().cpu()
    clf = RandomForestClassifier(n_estimators=200,n_jobs=11,max_depth=20)
    clf.fit(torch.cat([x[tri[i][0,:]],y[tri[i][1,:]]],dim=1), cd[tri[i][0,:],tri[i][1,:]])
    y_prob = clf.predict_proba(torch.cat([x[tei[i][0,:]],y[tei[i][1,:]]],dim=1))
    res.append([y_prob[:,0],cd[tei[i][0,:],tei[i][1,:]]])
    test_end_time = time.time()
    print('test time:', test_end_time - test_start_time)
    torch.save([y_prob[:,0],cd[tei[i][0,:],tei[i][1,:]]],f'GraphCDA_{i}.pth')

  _,cd,fea,tri,tei=torch.load('circ_CNN.pth')


cross:0
tensor(0.7397, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.1343, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.0995, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.1448, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.1160, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.0683, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.1100, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.0674, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.0907, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.1012, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.1022, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.0971, device='c

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score, auc, precision_recall_curve
import warnings
warnings.filterwarnings("ignore")
def roc_pr4_folder(test_x_ys, labels, pred_ys, ass_mat_shape):
	labels_mat, pred_ys_mat, test_num= torch.zeros((ass_mat_shape)) -1, torch.zeros((ass_mat_shape)) -1, len(labels)
	for i in range(test_num):
		labels_mat[test_x_ys[i][0], test_x_ys[i][1]]= labels[i]
		pred_ys_mat[test_x_ys[i][0], test_x_ys[i][1]]= pred_ys[i]
	bool_mat4test= (labels_mat!= -1)
	fpr_ls, tpr_ls, recall_ls, prec_ls, effective_rows_len = [], [], [], [], 0
	for i in range(ass_mat_shape[0]):
		if (labels_mat[i][bool_mat4test[i]]== 1).sum()> 0:
			effective_rows_len+= 1
			labels4test1rowi= labels_mat[i][bool_mat4test[i]]
			pred_y4test1rowi= pred_ys_mat[i][bool_mat4test[i]]
			fpr4rowi, tpr4rowi, _= roc_curve(labels4test1rowi, pred_y4test1rowi)
			fpr_ls.append(fpr4rowi)
			tpr_ls.append(tpr4rowi)
			precision4rowi, recall4rowi, _= precision_recall_curve(labels4test1rowi, pred_y4test1rowi)
			precision4rowi[-1]= [1, 0][precision4rowi[-2]== 0]
			prec_ls.append(precision4rowi[::-1])
			recall_ls.append(recall4rowi[::-1])
	mean_fpr, mean_recall= np.linspace(0, 1, 100), np.linspace(0, 1, 100)
	tpr_ls4mean_tpr, prec_ls4mean_prec= [], []
	for i in range(effective_rows_len):
		tpr_ls4mean_tpr.append(np.interp(mean_fpr, fpr_ls[i], tpr_ls[i]))
		prec_ls4mean_prec.append(np.interp(mean_fpr, recall_ls[i], prec_ls[i]))
	mean_tpr, mean_prec= np.mean(tpr_ls4mean_tpr, axis= 0), np.mean(prec_ls4mean_prec, axis= 0)
	print(f'ROC平均值auc(mean_fpr, mean_tpr): {auc(mean_fpr, mean_tpr)}')
	print(f'pr平均值auc(mean_recall, mean_prec)：{auc(mean_recall, mean_prec)}')
	return mean_fpr, mean_tpr, mean_recall, mean_prec
def roc_pr4cross_val(mean_fpr_ts, mean_tpr_ts, mean_recall_ts, mean_prec_ts, k_fold):
	mean_fpr, mean_tpr, mean_recall, mean_prec= mean_fpr_ts[0], torch.mean(mean_tpr_ts, dim= 0), mean_recall_ts[0], torch.mean(mean_prec_ts, dim= 0)
	torch.save([mean_fpr, mean_tpr, mean_recall, mean_prec],'GraphCDA.pkl')
	aucs4roc, aucs4pr= [], []
	for i in range(k_fold):
		aucs4roc.append(auc(mean_fpr_ts[i], mean_tpr_ts[i]))
		plt.plot(mean_fpr_ts[i], mean_tpr_ts[i], lw= 1, alpha= 0.3, label= 'ROC fold %d (AUC= %0.3f)' % (i+ 1, aucs4roc[i]))
	aucs4roc_std, mean_auc4roc= np.std(aucs4roc), auc(mean_fpr, mean_tpr)
	plt.plot(mean_fpr, mean_tpr, color= 'b', lw= 2, alpha= 0.8, label=r'Mean ROC (AUC = %0.3f $\pm$ %0.3f)' % (mean_auc4roc, aucs4roc_std))
	plt.title('roc curve')
	plt.xlabel('fpr')
	plt.ylabel('tpr')
	plt.axis([0, 1, 0, 1])
	plt.legend(loc= 'lower right')
	plt.show()
	for i in range(k_fold):
		aucs4pr.append(auc(mean_recall_ts[i], mean_prec_ts[i]))
		plt.plot(mean_recall_ts[i], mean_prec_ts[i], lw= 1, alpha= 0.3, label= 'PR fold %d (AUPR= %0.3f)' % (i+ 1, aucs4pr[i]))
	aucs4pr_std, mean_auc4pr= np.std(aucs4pr), auc(mean_recall, mean_prec)
	plt.plot(mean_recall, mean_prec, color= 'b', lw= 2, alpha= 0.8, label= r'Mean PR (AUPR = %0.3f $\pm$ %0.3f)' % (mean_auc4pr, aucs4pr_std))
	plt.title('pr curve')
	plt.xlabel('recall')
	plt.ylabel('precision')
	plt.axis([0, 1, 0, 1])
	plt.legend(loc= 'lower right')
	plt.show()

In [2]:
mean_fprs, mean_tprs, mean_recalls, mean_precs= [], [], [], []
_,cd,fea,tri,tei=torch.load('circ_CNN.pth')
# cd = np.load(rf'E:\CompeletedMethodsCodeAndPaper\data_circ\dataset\circRNA_disease.npy')
# fea = torch.load(rf'E:\CompeletedMethodsCodeAndPaper\data_circ\dataset\cover_feature_matrix.pth')
# tri = torch.load(rf'E:\CompeletedMethodsCodeAndPaper\data_circ\dataset\train_dataset.pth')
# tei = torch.load(rf'E:\CompeletedMethodsCodeAndPaper\data_circ\dataset\test_data.pth')
for i in range(5):
    pred, y=torch.load('GraphCDA_%d.pth'%i)
    # pred, y=res[i]
    test_idx= tei[i].T
    test_idx= torch.stack([test_idx[:, 1], test_idx[:, 0]], dim= 1)
    mean_fpr, mean_tpr, mean_recall, mean_prec= roc_pr4_folder(test_idx, y, pred, (138, 834))
    mean_fprs.append(torch.tensor(mean_fpr)); mean_tprs.append(torch.tensor(mean_tpr)); mean_recalls.append(torch.tensor(mean_recall)); mean_precs.append(torch.tensor(mean_prec))
mean_fpr_ts, mean_tpr_ts, mean_recall_ts, mean_prec_ts= torch.stack(mean_fprs), torch.stack(mean_tprs), torch.stack(mean_recalls, dim= 0), torch.stack(mean_precs, dim= 0)
roc_pr4cross_val(mean_fpr_ts, mean_tpr_ts, mean_recall_ts, mean_prec_ts, 5)

ROC平均值auc(mean_fpr, mean_tpr): 0.7107374459824202
pr平均值auc(mean_recall, mean_prec)：0.1511844474249854
ROC平均值auc(mean_fpr, mean_tpr): 0.7244262632296489
pr平均值auc(mean_recall, mean_prec)：0.1380788996662683
ROC平均值auc(mean_fpr, mean_tpr): 0.7770683717172567
pr平均值auc(mean_recall, mean_prec)：0.21064190403449687
ROC平均值auc(mean_fpr, mean_tpr): 0.7131902028154088
pr平均值auc(mean_recall, mean_prec)：0.17459663960364835
ROC平均值auc(mean_fpr, mean_tpr): 0.7191574755230659
pr平均值auc(mean_recall, mean_prec)：0.200606959125948


UnboundLocalError: cannot access local variable 'mean_fpr' where it is not associated with a value