In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch

import pykeen
from pykeen.triples import TriplesFactory
from pykeen import predict

from utils import load_IDMapping, load_gene_disease_ids, load_train_data, read_triplets
from utils import load_entity_feature
from utils import get_merged_embeddings #dict-->numpy
from utils import roc_auc, pr_auc
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler,MinMaxScaler,LabelEncoder
from sklearn.decomposition import PCA
from mlp_model import MLPScoringModel

data_dir = "../data/dataset"

# device = torch.device(f"cuda:{2}" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

columns = ['head', 'relation', 'tail']
entity2id, relation2id = load_IDMapping(data_dir)
gene_ids, disease_ids = load_gene_disease_ids(data_dir, entity2id)
print(f"Gene ids: {len(gene_ids)}, Disease ids: {len(disease_ids)}")
train_data, test_data, train_data_id, test_data_id = load_train_data(data_dir, entity2id, relation2id, 0)
known_triplets_df = pd.read_csv(os.path.join(data_dir, "gene_disease_triplet.tsv"), sep='\t', names=['node1', 'relation', 'node2'], dtype=object)
known_triplets = read_triplets(os.path.join(data_dir, "gene_disease_triplet.tsv"), entity2id, relation2id)
known_triplets = torch.tensor(known_triplets, dtype=torch.long, device=device)
print(f" Gene_Disease triplets: {len(known_triplets)}")

model_res = './train_results/mlp_ablation_merged0215/MLP_merged_20250215-144803'
checkpoint = torch.load(os.path.join(model_res, "best_MLPmodel.pth"), map_location='cpu')
print(f"Best epoch: {checkpoint['epoch']}")
mlp_model = checkpoint['model']

compgcn_model = torch.load("/home/worker/users/ZC/KnowledgeGraph/TarKG_reason/model/train_results/baseline/CompGCN_20250109-215102/trained_model.pkl", map_location='cpu')
transe_model = torch.load("/home/worker/users/ZC/KnowledgeGraph/TarKG_reason/model/train_results/baseline/TransE_20250109-215115/trained_model.pkl", map_location='cpu')
transr_model = torch.load("/home/worker/users/ZC/KnowledgeGraph/TarKG_reason/model/train_results/baseline/TransR_20250109-215206/trained_model.pkl", map_location='cpu')
rotate_model = torch.load("/home/worker/users/ZC/KnowledgeGraph/TarKG_reason/model/train_results/baseline/RotatE_20250109-215430/trained_model.pkl", map_location='cpu')
distmult_model = torch.load("/home/worker/users/ZC/KnowledgeGraph/TarKG_reason/model/train_results/baseline/DistMult_20250111-102406/trained_model.pkl", map_location='cpu')
complex_model = torch.load("/home/worker/users/ZC/KnowledgeGraph/TarKG_reason/model/train_results/baseline/ComplEx_20250111-102451/trained_model.pkl", map_location='cpu')
rescal_model = torch.load("/home/worker/users/ZC/KnowledgeGraph/TarKG_reason/model/train_results/baseline/RESCAL_20250117-214900/trained_model.pkl", map_location='cpu')
conve_model = torch.load("/home/worker/users/ZC/KnowledgeGraph/TarKG_reason/model/train_results/baseline/ConvE_20250124-172107/trained_model.pkl", map_location='cpu')

# model = model.to(device)

  from .autonotebook import tqdm as notebook_tqdm


 num_entity: 945552
 num_relation: 126
Gene ids: 151231, Disease ids: 26996
Load data from ../data/dataset and 1 fold:
 num_train_triples: 1879648
 num_test_triples: 469912
 Gene_Disease triplets: 2349560
Best epoch: 920


In [2]:
# gene_disease_tri = pd.read_csv("../data/data/gene_disease_relation.csv", dtype=object)
disease_feature = pd.read_csv("../data/data/disease_feature.csv", dtype=object)
disease_feature['target_id'] = disease_feature['kind'] + "::" + disease_feature['id']
disease_feature = disease_feature[['target_id', 'name']]

gene_feature = pd.read_csv("../data/data/gene_feature.csv", dtype=object)
gene_feature['target_id'] = gene_feature['kind'] + "::" + gene_feature['id']
gene_feature = gene_feature[['target_id', 'name']]

## 1.获取MLP嵌入

In [3]:
def get_mlp_embeddings(mlp_model_res):
    entity_embeddings = np.load(os.path.join(mlp_model_res, 'entity_embedding.npy'))
    relation_embeddings = np.load(os.path.join(mlp_model_res, 'relation_embedding.npy'))

    return entity_embeddings, relation_embeddings

entity_embeddings, relation_embeddings = get_mlp_embeddings(model_res)
entity_embeddings = torch.tensor(entity_embeddings, dtype=torch.float32).to(device)
relation_embeddings = torch.tensor(relation_embeddings, dtype=torch.float32).to(device)
print(entity_embeddings.shape, relation_embeddings.shape)
print(entity_embeddings.device, relation_embeddings.device)

torch.Size([178227, 120]) torch.Size([126, 20])
cpu cpu


## 2.模型预测

In [4]:
# MLP预测
def score_triplet(model, triplet, entity_embeddings, relation_embeddings, device):
    """
    给定模型和三元组，生成评分。
    triplet: 包含 head, relation, tail 的三元组
    entity_embeddings: 实体嵌入 tensor
    relation_embeddings: 关系嵌入 tensor
    """
    # 获取 head, relation 和 tail 的嵌入
    head_emb = entity_embeddings[triplet[0]]
    relation_emb = relation_embeddings[triplet[1]]
    tail_emb = entity_embeddings[triplet[2]]
    
    # 将它们沿着最后一个维度拼接
    input_emb = torch.cat([head_emb, relation_emb, tail_emb], dim=-1)
    
    # 将输入传递到模型并计算分数
    score = model(input_emb)#.squeeze()#.item()

    return score

def score_batch_triplets(model, fix_entity, relation, perturb_entity_index, perturb_type, entity_embeddings, relation_embeddings, device):
    """
    给定模型、三元组和扰动实体索引，批量生成评分。
    fix_entity: 固定实体
    relation: 关系
    perturb_entity_index: 扰动实体索引列表
    perturb_type: 'head' 或 'tail'
    entity_embeddings: 实体嵌入 tensor, 已在 device 上
    relation_embeddings: 关系嵌入 tensor, 已在 device 上
    """

    # 获取固定实体和关系的嵌入
    fix_emb = entity_embeddings[fix_entity]  # (embedding_dim,)
    relation_emb = relation_embeddings[relation]  # (embedding_dim,)
    
    # 获取扰动实体的嵌入
    perturb_emb = entity_embeddings[perturb_entity_index]  # (len(perturb_entity_index), embedding_dim)
    
    # 扩展固定实体和关系嵌入
    fix_emb = fix_emb.unsqueeze(0).expand(len(perturb_entity_index), -1)  # (len(perturb_entity_index), embedding_dim)
    relation_emb = relation_emb.unsqueeze(0).expand(len(perturb_entity_index), -1)  # (len(perturb_entity_index), embedding_dim)
    
    # 根据扰动类型拼接输入嵌入
    if perturb_type == 'tail':
        input_emb = torch.cat([fix_emb, relation_emb, perturb_emb], dim=-1)  # (len(perturb_entity_index), 3 * embedding_dim)
    elif perturb_type == 'head':
        input_emb = torch.cat([perturb_emb, relation_emb, fix_emb], dim=-1)  # (len(perturb_entity_index), 3 * embedding_dim)

    # 计算分数并返回结果
    scores = model(input_emb).squeeze()  # (len(perturb_entity_index),)
    return scores#.cpu().detach().numpy()  # 直接返回结果

def predict_hrt(model, entity_embeddings, relation_embeddings, entity2id, relation2id, gene_ids, disease_ids,
                 query_entity, query_entity_location, query_relation, known_triplets, device):
    # known_triplets = read_triplets(os.path.join(data_dir, "gene_disease_triplet.tsv"), entity2id, relation2id)
    head_relation_triplets = known_triplets[:, :2]
    tail_relation_triplets = torch.stack((known_triplets[:, 2], known_triplets[:, 1])).transpose(0, 1)

    # gene_ids, disease_ids = load_gene_disease_ids(data_dir, entity2id)
    entity_index = entity2id[query_entity]
    relation_index = relation2id[query_relation]
    query_entity_type = query_entity.split("::")[0]

    if query_entity_type == "Gene":
        target_ids = torch.tensor(disease_ids, dtype=torch.long, device=entity_embeddings.device)
    elif query_entity_type == "Disease":
        target_ids =  torch.tensor(gene_ids, dtype=torch.long, device=entity_embeddings.device)

    # 剔除已存在的三元组
    if query_entity_location == "head":
        target_entity_location = "tail"

        subject_relation = torch.tensor([entity_index, relation_index], dtype=torch.long, device=entity_embeddings.device)
        delete_index = torch.sum(head_relation_triplets == subject_relation, dim=1)
        delete_index = torch.nonzero(delete_index == 2).squeeze()
        delete_entity_index = known_triplets[delete_index, 2]
        mask = torch.isin(target_ids, delete_entity_index)
        target_entity_index = target_ids[~mask]

    elif query_entity_location == "tail":
        target_entity_location = "head"

        object_relation = torch.tensor([entity_index, relation_index], dtype=torch.long, device=entity_embeddings.device)
        delete_index = torch.sum(tail_relation_triplets == object_relation, dim=1)
        delete_index = torch.nonzero(delete_index == 2).squeeze()
        delete_entity_index = known_triplets[delete_index, 0]
        mask = torch.isin(target_ids, delete_entity_index)
        target_entity_index = target_ids[~mask]

    model.eval()  # 切换到评估模式
    with torch.no_grad():
        scores = score_batch_triplets(model, entity_index, relation_index, target_entity_index, target_entity_location, entity_embeddings, relation_embeddings, device)
        index_scores = torch.cat([target_entity_index.unsqueeze(1), scores.unsqueeze(1)], dim=1)
    scores_df = pd.DataFrame(index_scores.cpu().numpy(), columns=['target_index', 'score'])
    scores_df = scores_df.sort_values(by='score', ascending=False, ignore_index=True)
    scores_df['target_index'] = scores_df['target_index'].astype('int64')

    id2entity = {v: k for k, v in entity2id.items()}
    scores_df.insert(1, 'target_id', scores_df['target_index'].map(id2entity))
    scores_df['rank'] = scores_df.index + 1

    return scores_df

def predict_hrt_filter(model, entity_embeddings, relation_embeddings, entity2id, relation2id, gene_ids, disease_ids,
                 query_entity, query_entity_location, query_relation, known_triplets, device):
    # known_triplets = read_triplets(os.path.join(data_dir, "gene_disease_triplet.tsv"), entity2id, relation2id)
    head_ents = known_triplets[:, 0]
    tail_ents = known_triplets[:, 2]
    head_tail_triplets = torch.stack((known_triplets[:, 0], known_triplets[:, 2])).transpose(0, 1)

    # gene_ids, disease_ids = load_gene_disease_ids(data_dir, entity2id)
    entity_index = entity2id[query_entity]
    relation_index = relation2id[query_relation]
    query_entity_type = query_entity.split("::")[0]

    if query_entity_type == "Gene":
        target_ids = torch.tensor(disease_ids, dtype=torch.long, device=entity_embeddings.device)
    elif query_entity_type == "Disease":
        target_ids =  torch.tensor(gene_ids, dtype=torch.long, device=entity_embeddings.device)

    # 剔除与查询实体关联的所有实体（不考虑方向）
    head_matches = torch.nonzero(head_ents == entity_index, as_tuple=True)[0]
    tail_matches = torch.nonzero(tail_ents == entity_index, as_tuple=True)[0]
    related_entities = torch.cat([
        known_triplets[head_matches, 2],
        known_triplets[tail_matches, 0]
    ]).unique()
    mask = torch.isin(target_ids, related_entities)
    target_entity_index = target_ids[~mask]

    if query_entity_location == "head":
        target_entity_location = "tail"
    elif query_entity_location == "tail":
        target_entity_location = "head"

    model.eval()  # 切换到评估模式
    with torch.no_grad():
        scores = score_batch_triplets(model, entity_index, relation_index, target_entity_index, target_entity_location, entity_embeddings, relation_embeddings, device)
        index_scores = torch.cat([target_entity_index.unsqueeze(1), scores.unsqueeze(1)], dim=1)
    scores_df = pd.DataFrame(index_scores.cpu().numpy(), columns=['target_index', 'score'])
    scores_df = scores_df.sort_values(by='score', ascending=False, ignore_index=True)
    scores_df['target_index'] = scores_df['target_index'].astype('int64')

    id2entity = {v: k for k, v in entity2id.items()}
    scores_df.insert(1, 'target_id', scores_df['target_index'].map(id2entity))
    scores_df['rank'] = scores_df.index + 1

    return scores_df


# KGE预测
def score_triplets_kge(model, fix_entity, relation, perturb_entity_index, perturb_type, entity2id, relation2id, device):
    perturb_entities = perturb_entity_index.cpu().numpy()
    n = len(perturb_entities)

    # 创建逆映射：ID -> 字符串
    id2entity = {v: k for k, v in entity2id.items()}  # 实体 ID 到标签
    id2relation = {v: k for k, v in relation2id.items()}  # 关系 ID 到标签

    fix_entity_str = id2entity[fix_entity]  # 固定实体的字符串形式
    relation_str = id2relation[relation]    # 关系的字符串形式
    perturb_entities_str = [id2entity[idx] for idx in perturb_entities]  # 扰动实体的字符串列表

    # 初始化三元组数组
    triples = np.zeros((n, 3), dtype=object)  # 使用 object 类型存储字符串
    
    if perturb_type == "head":
        triples[:, 0] = perturb_entities_str  # head
        triples[:, 1] = relation_str          # relation
        triples[:, 2] = fix_entity_str        # tail
    elif perturb_type == "tail":
        triples[:, 0] = fix_entity_str        # head
        triples[:, 1] = relation_str          # relation
        triples[:, 2] = perturb_entities_str  # tail
    else:
        raise ValueError("perturb_type must be 'head' or 'tail'")
    
    triplets_tf = TriplesFactory.from_labeled_triples(triples, entity_to_id=entity2id, relation_to_id=relation2id, create_inverse_triples=False)
    score = model.predict_hrt(triplets_tf.mapped_triples).detach().cpu().numpy()
    score = pd.DataFrame(np.concatenate([triplets_tf.triples, score.reshape(-1, 1)], axis=1), columns=['head', 'relation', 'tail', 'score'])
    score = score.sort_values(by='score', ascending=False, ignore_index=True)
    score['rank'] = score.index + 1
    
    return score
    
def predict_hrt_kge(model, entity2id, relation2id, gene_ids, disease_ids, query_entity, query_entity_location, query_relation, known_triplets, device):
    head_ents = known_triplets[:, 0]
    tail_ents = known_triplets[:, 2]
    head_tail_triplets = torch.stack((known_triplets[:, 0], known_triplets[:, 2])).transpose(0, 1)

    # gene_ids, disease_ids = load_gene_disease_ids(data_dir, entity2id)
    entity_index = entity2id[query_entity]
    relation_index = relation2id[query_relation]
    query_entity_type = query_entity.split("::")[0]

    if query_entity_type == "Gene":
        target_ids = torch.tensor(disease_ids, dtype=torch.long)
    elif query_entity_type == "Disease":
        target_ids =  torch.tensor(gene_ids, dtype=torch.long)

    # 剔除与查询实体关联的所有实体（不考虑方向）
    head_matches = torch.nonzero(head_ents == entity_index, as_tuple=True)[0]
    tail_matches = torch.nonzero(tail_ents == entity_index, as_tuple=True)[0]
    related_entities = torch.cat([
        known_triplets[head_matches, 2],
        known_triplets[tail_matches, 0]
    ]).unique()
    mask = torch.isin(target_ids, related_entities)
    target_entity_index = target_ids[~mask]

    if query_entity_location == "head":
        target_entity_location = "tail"
    elif query_entity_location == "tail":
        target_entity_location = "head"

    model.eval()  # 切换到评估模式
    with torch.no_grad():
        scores = score_triplets_kge(model, entity_index, relation_index, target_entity_index, target_entity_location, entity2id, relation2id, device)
        return scores
        index_scores = torch.cat([target_entity_index.unsqueeze(1), scores.unsqueeze(1)], dim=1)
    scores_df = pd.DataFrame(index_scores.cpu().numpy(), columns=['target_index', 'score'])
    scores_df = scores_df.sort_values(by='score', ascending=False, ignore_index=True)
    scores_df['target_index'] = scores_df['target_index'].astype('int64')

    id2entity = {v: k for k, v in entity2id.items()}
    scores_df.insert(1, 'target_id', scores_df['target_index'].map(id2entity))
    scores_df['rank'] = scores_df.index + 1

    return scores_df


In [5]:
new_targets = pd.read_csv("./train_results/case_res/0304/new_targets_30.tsv", sep='\t', dtype=object)
new_targets

Unnamed: 0,d_name,d_id,g_name,g_id
0,阿尔兹海默症,DOID:10652,ISG15_HUMAN,P05161
1,癌症,DOID:162,MOT11_HUMAN,Q8NCK7
2,结肠炎,DOID:0060180,NLRC5_HUMAN,Q86WI3
3,炎症,MESH:D007249,GLO2_HUMAN,Q16775
4,肺纤维化,DOID:3770,TPIS_HUMAN,P60174
5,肥胖,DOID:9970,TIGAR_HUMAN,Q9NQ88
6,肝癌,DOID:3571,DYR1A_HUMAN,Q13627
7,癌症,DOID:162,IRGQ_HUMAN,Q8WZA9
8,三阴性乳腺癌,DOID:0060081,SOSD1_HUMAN,Q6X4U4
9,骨质疏松症,DOID:11476,CCN3_HUMAN,P48745


In [22]:
gene_ranks = []
disease_ranks = []

query_relation = 'Gene:Disease::drug targets'
# model = mlp_model
# model = compgcn_model
# model = transe_model
# model = transr_model
# model = rotate_model
# model = distmult_model
# model = complex_model
# model = rescal_model
model = conve_model
for i in range(len(new_targets)):
    disease = 'Disease::' + new_targets.loc[i, 'd_id']
    gene = 'Gene::' + new_targets.loc[i, 'g_id']
    # print(disease)

    # predict_scores_df1 = predict_hrt_filter(model, entity_embeddings, relation_embeddings, entity2id, relation2id, gene_ids, disease_ids,
    #                              disease, 'tail', query_relation, known_triplets, device)
    # gene_rank = predict_scores_df1[predict_scores_df1['target_id'] == gene]['rank'].values[0]

    # predict_scores_df2 = predict_hrt_filter(model, entity_embeddings, relation_embeddings, entity2id, relation2id, gene_ids, disease_ids,
    #                                 gene, 'head', query_relation, known_triplets, device)
    # disease_rank = predict_scores_df2[predict_scores_df2['target_id'] == disease]['rank'].values[0]

    predict_scores_kge_df1 = predict_hrt_kge(model, entity2id, relation2id, gene_ids, disease_ids, disease, 'tail', query_relation, known_triplets, device)
    gene_rank = predict_scores_kge_df1[predict_scores_kge_df1['head'] == gene]['rank'].values[0]

    predict_scores_kge_df2 = predict_hrt_kge(model, entity2id, relation2id, gene_ids, disease_ids, gene, 'head', query_relation, known_triplets, device)
    disease_rank = predict_scores_kge_df2[predict_scores_kge_df2['tail'] == disease]['rank'].values[0]

    gene_ranks.append(gene_rank)
    disease_ranks.append(disease_rank)

new_targets_rank = new_targets.copy(deep=True)
new_targets_rank['gene_rank'] = gene_ranks
new_targets_rank['disease_rank'] = disease_ranks

new_targets_rank

Reconstructing all label-based triples. This is expensive and rarely needed.
Reconstructing all label-based triples. This is expensive and rarely needed.
Reconstructing all label-based triples. This is expensive and rarely needed.
Reconstructing all label-based triples. This is expensive and rarely needed.
Reconstructing all label-based triples. This is expensive and rarely needed.
Reconstructing all label-based triples. This is expensive and rarely needed.
Reconstructing all label-based triples. This is expensive and rarely needed.
Reconstructing all label-based triples. This is expensive and rarely needed.
Reconstructing all label-based triples. This is expensive and rarely needed.
Reconstructing all label-based triples. This is expensive and rarely needed.
Reconstructing all label-based triples. This is expensive and rarely needed.
Reconstructing all label-based triples. This is expensive and rarely needed.
Reconstructing all label-based triples. This is expensive and rarely needed.

Unnamed: 0,d_name,d_id,g_name,g_id,gene_rank,disease_rank
0,阿尔兹海默症,DOID:10652,ISG15_HUMAN,P05161,140329,47
1,癌症,DOID:162,MOT11_HUMAN,Q8NCK7,13374,4392
2,结肠炎,DOID:0060180,NLRC5_HUMAN,Q86WI3,139387,171
3,炎症,MESH:D007249,GLO2_HUMAN,Q16775,130246,28
4,肺纤维化,DOID:3770,TPIS_HUMAN,P60174,145275,276
5,肥胖,DOID:9970,TIGAR_HUMAN,Q9NQ88,123473,132
6,肝癌,DOID:3571,DYR1A_HUMAN,Q13627,146165,27
7,癌症,DOID:162,IRGQ_HUMAN,Q8WZA9,147739,5759
8,三阴性乳腺癌,DOID:0060081,SOSD1_HUMAN,Q6X4U4,143659,82
9,骨质疏松症,DOID:11476,CCN3_HUMAN,P48745,92790,152


In [23]:
# new_targets_rank.to_csv("./train_results/case_res/0304/mlp_target_ranks.csv", index=False)
# new_targets_rank.to_csv("./train_results/case_res/0304/compgcn_target_ranks.csv", index=False)
# new_targets_rank.to_csv("./train_results/case_res/0304/transe_target_ranks.csv", index=False)
# new_targets_rank.to_csv("./train_results/case_res/0304/transr_target_ranks.csv", index=False)
# new_targets_rank.to_csv("./train_results/case_res/0304/rotate_target_ranks.csv", index=False)
# new_targets_rank.to_csv("./train_results/case_res/0304/distmult_target_ranks.csv", index=False)
# new_targets_rank.to_csv("./train_results/case_res/0304/complex_target_ranks.csv", index=False)
# new_targets_rank.to_csv("./train_results/case_res/0304/rescal_target_ranks.csv", index=False)
new_targets_rank.to_csv("./train_results/case_res/0304/conve_target_ranks.csv", index=False)

In [None]:
# query_entity = 'Disease::DOID:8577' #ulcerative colitis

node1 = "Disease::DOID:10652"
node2 = "Gene::P05161"
query_relation = 'Gene:Disease::drug targets'

known_triplets_df[((known_triplets_df['node1'] == node1) & (known_triplets_df['node2'] == node2)|
               (known_triplets_df['node1'] == node2) & (known_triplets_df['node2'] == node1))]

In [None]:
predict_scores_kge_df1 = predict_hrt_kge(model, entity2id, relation2id, gene_ids, disease_ids, node1, 'tail', query_relation, known_triplets, device)
print(predict_scores_kge_df1[predict_scores_kge_df1['head'] == node2]['rank'].values[0])

predict_scores_kge_df2 = predict_hrt_kge(model, entity2id, relation2id, gene_ids, disease_ids, node2, 'head', query_relation, known_triplets, device)
print(predict_scores_kge_df2[predict_scores_kge_df2['tail'] == node1]['rank'].values[0])

In [None]:
predict_scores_df1 = predict_hrt_filter(model, entity_embeddings, relation_embeddings, entity2id, relation2id, gene_ids, disease_ids,
                                 node1, 'tail', query_relation, known_triplets, device)
print(predict_scores_df1[predict_scores_df1['target_id'] == node2]['rank'].values[0])

predict_scores_df2 = predict_hrt_filter(model, entity_embeddings, relation_embeddings, entity2id, relation2id, gene_ids, disease_ids,
                                 node2, 'head', query_relation, known_triplets, device)
print(predict_scores_df2[predict_scores_df2['target_id'] == node1]['rank'].values[0])
predict_scores_df1