In [1]:
import torch
import os
from configparser import ConfigParser
import numpy as np
import scipy.sparse as sp
import tensorflow as tf
import numpy as np
import networkx as nx
import pandas as pd



In [2]:


def load_data(sample_data_path):
    '''
    加载邻接矩阵，这里利用networkx读取文件，生成图和邻接矩阵
    生成的节点的编号是根据节点在文件中出现的顺序进行编号
    :param sample_data_path:
    :return:
    '''

    df = pd.read_csv(sample_data_path, sep=",")
    g = nx.from_pandas_edgelist(df, "lncRNA", "Disease")
    # g = nx.read_edgelist(sample_data_path)
    adj = nx.adjacency_matrix(g)

    nodeIndex = dict()
    nodeAll = dict()
    
    # 得到node 和 node的 序列
    data = df.values
    nodeLncRNA = dict()
    nodeDisease = dict()
    g_nodes = list(g.nodes)

    for i in range(len(g_nodes)):
        nodeAll[g_nodes[i]] = i
        if g_nodes[i] in data[:, 0]:
            nodeLncRNA[g_nodes[i]] = i
        if g_nodes[i] in data[:, 1]:
            nodeDisease[g_nodes[i]] = i

    nodeIndex["index_lncRNA"] = nodeLncRNA
    nodeIndex["index_disease"] = nodeDisease

    return adj, nodeIndex, nodeAll



class Predict():
    def __init__(self):
        self.hidden_emb = None
        self.adj_orig = None
        self.nodeList = None
        self.nodeAll = None

    def load_model_adj(self, config_path):
        '''
        load hidden_emb and adj
        :param config_path:
        :return:
        '''
        if os.path.exists(config_path) and (os.path.split(config_path)[1].split('.')[0] == 'config') and (os.path.splitext(config_path)[1].split('.')[1] == 'cfg'):
            # load config file
            config = ConfigParser()
            config.read(config_path)
            section = config.sections()[0]

            # data catalog path
            data_catalog = config.get(section, "data_catalog")
            # train file path
            train_file_name = config.get(section, "train_file_name")
            # model save/load path
            model_path = config.get(section, "model_path")
            

            if not os.path.exists(model_path) and os.path.exists(os.path.join(data_catalog, train_file_name)):
                raise FileNotFoundError('Not found file!')
            
            model_path_ = "{}/lncRNA_disease.npy".format(model_path)
            self.hidden_emb = np.load(model_path_)

            # load 原始邻接矩阵，除去对角线元素
            adj, self.nodeList, self.nodeAll = load_data(os.path.join(data_catalog, train_file_name))
            self.adj_orig = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape)
            self.adj_orig.eliminate_zeros()
        else:
            raise FileNotFoundError('File config.cfg not found : ' + config_path)

    def predict(self):
        def sigmoid(x):
            return 1 / (1 + np.exp(-x))
        # 内积
        adj_rec = np.dot(self.hidden_emb, self.hidden_emb.T)
        adj_rec = sigmoid(adj_rec)
        return self.adj_orig, adj_rec

    def topK(self, type='disease', node='YJL081C', k=10):
        """
        @param type: disease | lncRNA
        @param node: node-name
        @param k: top-K
        @return: dict()
        """
        node_id = self.nodeList["index_{}".format(type)][node]

        if type == 'disease':
            name = "index_{}".format('lncRNA')
            node_to_list = np.array(list(self.nodeList[name].values()))
        else:
            name = "index_{}".format('disease')
            node_to_list = np.array(list(self.nodeList[name].values()))

        params = tf.convert_to_tensor(self.hidden_emb)
        lookup_embeding_params = tf.nn.embedding_lookup(
            params=params,
            ids=node_to_list,
            max_norm=None,
            name=None)

        node_emb = tf.nn.embedding_lookup(params=params,ids=node_id, max_norm=None,name=None)
        scores = tf.math.sigmoid(tf.matmul(lookup_embeding_params, tf.reshape(node_emb, (-1,1))))

        top_k = tf.nn.top_k(tf.reshape(scores,(1,-1))[0], k, sorted=True)

        res = node_to_list[top_k[1].numpy()]
        topKindex = {}
        items = list(self.nodeAll.items())
        for i in range(len(res)):
            key = items[res[i]][0]
            topKindex[key] = top_k[0][i].numpy()

        return topKindex

In [3]:
config_path = os.path.join(os.getcwd(), 'config.cfg')
predict = Predict()
predict.load_model_adj(config_path)

In [12]:
path = "F:/Python/Python_learning/BJUT/Paper/dataset/lncRNA_disease_towsets.csv"
adj, nodeIndex, nodeAll = load_data(path)

In [27]:
for dis in nodeIndex['index_disease']:
    predlist = predict.topK('disease', dis, 10)
    print("Node: {} ".format(dis))
    print(predlist)

Node: hepatocellular carcinoma 
{'MALAT1': 0.54677176, 'H19': 0.54522765, 'CDKN2B-AS1': 0.540593, 'HOTAIR': 0.5372136, 'MEG3': 0.5361009, 'PVT1': 0.5346256, 'NEAT1': 0.5342995, 'UCA1': 0.5327157, 'GAS5': 0.53157175, 'MIATNB': 0.5299909}
Node: stomach cancer 
{'MALAT1': 0.5243366, 'H19': 0.52350336, 'CDKN2B-AS1': 0.521108, 'HOTAIR': 0.51933026, 'MEG3': 0.51873064, 'PVT1': 0.5179672, 'NEAT1': 0.5178218, 'UCA1': 0.51698047, 'GAS5': 0.5163766, 'MIATNB': 0.51569176}
Node: breast cancer 
{'MALAT1': 0.54055834, 'H19': 0.5392642, 'CDKN2B-AS1': 0.5353279, 'HOTAIR': 0.53232425, 'MEG3': 0.53126854, 'PVT1': 0.53000605, 'NEAT1': 0.5297092, 'UCA1': 0.5283341, 'GAS5': 0.5273869, 'MIATNB': 0.525994}
Node: esophageal cancer 
{'MALAT1': 0.63951194, 'H19': 0.6349881, 'CDKN2B-AS1': 0.62197256, 'HOTAIR': 0.61208594, 'MEG3': 0.60876834, 'PVT1': 0.6043802, 'NEAT1': 0.60380715, 'UCA1': 0.5989716, 'GAS5': 0.5956417, 'MIATNB': 0.5933694}
Node: prostate cancer 
{'MALAT1': 0.52930665, 'H19': 0.5281993, 'CDKN2B-AS

Node: chronic kidney disease 
{'MALAT1': 0.5030724, 'H19': 0.5029959, 'CDKN2B-AS1': 0.5026542, 'HOTAIR': 0.5025128, 'MEG3': 0.5024012, 'PVT1': 0.5023214, 'NEAT1': 0.50227267, 'UCA1': 0.50223315, 'BARX1-DT': 0.5022301, 'NOP14-AS1': 0.50220513}
Node: colorectal carcinoma 
{'MALAT1': 0.50709236, 'H19': 0.50687647, 'CDKN2B-AS1': 0.50614846, 'HOTAIR': 0.50563574, 'MEG3': 0.5054554, 'PVT1': 0.5052447, 'NEAT1': 0.5051572, 'UCA1': 0.5049405, 'GAS5': 0.5047703, 'MIAT': 0.50455725}
Node: chronic lymphocytic leukemia 
{'MALAT1': 0.5078969, 'H19': 0.50765973, 'CDKN2B-AS1': 0.5068612, 'HOTAIR': 0.5062696, 'MEG3': 0.50603527, 'PVT1': 0.5058049, 'NEAT1': 0.50567347, 'UCA1': 0.50545144, 'GAS5': 0.50526226, 'MIAT': 0.50502056}
Node: childhood medulloblastoma 
{'MALAT1': 0.5030664, 'H19': 0.5029453, 'CDKN2B-AS1': 0.50263554, 'HOTAIR': 0.50242794, 'MEG3': 0.5023616, 'PVT1': 0.5022628, 'NEAT1': 0.5022334, 'UCA1': 0.5021366, 'GAS5': 0.50205845, 'MIATNB': 0.502016}
Node: adrenocortical carcinoma 
{'MALAT1':

Node: maxillary sinus squamous cell carcinoma 
{'MALAT1': 0.5062739, 'H19': 0.5061114, 'CDKN2B-AS1': 0.50544995, 'HOTAIR': 0.505008, 'MEG3': 0.504846, 'PVT1': 0.50466317, 'NEAT1': 0.504606, 'UCA1': 0.5043894, 'GAS5': 0.5042612, 'MIAT': 0.5040378}
Node: cataract 
{'MALAT1': 0.5057627, 'H19': 0.50561786, 'CDKN2B-AS1': 0.5049739, 'HOTAIR': 0.5045542, 'MEG3': 0.5044311, 'PVT1': 0.5042853, 'NEAT1': 0.50419325, 'UCA1': 0.5039677, 'GAS5': 0.5038833, 'LINC02835': 0.5038762}
Node: allergic rhinitis 
{'MALAT1': 0.5096883, 'H19': 0.5093194, 'CDKN2B-AS1': 0.50814223, 'HOTAIR': 0.5076233, 'MEG3': 0.50740385, 'PVT1': 0.50715715, 'NEAT1': 0.50663775, 'UCA1': 0.506544, 'GAS5': 0.50642556, 'MIATNB': 0.50601524}
Node: chordoma 
{'MALAT1': 0.5067421, 'H19': 0.50648963, 'CDKN2B-AS1': 0.5057842, 'HOTAIR': 0.5052642, 'MEG3': 0.5051444, 'PVT1': 0.5049437, 'NEAT1': 0.5049254, 'UCA1': 0.5046059, 'GAS5': 0.50448745, 'MIAT': 0.50432724}
Node: spinal cord ependymoma 
{'MALAT1': 0.5082328, 'H19': 0.5079849, 'CDKN2

In [11]:
predlist = predict.topK('disease', 'lung cancer', 20)
print("Node:Lung cancer")
predlist

Node:Lung cancer


{'MALAT1': 0.5228399,
 'H19': 0.5219895,
 'CDKN2B-AS1': 0.5198178,
 'HOTAIR': 0.5181221,
 'MEG3': 0.51765996,
 'PVT1': 0.5168938,
 'NEAT1': 0.5168832,
 'UCA1': 0.515985,
 'GAS5': 0.51544946,
 'MIATNB': 0.51503,
 'MIAT': 0.51462996,
 'DANCR': 0.5145188,
 'LINC00645': 0.5140596,
 'CRNDE': 0.51403743,
 'XIST': 0.5138776,
 'KCNQ1OT1': 0.5133226,
 'LINC02835': 0.5131647,
 'CYTOR': 0.51308584,
 'AFAP1-AS1': 0.51300997,
 'SOX2-OT': 0.51295775}

In [8]:
predlist = predict.topK('disease', 'esophageal cancer', 20)
print("Node:breast cancer")
predlist

Node:breast cancer


{'MALAT1': 0.63951194,
 'H19': 0.6349881,
 'CDKN2B-AS1': 0.62197256,
 'HOTAIR': 0.61208594,
 'MEG3': 0.60876834,
 'PVT1': 0.6043802,
 'NEAT1': 0.60380715,
 'UCA1': 0.5989716,
 'GAS5': 0.5956417,
 'MIATNB': 0.5933694,
 'MIAT': 0.59049237,
 'DANCR': 0.58990496,
 'CRNDE': 0.5872192,
 'LINC00645': 0.5865546,
 'XIST': 0.5863835,
 'KCNQ1OT1': 0.58274895,
 'CYTOR': 0.5813588,
 'LINC02835': 0.5813334,
 'AFAP1-AS1': 0.58078814,
 'SOX2-OT': 0.5805251}

In [6]:
predlist = predict.topK('lncRNA', 'XIST', 10)
print("Node:XIST")
predlist

Node:XIST


{'esophageal cancer': 0.5863835,
 'hepatocellular carcinoma': 0.5285112,
 'colorectal cancer': 0.5276216,
 'breast cancer': 0.52469736,
 'osteoarthritis': 0.5189229,
 'prostate cancer': 0.5178109,
 'glioblastoma': 0.5161108,
 'pancreatic cancer': 0.5158422,
 'osteosarcoma': 0.5154561,
 'ovarian cancer': 0.51531225}