In [1]:
import torch
import os
from configparser import ConfigParser
import numpy as np
import scipy.sparse as sp
import tensorflow as tf
import numpy as np
import networkx as nx
import pandas as pd



In [2]:


def load_data(sample_data_path):
    '''
    加载邻接矩阵，这里利用networkx读取文件，生成图和邻接矩阵
    生成的节点的编号是根据节点在文件中出现的顺序进行编号
    :param sample_data_path:
    :return:
    '''

    df = pd.read_csv(sample_data_path, sep=",")
    g = nx.from_pandas_edgelist(df, "lncRNA", "Disease")
    # g = nx.read_edgelist(sample_data_path)
    adj = nx.adjacency_matrix(g)

    nodeIndex = dict()
    nodeAll = dict()
    
    # 得到node 和 node的 序列
    data = df.values
    nodeLncRNA = dict()
    nodeDisease = dict()
    g_nodes = list(g.nodes)

    for i in range(len(g_nodes)):
        nodeAll[g_nodes[i]] = i
        if g_nodes[i] in data[:, 0]:
            nodeLncRNA[g_nodes[i]] = i
        if g_nodes[i] in data[:, 1]:
            nodeDisease[g_nodes[i]] = i

    nodeIndex["index_lncRNA"] = nodeLncRNA
    nodeIndex["index_disease"] = nodeDisease

    return adj, nodeIndex, nodeAll



class Predict():
    def __init__(self):
        self.hidden_emb = None
        self.adj_orig = None
        self.nodeList = None
        self.nodeAll = None

    def load_model_adj(self, config_path):
        '''
        load hidden_emb and adj
        :param config_path:
        :return:
        '''
        if os.path.exists(config_path) and (os.path.split(config_path)[1].split('.')[0] == 'config') and (os.path.splitext(config_path)[1].split('.')[1] == 'cfg'):
            # load config file
            config = ConfigParser()
            config.read(config_path)
            section = config.sections()[0]

            # data catalog path
            data_catalog = config.get(section, "data_catalog")
            # train file path
            train_file_name = config.get(section, "train_file_name")
            # model save/load path
            model_path = config.get(section, "model_path")
            

            if not os.path.exists(model_path) and os.path.exists(os.path.join(data_catalog, train_file_name)):
                raise FileNotFoundError('Not found file!')
            
            print(model_path)
            
            model_path_ = "{}/lncRNA_disease.npy".format(model_path)
            print(model_path_)
            self.hidden_emb = np.load(model_path_)

            # load 原始邻接矩阵，除去对角线元素
            print(os.path.join(data_catalog, train_file_name))
            adj, self.nodeList, self.nodeAll = load_data(os.path.join(data_catalog, train_file_name))
            self.adj_orig = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape)
            self.adj_orig.eliminate_zeros()
        else:
            raise FileNotFoundError('File config.cfg not found : ' + config_path)

    def predict(self):
        def sigmoid(x):
            return 1 / (1 + np.exp(-x))
        # 内积
        adj_rec = np.dot(self.hidden_emb, self.hidden_emb.T)
        adj_rec = sigmoid(adj_rec)
        return self.adj_orig, adj_rec

    def topK(self, type='disease', node='YJL081C', k=10):
        """
        @param type: disease | lncRNA
        @param node: node-name
        @param k: top-K
        @return: dict()
        """
        node_id = self.nodeList["index_{}".format(type)][node]

        if type == 'disease':
            name = "index_{}".format('lncRNA')
            node_to_list = np.array(list(self.nodeList[name].values()))
        else:
            name = "index_{}".format('disease')
            node_to_list = np.array(list(self.nodeList[name].values()))

        params = tf.convert_to_tensor(self.hidden_emb)
        lookup_embeding_params = tf.nn.embedding_lookup(
            params=params,
            ids=node_to_list,
            max_norm=None,
            name=None)

        node_emb = tf.nn.embedding_lookup(params=params,ids=node_id, max_norm=None,name=None)
        scores = tf.math.sigmoid(tf.matmul(lookup_embeding_params, tf.reshape(node_emb, (-1,1))))

        top_k = tf.nn.top_k(tf.reshape(scores,(1,-1))[0], k, sorted=True)

        res = node_to_list[top_k[1].numpy()]
        topKindex = {}
        items = list(self.nodeAll.items())
        for i in range(len(res)):
            key = items[res[i]][0]
            topKindex[key] = top_k[0][i].numpy()

        return topKindex

In [3]:
config_path = os.path.join(os.getcwd(), 'config.cfg')

In [4]:
predict = Predict()
predict.load_model_adj(config_path)

F:/Python/Python_learning/BJUT/Bjut_Graduation/Code/Vgae/model
F:/Python/Python_learning/BJUT/Bjut_Graduation/Code/Vgae/model/lncRNA_disease.npy
F:/Python/Python_learning/BJUT/Bjut_Graduation/dataset\lncRNA_disease_towsets.csv


In [5]:
config_path = os.path.join(os.getcwd(), 'config.cfg')
predict = Predict()
predict.load_model_adj(config_path)

F:/Python/Python_learning/BJUT/Bjut_Graduation/Code/Vgae/model
F:/Python/Python_learning/BJUT/Bjut_Graduation/Code/Vgae/model/lncRNA_disease.npy
F:/Python/Python_learning/BJUT/Bjut_Graduation/dataset\lncRNA_disease_towsets.csv


In [6]:
path = "F:/Python/Python_learning/BJUT/Bjut_Graduation/dataset/lncRNA_disease_towsets.csv"

adj, nodeIndex, nodeAll = load_data(path)

In [7]:
for dis in nodeIndex['index_disease']:
    predlist = predict.topK('disease', dis, 10)
    print("Node: {} ".format(dis))
    print(predlist, "\n")

Node: hepatocellular carcinoma 
{'H19': 0.67929614, 'MALAT1': 0.67758703, 'HOTAIR': 0.6616018, 'MEG3': 0.6606383, 'NEAT1': 0.64298224, 'CDKN2B-AS1': 0.642965, 'PVT1': 0.63743126, 'UCA1': 0.63734215, 'XIST': 0.6329771, 'GAS5': 0.6322103} 

Node: stomach cancer 
{'H19': 0.6135414, 'MALAT1': 0.6119645, 'HOTAIR': 0.6016472, 'MEG3': 0.60088235, 'CDKN2B-AS1': 0.58930635, 'NEAT1': 0.58913577, 'UCA1': 0.58587927, 'PVT1': 0.58563447, 'XIST': 0.58323866, 'GAS5': 0.58219147} 

Node: breast cancer 
{'H19': 0.66032445, 'MALAT1': 0.65809596, 'HOTAIR': 0.64360255, 'MEG3': 0.6416696, 'NEAT1': 0.62738156, 'CDKN2B-AS1': 0.62622434, 'PVT1': 0.62178516, 'UCA1': 0.6216517, 'XIST': 0.61720455, 'GAS5': 0.61709607} 

Node: esophageal cancer 
{'H19': 0.80780894, 'MALAT1': 0.80266404, 'MEG3': 0.7804128, 'HOTAIR': 0.7804122, 'NEAT1': 0.75338626, 'CDKN2B-AS1': 0.7525956, 'PVT1': 0.7443969, 'UCA1': 0.74433756, 'XIST': 0.7358306, 'GAS5': 0.7353406} 

Node: prostate cancer 
{'H19': 0.60539895, 'MALAT1': 0.60444444, 

Node: testicular cancer 
{'H19': 0.5344735, 'MALAT1': 0.53365535, 'HOTAIR': 0.5304029, 'MEG3': 0.5298563, 'NEAT1': 0.5269152, 'CDKN2B-AS1': 0.5265666, 'UCA1': 0.52576286, 'PVT1': 0.52571535, 'GAS5': 0.52474177, 'XIST': 0.5246553} 

Node: chromophobe renal cell carcinoma 
{'H19': 0.54373616, 'MALAT1': 0.5428024, 'HOTAIR': 0.53837574, 'MEG3': 0.5382787, 'NEAT1': 0.53425896, 'CDKN2B-AS1': 0.5336409, 'PVT1': 0.5325056, 'UCA1': 0.53239155, 'GAS5': 0.5311399, 'XIST': 0.53106624} 

Node: head and neck squamous cell carcinoma 
{'H19': 0.5616882, 'MALAT1': 0.5607524, 'HOTAIR': 0.55491406, 'MEG3': 0.5542649, 'NEAT1': 0.54832476, 'CDKN2B-AS1': 0.5481134, 'PVT1': 0.5462386, 'UCA1': 0.54618615, 'XIST': 0.5446735, 'GAS5': 0.54435617} 

Node: gastric cardia adenocarcinoma 
{'H19': 0.55072576, 'MALAT1': 0.5483839, 'HOTAIR': 0.5437087, 'MEG3': 0.543411, 'NEAT1': 0.5386435, 'CDKN2B-AS1': 0.5384148, 'UCA1': 0.5374704, 'PVT1': 0.5368404, 'GAS5': 0.53560835, 'XIST': 0.5356001} 

Node: basal cell carcinoma 

Node: multiple myeloma 
{'H19': 0.5598158, 'MALAT1': 0.55856407, 'HOTAIR': 0.5530226, 'MEG3': 0.55232, 'NEAT1': 0.5467408, 'CDKN2B-AS1': 0.5464098, 'UCA1': 0.5448928, 'PVT1': 0.5446685, 'XIST': 0.54314965, 'GAS5': 0.54296577} 

Node: hereditary hemorrhagic telangiectasia 
{'H19': 0.5811626, 'MALAT1': 0.57966876, 'HOTAIR': 0.5719277, 'MEG3': 0.5716347, 'NEAT1': 0.56412995, 'CDKN2B-AS1': 0.56384647, 'UCA1': 0.5613325, 'PVT1': 0.5612029, 'XIST': 0.55913275, 'GAS5': 0.55899125} 

Node: bone disease 
{'H19': 0.5246958, 'MALAT1': 0.52461374, 'MEG3': 0.52214086, 'HOTAIR': 0.5220736, 'NEAT1': 0.519367, 'CDKN2B-AS1': 0.5193611, 'PVT1': 0.5185559, 'UCA1': 0.51840615, 'XIST': 0.51798177, 'GAS5': 0.5177463} 

Node: osteoporosis 
{'H19': 0.5254621, 'MALAT1': 0.5249905, 'HOTAIR': 0.52249014, 'MEG3': 0.52233785, 'NEAT1': 0.5201211, 'CDKN2B-AS1': 0.5197059, 'PVT1': 0.51894444, 'UCA1': 0.5188857, 'GAS5': 0.5181686, 'XIST': 0.5181038} 

Node: brain cancer 
{'H19': 0.528938, 'MALAT1': 0.52850884, 'HOTAIR

Node: skin disease 
{'H19': 0.51611114, 'MALAT1': 0.5154281, 'HOTAIR': 0.51391906, 'MEG3': 0.5136397, 'NEAT1': 0.5122884, 'CDKN2B-AS1': 0.51204926, 'UCA1': 0.51173115, 'PVT1': 0.5116931, 'GAS5': 0.5112037, 'XIST': 0.51111} 

Node: lung carcinoma 
{'H19': 0.5331229, 'MALAT1': 0.5316061, 'HOTAIR': 0.5286331, 'MEG3': 0.5280498, 'NEAT1': 0.52521497, 'CDKN2B-AS1': 0.52493423, 'UCA1': 0.52439946, 'PVT1': 0.5241195, 'XIST': 0.5231819, 'GAS5': 0.5231686} 

Node: trophoblastic neoplasm 
{'H19': 0.5295935, 'MALAT1': 0.52858543, 'HOTAIR': 0.525669, 'MEG3': 0.52543855, 'NEAT1': 0.5225598, 'CDKN2B-AS1': 0.5222748, 'PVT1': 0.52151823, 'UCA1': 0.5212478, 'XIST': 0.52064073, 'GAS5': 0.52030206} 

Node: infertility 
{'H19': 0.5160532, 'MALAT1': 0.5153826, 'HOTAIR': 0.51387954, 'MEG3': 0.5136036, 'NEAT1': 0.51223546, 'CDKN2B-AS1': 0.5120093, 'UCA1': 0.511676, 'PVT1': 0.5116512, 'GAS5': 0.5111507, 'XIST': 0.51107085} 

Node: aortic valve disease 
{'H19': 0.5330765, 'MALAT1': 0.53224814, 'HOTAIR': 0.52906

Node: monocytic leukemia 
{'MALAT1': 0.52988344, 'H19': 0.5297372, 'HOTAIR': 0.5267403, 'MEG3': 0.5267233, 'NEAT1': 0.52358454, 'CDKN2B-AS1': 0.5231963, 'PVT1': 0.5223853, 'UCA1': 0.5217873, 'XIST': 0.5213935, 'GAS5': 0.5211162} 

Node: diabetic retinopathy 
{'H19': 0.53126943, 'MALAT1': 0.53075397, 'HOTAIR': 0.52761483, 'MEG3': 0.5269726, 'NEAT1': 0.5245998, 'CDKN2B-AS1': 0.52406436, 'PVT1': 0.5232883, 'UCA1': 0.52326775, 'GAS5': 0.5224432, 'XIST': 0.52240294} 

Node: renal fibrosis 
{'H19': 0.5164321, 'MALAT1': 0.5162709, 'HOTAIR': 0.5146045, 'MEG3': 0.51442885, 'NEAT1': 0.51302356, 'CDKN2B-AS1': 0.5128092, 'PVT1': 0.5123527, 'UCA1': 0.5123201, 'GAS5': 0.51192427, 'XIST': 0.5118397} 

Node: myelodysplastic syndrome 
{'H19': 0.53061426, 'MALAT1': 0.53017604, 'HOTAIR': 0.52719843, 'MEG3': 0.5270592, 'NEAT1': 0.52396774, 'CDKN2B-AS1': 0.5238153, 'UCA1': 0.5228958, 'PVT1': 0.5228187, 'XIST': 0.52221286, 'GAS5': 0.5219151} 

Node: pituitary cancer 
{'H19': 0.5207817, 'MALAT1': 0.5206321, 

In [11]:
predlist = predict.topK('disease', 'lung cancer', 10)
print("Node:Lung cancer")
predlist

Node:Lung cancer


{'H19': 0.60430837,
 'MALAT1': 0.6027492,
 'HOTAIR': 0.59295803,
 'MEG3': 0.59191,
 'NEAT1': 0.58232975,
 'CDKN2B-AS1': 0.58158594,
 'UCA1': 0.5785693,
 'PVT1': 0.5785343,
 'XIST': 0.57564354,
 'GAS5': 0.57546365}

In [9]:
predlist = predict.topK('disease', 'breast cancer', 20)
print("Node:breast cancer")
predlist

Node:breast cancer


{'H19': 0.66032445,
 'MALAT1': 0.65809596,
 'HOTAIR': 0.64360255,
 'MEG3': 0.6416696,
 'NEAT1': 0.62738156,
 'CDKN2B-AS1': 0.62622434,
 'PVT1': 0.62178516,
 'UCA1': 0.6216517,
 'XIST': 0.61720455,
 'GAS5': 0.61709607,
 'MIAT': 0.6135536,
 'CRNDE': 0.60637075,
 'DANCR': 0.5998458,
 'CYTOR': 0.5988136,
 'SOX2-OT': 0.5971471,
 'LINC-ROR': 0.5954078,
 'SNHG1': 0.5952522,
 'AFAP1-AS1': 0.5945187,
 'CCAT2': 0.59426934,
 'PCAT1': 0.59304976}

In [10]:
predlist = predict.topK('lncRNA', 'PCAT6', 10)
print("Node:PCAT6")
predlist

Node:PCAT6


{'esophageal cancer': 0.66218656,
 'hepatocellular carcinoma': 0.5879968,
 'breast cancer': 0.57722086,
 'colorectal cancer': 0.57497174,
 'stomach cancer': 0.55453163,
 'osteoarthritis': 0.5543028,
 'osteosarcoma': 0.55313647,
 'prostate cancer': 0.5502772,
 'ovarian cancer': 0.5501168,
 'lung cancer': 0.54934317}

In [11]:
predlist = predict.topK('lncRNA', 'MEG3', 10)
print("Node:MEG3")
predlist

Node:MEG3


{'esophageal cancer': 0.7804128,
 'hepatocellular carcinoma': 0.6606383,
 'breast cancer': 0.6416696,
 'colorectal cancer': 0.63789314,
 'osteoarthritis': 0.6011936,
 'stomach cancer': 0.60088235,
 'osteosarcoma': 0.5983643,
 'prostate cancer': 0.5938645,
 'ovarian cancer': 0.5932632,
 'lung cancer': 0.59191}