In [1]:
import os
import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm
import torch

from utils import load_IDMapping, load_gene_disease_ids, load_train_data, read_triplets
from utils import load_entity_feature
from utils import get_merged_embeddings #dict-->numpy
from utils import roc_auc, pr_auc
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler,MinMaxScaler,LabelEncoder
from sklearn.decomposition import PCA
from mlp_model import MLPScoringModel

data_dir = "../data/dataset" #You can find the data at [Zenodo](https://zenodo.org/records/17156565)(data.zip)
# model_res = './train_results/mlp_ablation_mergedAtt/MLP_merged_20250202-002342'
# model_res = './train_results/mlp_ablation_merged/MLP_merged_20250202-002254'
# model_res = './train_results/mlp_ablation_merged0213/MLP_merged_20250213-095514'
model_res = './train_results/mlp_ablation_merged0215/MLP_merged_20250215-144803' #You can find the training results at [Zenodo](https://zenodo.org/records/17156565)(mlp_ablation.tar.gz)
# model_res = './train_results/mlp_ablation_merged0215/MLP_onlyKGE_20250215-144803'
# model_res = './train_results/mlp_ablation_merged0215/MLP_onlyFeat_20250215-144804'
# model_res = './train_results/mlp_ablation_merged0215/MLP_empty_20250215-144807'

# device = torch.device(f"cuda:{2}" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

columns = ['head', 'relation', 'tail']
entity2id, relation2id = load_IDMapping(data_dir)
gene_ids, disease_ids = load_gene_disease_ids(data_dir, entity2id)
print(f"Gene ids: {len(gene_ids)}, Disease ids: {len(disease_ids)}")
train_data, test_data, train_data_id, test_data_id = load_train_data(data_dir, entity2id, relation2id, 0)
known_triplets = read_triplets(os.path.join(data_dir, "gene_disease_triplet.tsv"), entity2id, relation2id)
known_triplets = torch.tensor(known_triplets, dtype=torch.long, device=device)
print(f" Gene_Disease triplets: {len(known_triplets)}")

checkpoint = torch.load(os.path.join(model_res, "best_MLPmodel.pth"), map_location='cpu')
print(f"Best epoch: {checkpoint['epoch']}")
model = checkpoint['model'].to(device)

  from tqdm.autonotebook import tqdm


 num_entity: 945552
 num_relation: 126
Gene ids: 151231, Disease ids: 26996
Load data from ../data/dataset and 1 fold:
 num_train_triples: 1879648
 num_test_triples: 469912
 Gene_Disease triplets: 2349560
Best epoch: 920


In [2]:
gene_disease_tri = pd.read_csv("../data/data/gene_disease_relation.csv", dtype=object)

disease_feature = pd.read_csv("../data/data/disease_feature.csv", dtype=object)
disease_feature['target_id'] = disease_feature['kind'] + "::" + disease_feature['id']
disease_feature = disease_feature[['target_id', 'name']]

gene_feature = pd.read_csv("../data/data/gene_feature.csv", dtype=object)
gene_feature['target_id'] = gene_feature['kind'] + "::" + gene_feature['id']
gene_feature = gene_feature[['target_id', 'name']]

test_data_id_pos = test_data_id[test_data_id['label'] == 1]
test_triplets = test_data_id_pos[columns].values
# test_triplets = test_triplets[:1000]

test_triplets = torch.tensor(test_triplets, dtype=torch.long, device=device)
print(known_triplets.shape, test_triplets.shape)

torch.Size([2349560, 3]) torch.Size([469912, 3])


# 1.修改嵌入合并操作
使用numpy操作嵌入以代替字典，易于与torch转换

In [None]:
def load_entity_feature_np(file_path):
    # read gene and disease feature embedding from preprocessing file
    gene_feat_emb = np.load(file_path + '/gene_embedding.npy')
    disease_feat_emb = np.load(file_path + '/disease_embedding.npy')

    return gene_feat_emb, disease_feat_emb

def get_full_embeddings(entity_ids, emb_dict, dim):
    """
    将嵌入字典补齐到完整实体集合，缺失的实体补零向量。
    """
    full_emb = {}
    zero_vector = np.zeros(dim)
    for entity_id in entity_ids:
        full_emb[entity_id] = emb_dict.get(entity_id, zero_vector)
    return full_emb

## 2.convert to matrix
def dict_to_matrix(emb_dict, entity_ids):
    """
    将嵌入字典转换为矩阵，行对应实体，列为嵌入向量。
    """
    return np.array([emb_dict[eid] for eid in entity_ids])

## 3.principal component dimension reduction
def reduce_dim(embedding, entity_ids, target_dim):
    """
    使用PCA对嵌入矩阵降维。
    """
    # emb_dim = next(iter(embedding.values())).shape[0]
    # full_emb = get_full_embeddings(entity_ids, embedding, emb_dim)
    # emb_matrix = dict_to_matrix(full_emb, entity_ids)
    emb_matrix = embedding
    emb_dim = emb_matrix.shape[1]

    mms = MinMaxScaler(feature_range=(0,1))
    emb_matrix_scaled = mms.fit_transform(emb_matrix)
    if emb_dim > target_dim:
        pca = PCA(n_components=target_dim, random_state=42)
        reduced_matrix_scaled = pca.fit_transform(emb_matrix_scaled)
        reduced_matrix_scaled1 = mms.fit_transform(reduced_matrix_scaled)
        return reduced_matrix_scaled1
    else:
        return emb_matrix_scaled
   
def get_merged_embeddings_np(kge_ent_emb, kge_rel_emb, entity2id, relation2id,
                           gene_feat_emb, disease_feat_emb, gene_ids, disease_ids, target_dim=300):
    mms = MinMaxScaler(feature_range=(0,1))
    
    entity_ids = disease_ids + gene_ids # the entity ids order is [disease,gene]
    # kge_ent_emb = {id: kge_ent_emb[id] for name, id in entity2id.items()}
    kge_ent_emb = kge_ent_emb[:len(entity_ids)]
    kge_ent_reduced = reduce_dim(kge_ent_emb, entity_ids, target_dim)

    disease_feat_reduced = reduce_dim(disease_feat_emb, disease_ids, target_dim)
    gene_feat_reduced = reduce_dim(gene_feat_emb, gene_ids, target_dim)
    feat_reduced = np.concatenate([disease_feat_reduced, gene_feat_reduced])

    combined_ent_emb = np.concatenate([kge_ent_reduced, feat_reduced], axis=1)
    combined_ent_emb_scaled = mms.fit_transform(combined_ent_emb)
    # combined_ent_emb_scaled_dict = {eid: combined_ent_emb_scaled[idx] for idx, eid in enumerate(entity_ids)}

    relation_ids = sorted(list(relation2id.values()))
    # kge_rel_emb = {id: kge_rel_emb[id] for name, id in relation2id.items()}
    kge_rel_emb = kge_rel_emb[:len(relation_ids)]
    kge_rel_reduced = reduce_dim(kge_rel_emb, relation_ids, target_dim)
    # rel_emb_scaled_dict = {eid: kge_rel_reduced[idx] for idx, eid in enumerate(relation_ids)}

    return combined_ent_emb_scaled, kge_rel_reduced

gene_feat_emb, disease_feat_emb = load_entity_feature_np(data_dir)
kge_ent_emb = np.load(f'{model_res}/entity_embedding.npy')
kge_rel_emb = np.load(f'{model_res}/relation_embedding.npy')
print(gene_feat_emb.shape, disease_feat_emb.shape)
print(kge_ent_emb.shape, kge_rel_emb.shape)

entity_embeddings, relation_embeddings = get_merged_embeddings_np(kge_ent_emb, kge_rel_emb, entity2id, relation2id,
                                                                gene_feat_emb, disease_feat_emb, gene_ids, disease_ids, target_dim=200)
print(entity_embeddings.shape, relation_embeddings.shape)
# print(entity_embeddings.device, relation_embeddings.device)

entity_embeddings = torch.tensor(entity_embeddings, dtype=torch.float32).to(device)
relation_embeddings = torch.tensor(relation_embeddings, dtype=torch.float32).to(device)
print(entity_embeddings.shape, relation_embeddings.shape)
print(entity_embeddings.device, relation_embeddings.device)

(151231, 1280) (26996, 768)
(945552, 20) (252, 20)
(178227, 220) (126, 20)
torch.Size([178227, 220]) torch.Size([126, 20])
cpu cpu


In [3]:
def get_mlp_embeddings(mlp_model_res):
    entity_embeddings = np.load(os.path.join(mlp_model_res, 'entity_embedding.npy'))
    relation_embeddings = np.load(os.path.join(mlp_model_res, 'relation_embedding.npy'))

    return entity_embeddings, relation_embeddings

entity_embeddings, relation_embeddings = get_mlp_embeddings(model_res)
entity_embeddings = torch.tensor(entity_embeddings, dtype=torch.float32).to(device)
relation_embeddings = torch.tensor(relation_embeddings, dtype=torch.float32).to(device)
print(entity_embeddings.shape, relation_embeddings.shape)
print(entity_embeddings.device, relation_embeddings.device)

torch.Size([178227, 120]) torch.Size([126, 20])
cpu cpu


# 2.模型打分、评价和预测(新版本-嵌入使用字典numpy)

In [4]:
def score_triplet(model, triplet, entity_embeddings, relation_embeddings, device):
    """
    给定模型和三元组，生成评分。
    triplet: 包含 head, relation, tail 的三元组
    entity_embeddings: 实体嵌入 tensor
    relation_embeddings: 关系嵌入 tensor
    """
    # 获取 head, relation 和 tail 的嵌入
    head_emb = entity_embeddings[triplet[0]]
    relation_emb = relation_embeddings[triplet[1]]
    tail_emb = entity_embeddings[triplet[2]]
    
    # 将它们沿着最后一个维度拼接
    input_emb = torch.cat([head_emb, relation_emb, tail_emb], dim=-1)
    
    # 将输入传递到模型并计算分数
    score = model(input_emb)#.squeeze()#.item()

    return score

def score_batch_triplets(model, fix_entity, relation, perturb_entity_index, perturb_type, entity_embeddings, relation_embeddings, device):
    """
    给定模型、三元组和扰动实体索引，批量生成评分。
    fix_entity: 固定实体
    relation: 关系
    perturb_entity_index: 扰动实体索引列表
    perturb_type: 'head' 或 'tail'
    entity_embeddings: 实体嵌入 tensor, 已在 device 上
    relation_embeddings: 关系嵌入 tensor, 已在 device 上
    """

    # 获取固定实体和关系的嵌入
    fix_emb = entity_embeddings[fix_entity]  # (embedding_dim,)
    relation_emb = relation_embeddings[relation]  # (embedding_dim,)
    
    # 获取扰动实体的嵌入
    perturb_emb = entity_embeddings[perturb_entity_index]  # (len(perturb_entity_index), embedding_dim)
    
    # 扩展固定实体和关系嵌入
    fix_emb = fix_emb.unsqueeze(0).expand(len(perturb_entity_index), -1)  # (len(perturb_entity_index), embedding_dim)
    relation_emb = relation_emb.unsqueeze(0).expand(len(perturb_entity_index), -1)  # (len(perturb_entity_index), embedding_dim)
    
    # 根据扰动类型拼接输入嵌入
    if perturb_type == 'tail':
        input_emb = torch.cat([fix_emb, relation_emb, perturb_emb], dim=-1)  # (len(perturb_entity_index), 3 * embedding_dim)
    elif perturb_type == 'head':
        input_emb = torch.cat([perturb_emb, relation_emb, fix_emb], dim=-1)  # (len(perturb_entity_index), 3 * embedding_dim)

    # 计算分数并返回结果
    scores = model(input_emb).squeeze()  # (len(perturb_entity_index),)
    return scores#.cpu().detach().numpy()  # 直接返回结果

def predict_hrt(model, entity_embeddings, relation_embeddings, entity2id, relation2id, gene_ids, disease_ids,
                 query_entity, query_entity_location, query_relation, known_triplets, device):
    # known_triplets = read_triplets(os.path.join(data_dir, "gene_disease_triplet.tsv"), entity2id, relation2id)
    head_relation_triplets = known_triplets[:, :2]
    tail_relation_triplets = torch.stack((known_triplets[:, 2], known_triplets[:, 1])).transpose(0, 1)

    # gene_ids, disease_ids = load_gene_disease_ids(data_dir, entity2id)
    entity_index = entity2id[query_entity]
    relation_index = relation2id[query_relation]
    query_entity_type = query_entity.split("::")[0]

    if query_entity_type == "Gene":
        target_ids = torch.tensor(disease_ids, dtype=torch.long, device=entity_embeddings.device)
    elif query_entity_type == "Disease":
        target_ids =  torch.tensor(gene_ids, dtype=torch.long, device=entity_embeddings.device)

    # 剔除已存在的三元组
    if query_entity_location == "head":
        target_entity_location = "tail"

        subject_relation = torch.tensor([entity_index, relation_index], dtype=torch.long, device=entity_embeddings.device)
        delete_index = torch.sum(head_relation_triplets == subject_relation, dim=1)
        delete_index = torch.nonzero(delete_index == 2).squeeze()
        delete_entity_index = known_triplets[delete_index, 2]
        mask = torch.isin(target_ids, delete_entity_index)
        target_entity_index = target_ids[~mask]

    elif query_entity_location == "tail":
        target_entity_location = "head"

        object_relation = torch.tensor([entity_index, relation_index], dtype=torch.long, device=entity_embeddings.device)
        delete_index = torch.sum(tail_relation_triplets == object_relation, dim=1)
        delete_index = torch.nonzero(delete_index == 2).squeeze()
        delete_entity_index = known_triplets[delete_index, 0]
        mask = torch.isin(target_ids, delete_entity_index)
        target_entity_index = target_ids[~mask]

    model.eval()  # 切换到评估模式
    with torch.no_grad():
        scores = score_batch_triplets(model, entity_index, relation_index, target_entity_index, target_entity_location, entity_embeddings, relation_embeddings, device)
        index_scores = torch.cat([target_entity_index.unsqueeze(1), scores.unsqueeze(1)], dim=1)
    scores_df = pd.DataFrame(index_scores.cpu().numpy(), columns=['target_index', 'score'])
    scores_df = scores_df.sort_values(by='score', ascending=False, ignore_index=True)
    scores_df['target_index'] = scores_df['target_index'].astype('int64')

    id2entity = {v: k for k, v in entity2id.items()}
    scores_df.insert(1, 'target_id', scores_df['target_index'].map(id2entity))
    scores_df['rank'] = scores_df.index + 1

    return scores_df

def predict_hrt_filter(model, entity_embeddings, relation_embeddings, entity2id, relation2id, gene_ids, disease_ids,
                 query_entity, query_entity_location, query_relation, known_triplets, device):
    # known_triplets = read_triplets(os.path.join(data_dir, "gene_disease_triplet.tsv"), entity2id, relation2id)
    head_ents = known_triplets[:, 0]
    tail_ents = known_triplets[:, 2]
    head_tail_triplets = torch.stack((known_triplets[:, 0], known_triplets[:, 2])).transpose(0, 1)

    # gene_ids, disease_ids = load_gene_disease_ids(data_dir, entity2id)
    entity_index = entity2id[query_entity]
    relation_index = relation2id[query_relation]
    query_entity_type = query_entity.split("::")[0]

    if query_entity_type == "Gene":
        target_ids = torch.tensor(disease_ids, dtype=torch.long, device=entity_embeddings.device)
    elif query_entity_type == "Disease":
        target_ids =  torch.tensor(gene_ids, dtype=torch.long, device=entity_embeddings.device)

    # 剔除与查询实体关联的所有实体（不考虑方向）
    head_matches = torch.nonzero(head_ents == entity_index, as_tuple=True)[0]
    tail_matches = torch.nonzero(tail_ents == entity_index, as_tuple=True)[0]
    related_entities = torch.cat([
        known_triplets[head_matches, 2],
        known_triplets[tail_matches, 0]
    ]).unique()
    mask = torch.isin(target_ids, related_entities)
    target_entity_index = target_ids[~mask]

    if query_entity_location == "head":
        target_entity_location = "tail"
    elif query_entity_location == "tail":
        target_entity_location = "head"

    model.eval()  # 切换到评估模式
    with torch.no_grad():
        scores = score_batch_triplets(model, entity_index, relation_index, target_entity_index, target_entity_location, entity_embeddings, relation_embeddings, device)
        index_scores = torch.cat([target_entity_index.unsqueeze(1), scores.unsqueeze(1)], dim=1)
    scores_df = pd.DataFrame(index_scores.cpu().numpy(), columns=['target_index', 'score'])
    scores_df = scores_df.sort_values(by='score', ascending=False, ignore_index=True)
    scores_df['target_index'] = scores_df['target_index'].astype('int64')

    id2entity = {v: k for k, v in entity2id.items()}
    scores_df.insert(1, 'target_id', scores_df['target_index'].map(id2entity))
    scores_df['rank'] = scores_df.index + 1

    return scores_df

def predict_triplets_scores(model, entity_embeddings, relation_embeddings, triplets: pd.DataFrame, device):
    """
    给定模型和三元组，批量预测三元组的分数。
    triplets: 包含 head, relation, tail 的三元组, 为 DataFrame结构
    entity_embeddings: 实体嵌入 tensor, 已在 device 上
    relation_embeddings: 关系嵌入 tensor, 已在 device 上
    """
    # 从 triplets 中提取 head, relation, tail 的索引
    head_entity_index = triplets.iloc[:, 0]#.to(device)
    relation_index = triplets.iloc[:, 1]#.to(device)
    tail_entity_index = triplets.iloc[:, 2]#.to(device)
    
    # 获取头实体、关系和尾实体的嵌入
    head_emb = entity_embeddings[head_entity_index]  # (len(triplets), embedding_dim)
    relation_emb = relation_embeddings[relation_index]  # (len(triplets), embedding_dim)
    tail_emb = entity_embeddings[tail_entity_index]  # (len(triplets), embedding_dim)
    
    # 拼接 head, relation, tail 的嵌入
    input_emb = torch.cat([head_emb, relation_emb, tail_emb], dim=-1)  # (len(triplets), 3 * embedding_dim)
    
    # 计算分数
    model.eval()
    with torch.no_grad():
        scores = model(input_emb).squeeze()  # (len(triplets),)
    
    return scores.cpu().detach().numpy()  # 转换为 NumPy 数组并返回

def sort_and_rank(score, target):
    _, indices = torch.sort(score, dim=0, descending=True)
    indices = torch.nonzero(indices == target.view(-1, 1))
    indices = indices[:, 1].view(-1)
    return indices

def calc_mrr(model, test_triplets, entity_embeddings, relation_embeddings, known_triplets, device, hits=[1, 3, 10]):
    """
    计算测试集上的 MR, MRR, Hit@K。
    test_triplets: 测试集 test_triplets
    model: 训练好的评分模型
    entity_embeddings: 实体嵌入 tensor, 已在 device 上
    relation_embeddings: 关系嵌入 tensor, 已在 device 上
    known_triplets: 所有已知的三元组集合（例如，训练集+测试集的三元组）
    device: 设备 (CPU 或 GPU)
    k_values: Hit@K 的 K 值列表，例如 [1, 3, 10]
    """
    model.eval()  # 切换到评估模式
    with torch.no_grad():
        ranks_s = []
        ranks_o = []

        head_relation_triplets = known_triplets[:, :2]
        tail_relation_triplets = torch.stack((known_triplets[:, 2], known_triplets[:, 1])).transpose(0, 1)

        subject_relation_map = {}
        object_relation_map = {}

        all_entities = torch.arange(len(entity_embeddings), device=entity_embeddings.device)
        for i in range(len(test_triplets)):
            test_triplet = test_triplets[i]
            # if i % 100 == 0:
            #     print(f" Processing triplet {i}/{len(test_triplets)}")
            
            # target = score_triplet(model, test_triplet, entity_embeddings, relation_embeddings, device)
            subject, relation, object_ = test_triplet[0], test_triplet[1], test_triplet[2]

            # Perturb object (head is fixed)
            subject_relation = test_triplet[:2]  # (subject, relation)
            subject_relation_key = (subject_relation[0].item(), subject_relation[1].item())

            if subject_relation_key not in subject_relation_map:
                delete_index = torch.sum(head_relation_triplets == subject_relation, dim=1)
                delete_index = torch.nonzero(delete_index == 2).squeeze()
                delete_entity_index = known_triplets[delete_index, 2] #columns 3, is tail entity 
                
                mask = torch.isin(all_entities, delete_entity_index)
                perturb_entity_index = all_entities[~mask]
                subject_relation_map[subject_relation_key] = perturb_entity_index

            perturb_entity_index = subject_relation_map[subject_relation_key]
            perturb_entity_index = torch.cat((perturb_entity_index, object_.view(-1)))
            scores = score_batch_triplets(model, subject, relation, perturb_entity_index, 'tail', entity_embeddings, relation_embeddings, device)
            target = torch.tensor(len(perturb_entity_index) - 1).to(entity_embeddings.device)

            ranks_s.append(sort_and_rank(scores, target))

            # Perturb subject (tail is fixed)
            object_relation  = torch.cat([object_.unsqueeze(0), relation.unsqueeze(0)], dim=0)  # (subject, relation)
            object_relation_key = (object_relation[0].item(), object_relation[1].item())

            if object_relation_key  not in object_relation_map:
                delete_index = torch.sum(tail_relation_triplets == object_relation, dim=1)
                delete_index = torch.nonzero(delete_index == 2).squeeze()
                delete_entity_index = known_triplets[delete_index, 0] #columns 1, is head entity
                
                mask = torch.isin(all_entities, delete_entity_index)
                perturb_entity_index = all_entities[~mask]
                object_relation_map[object_relation_key] = perturb_entity_index

            perturb_entity_index = object_relation_map[object_relation_key]
            perturb_entity_index = torch.cat((perturb_entity_index, subject.view(-1)))
            scores = score_batch_triplets(model, object_, relation, perturb_entity_index, 'head', entity_embeddings, relation_embeddings, device)

            target = torch.tensor(len(perturb_entity_index) - 1).to(entity_embeddings.device)
            ranks_o.append(sort_and_rank(scores, target))
        
        ranks_s = torch.cat(ranks_s)
        ranks_o = torch.cat(ranks_o)

        ranks = torch.cat([ranks_s, ranks_o])
        ranks += 1
        # Calculate MRR (Mean Reciprocal Rank)
        mrr = torch.mean(1.0 / ranks.float()).item()

        # Calculate MR (Mean Rank)
        mr = torch.mean(ranks.float()).item()

        # Calculate Hits@k
        hits_result = {}
        for hit in hits:
            avg_count = torch.mean((ranks <= hit).float()).item()
            hits_result[hit] = avg_count
        
        evaluate_res = {'MRR': mrr, 'MR': mr, 'Hits@1': hits_result[1], 'Hits@3': hits_result[3], 'Hits@10': hits_result[10]}
        return evaluate_res

def calc_mrr_simple(model, test_triplets, entity_embeddings, relation_embeddings, known_triplets, device, hits=[1, 3, 10]):
    """
    计算测试集上的 MR, MRR, Hit@K。
    test_triplets: 测试集 test_triplets
    model: 训练好的评分模型
    entity_embeddings: 实体嵌入字典
    relation_embeddings: 关系嵌入字典
    known_triplets: 所有已知的三元组集合（例如，训练集+测试集的三元组）
    device: 设备 (CPU 或 GPU)
    k_values: Hit@K 的 K 值列表，例如 [1, 3, 10]
    """
    model.eval()  # 切换到评估模式
    with torch.no_grad():
        ranks_s = []
        ranks_o = []

        all_entities = torch.arange(len(entity_embeddings), device=entity_embeddings.device)
        for i in range(len(test_triplets)):
            test_triplet = test_triplets[i]
            # if i % 100 == 0:
            #     print(f" Processing triplet {i}/{len(test_triplets)}")
            
            target = score_triplet(model, test_triplet, entity_embeddings, relation_embeddings, device).item()
            subject, relation, object_ = test_triplet[0], test_triplet[1], test_triplet[2]

            # Perturb object (head is fixed)
            scores = score_batch_triplets(model, subject, relation, all_entities, 'tail', entity_embeddings, relation_embeddings, device)
            target = torch.tensor(object_.item()).to(entity_embeddings.device)
            ranks_s.append(sort_and_rank(scores, target))

            # Perturb subject (tail is fixed)
            scores = score_batch_triplets(model, object_, relation, all_entities, 'head', entity_embeddings, relation_embeddings, device)
            target = torch.tensor(subject.item()).to(entity_embeddings.device)
            ranks_o.append(sort_and_rank(scores, target))
        
        ranks_s = torch.cat(ranks_s)
        ranks_o = torch.cat(ranks_o)

        ranks = torch.cat([ranks_s, ranks_o])
        ranks += 1
        # Calculate MRR (Mean Reciprocal Rank)
        mrr = torch.mean(1.0 / ranks.float()).item()

        # Calculate MR (Mean Rank)
        mr = torch.mean(ranks.float()).item()

        # Calculate Hits@k
        hits_result = {}
        for hit in hits:
            avg_count = torch.mean((ranks <= hit).float()).item()
            hits_result[hit] = avg_count
        
        evaluate_res = {'MRR': mrr, 'MR': mr, 'Hits@1': hits_result[1], 'Hits@3': hits_result[3], 'Hits@10': hits_result[10]}
        return evaluate_res

def evaluate_metrics(model, test_triplets, entity_embeddings, relation_embeddings, known_triplets, device, batch_size:int=200):
    # test_triplets = test_triplets.to(device)
    # entity_embeddings = entity_embeddings.to(device)
    # relation_embeddings = relation_embeddings.to(device)
    # known_triplets = known_triplets.to(device)
    
    total_res = {'MRR': 0, 'MR': 0, 'Hits@1': 0, 'Hits@3': 0, 'Hits@10': 0}
    for i in tqdm(range(0, len(test_triplets), batch_size), desc="Evaluating Batches", unit="batch"):
        batch_valid_triplets = test_triplets[i:i+batch_size]
        batch_length = len(batch_valid_triplets)
        # print(f"now evaluating {i+1} batch, its length is {batch_length}")

        batch_res = calc_mrr(model, batch_valid_triplets, entity_embeddings, relation_embeddings, known_triplets, device, hits=[1, 3, 10])
        total_res['MRR'] += batch_res['MRR'] * batch_length
        total_res['MR'] += batch_res['MR'] * batch_length
        total_res['Hits@1'] += batch_res['Hits@1'] * batch_length
        total_res['Hits@3'] += batch_res['Hits@3'] * batch_length
        total_res['Hits@10'] += batch_res['Hits@10'] * batch_length

        # Update progress bar with current batch Hits@1, Hits@3, Hits@10
        tqdm.write(f"Batch {i//batch_size + 1}, Hits@1: {batch_res['Hits@1']:.4f}, Hits@3: {batch_res['Hits@3']:.4f}, Hits@10: {batch_res['Hits@10']:.4f}")

    total_res['MRR'] /= len(test_triplets)
    total_res['MR'] /= len(test_triplets)
    total_res['Hits@1'] /= len(test_triplets)
    total_res['Hits@3'] /= len(test_triplets)
    total_res['Hits@10'] /= len(test_triplets)

    return total_res

In [None]:
once_test = calc_mrr(model, test_triplets, entity_embeddings, relation_embeddings, known_triplets, device, hits=[1, 3, 10])
print(once_test)
batch_res = evaluate_metrics(model, test_triplets, entity_embeddings, relation_embeddings, known_triplets, device, batch_size=300)
print(batch_res)

{'MRR': 0.03419439122080803,
 'MR': 3432.0166015625,
 'Hits@1': 0.016499999910593033,
 'Hits@3': 0.027500001713633537,
 'Hits@10': 0.057500001043081284}

In [5]:
# query_entity = 'Disease::DOID:8577' #ulcerative colitis

node1 = "DOID:162"
# node1 = "DOID:162" #cancer
# node1 = "MESH:D009369" #Neoplasms
node2 = "Q8WZA9"

gene_disease_tri[((gene_disease_tri['node1'] == node1) & (gene_disease_tri['node2'] == node2)|
               (gene_disease_tri['node1'] == node2) & (gene_disease_tri['node2'] == node1))]

Unnamed: 0,index,node1,node1_type,relation,node2,node2_type,direction


In [6]:
query_relation = 'Gene:Disease::drug targets'

query_entity_location1 = 'tail'
query_entity1 = f'Disease::{node1}'
predict_scores_df1 = predict_hrt_filter(model, entity_embeddings, relation_embeddings, entity2id, relation2id, gene_ids, disease_ids,
                                 query_entity1, query_entity_location1, query_relation, known_triplets, device)
try:
    print(predict_scores_df1[predict_scores_df1['target_id'] == f'Gene::{node2}']['rank'].values[0])
except:
    print('not found')
    predict_scores_df1 = predict_hrt(model, entity_embeddings, relation_embeddings, entity2id, relation2id, gene_ids, disease_ids,
                                 query_entity1, query_entity_location1, query_relation, known_triplets, device)
    print(predict_scores_df1[predict_scores_df1['target_id'] == f'Gene::{node2}']['rank'].values[0])

query_entity_location2 = 'head'
query_entity2 = f'Gene::{node2}'
predict_scores_df2 = predict_hrt_filter(model, entity_embeddings, relation_embeddings, entity2id, relation2id, gene_ids, disease_ids,
                                 query_entity2, query_entity_location2, query_relation, known_triplets, device)
try:
    print(predict_scores_df2[predict_scores_df2['target_id'] == f'Disease::{node1}']['rank'].values[0])
except:
    print('not found')
    predict_scores_df2 = predict_hrt(model, entity_embeddings, relation_embeddings, entity2id, relation2id, gene_ids, disease_ids,
                                 query_entity2, query_entity_location2, query_relation, known_triplets, device)
    print(predict_scores_df2[predict_scores_df2['target_id'] == f'Disease::{node1}']['rank'].values[0])

1541
1046


In [15]:
# query_entity = 'Gene::Q16769' #sQC
# query_entity = 'Gene::Q9NXS2' #gQC
# query_entity = 'Gene::O95630' #AMSH
# query_entity = 'Gene::Q96FJ0' #AMSH-LP
# query_entity_location = 'head'

# query_entity = 'Disease::DOID:8577' #ulcerative colitis
query_entity = 'Disease::DOID:0050589' #inflammatory bowel disease
# query_entity = 'Disease::DOID:784' #chronic kidney disease
query_entity_location = 'tail'
query_relation = 'Gene:Disease::drug targets'
# query_relation = 'Disease:Gene::associated with'

predict_scores_df = predict_hrt_filter(model, entity_embeddings, relation_embeddings, entity2id, relation2id, gene_ids, disease_ids,
                                 query_entity, query_entity_location, query_relation, known_triplets, device)
predict_scores_df['query_entity'] = query_entity
predict_scores_df['query_relation'] = query_relation
if query_entity.split('::')[0] == 'Gene':
    predict_scores_df = pd.merge(predict_scores_df, disease_feature, on='target_id', how='left')
    predict_scores_df = predict_scores_df[['query_entity', 'query_relation', 'target_index', 'target_id', 'name', 'score', 'rank']]
    print("ulcerative colitis", predict_scores_df[predict_scores_df['target_id'] == 'Disease::DOID:8577']['rank'].values[0])
    print("inflammatory bowel disease", predict_scores_df[predict_scores_df['target_id'] == 'Disease::DOID:0050589']['rank'].values[0])
    # print("chronic kidney disease:", predict_scores_df[predict_scores_df['target_id'] == 'Disease::DOID:784']['rank'].values[0])
    # predict_scores_df[predict_scores_df['name'].str.contains('inflammatory bowel disease|ulcerative colitis', regex=True)]
elif query_entity.split('::')[0] == 'Disease':
    predict_scores_df = pd.merge(predict_scores_df, gene_feature, on='target_id', how='left')
    predict_scores_df = predict_scores_df[['query_entity', 'query_relation', 'target_index', 'target_id', 'name', 'score', 'rank']]
    print("sQC", predict_scores_df[predict_scores_df['target_id'] == 'Gene::Q16769']['rank'].values[0])
    print("gQC", predict_scores_df[predict_scores_df['target_id'] == 'Gene::Q9NXS2']['rank'].values[0])
    # print("AMSH", predict_scores_df[predict_scores_df['target_id'] == 'Gene::O95630']['rank'].values[0])
    # print("AMSH-LP", predict_scores_df[predict_scores_df['target_id'] == 'Gene::Q96FJ0']['rank'].values[0])
predict_scores_df

sQC 2571
gQC 10121


Unnamed: 0,query_entity,query_relation,target_index,target_id,name,score,rank
0,Disease::DOID:0050589,Gene:Disease::drug targets,168163,Gene::Q9R0Q8,CLC4E_MOUSE,9.922761e-01,1
1,Disease::DOID:0050589,Gene:Disease::drug targets,175975,Gene::Q9Y5Y4,PD2R2_HUMAN,9.879166e-01,2
2,Disease::DOID:0050589,Gene:Disease::drug targets,85630,Gene::P51685,CCR8_HUMAN,9.875825e-01,3
3,Disease::DOID:0050589,Gene:Disease::drug targets,85624,Gene::P51679,CCR4_HUMAN,9.867319e-01,4
4,Disease::DOID:0050589,Gene:Disease::drug targets,73393,Gene::P17735,ATTY_HUMAN,9.854205e-01,5
...,...,...,...,...,...,...,...
147916,Disease::DOID:0050589,Gene:Disease::drug targets,62573,Gene::O82775,ATI1_ARATH,1.892593e-06,147917
147917,Disease::DOID:0050589,Gene:Disease::drug targets,96065,Gene::P9WNH5,HSAD_MYCTU,1.248290e-06,147918
147918,Disease::DOID:0050589,Gene:Disease::drug targets,160625,Gene::Q9FNP9,AGCT_ARATH,1.070355e-06,147919
147919,Disease::DOID:0050589,Gene:Disease::drug targets,96324,Gene::P9WP86,COBH_MYCTO,4.490742e-07,147920


In [16]:
# predict_scores_df.to_csv("./train_results/case_res/sQC_disease_rank.csv", index=False)
# predict_scores_df.to_csv("./train_results/case_res/gQC_disease_rank.csv", index=False)
# predict_scores_df.to_csv("./train_results/case_res/AMSH_disease_rank.csv", index=False)
# predict_scores_df.to_csv("./train_results/case_res/AMSH-LP_disease_rank.csv", index=False)

# predict_scores_df.to_csv("./train_results/case_res/colitis_gene_rank.csv", index=False)
predict_scores_df.to_csv("./train_results/case_res/IBD_gene_rank.csv", index=False)
predict_scores_df

Unnamed: 0,query_entity,query_relation,target_index,target_id,name,score,rank
0,Disease::DOID:0050589,Gene:Disease::drug targets,168163,Gene::Q9R0Q8,CLC4E_MOUSE,9.922761e-01,1
1,Disease::DOID:0050589,Gene:Disease::drug targets,175975,Gene::Q9Y5Y4,PD2R2_HUMAN,9.879166e-01,2
2,Disease::DOID:0050589,Gene:Disease::drug targets,85630,Gene::P51685,CCR8_HUMAN,9.875825e-01,3
3,Disease::DOID:0050589,Gene:Disease::drug targets,85624,Gene::P51679,CCR4_HUMAN,9.867319e-01,4
4,Disease::DOID:0050589,Gene:Disease::drug targets,73393,Gene::P17735,ATTY_HUMAN,9.854205e-01,5
...,...,...,...,...,...,...,...
147916,Disease::DOID:0050589,Gene:Disease::drug targets,62573,Gene::O82775,ATI1_ARATH,1.892593e-06,147917
147917,Disease::DOID:0050589,Gene:Disease::drug targets,96065,Gene::P9WNH5,HSAD_MYCTU,1.248290e-06,147918
147918,Disease::DOID:0050589,Gene:Disease::drug targets,160625,Gene::Q9FNP9,AGCT_ARATH,1.070355e-06,147919
147919,Disease::DOID:0050589,Gene:Disease::drug targets,96324,Gene::P9WP86,COBH_MYCTO,4.490742e-07,147920


In [5]:
test_data_triplets = test_data_id[columns]
labels = test_data_id['label'].values
scores = predict_triplets_scores(model, entity_embeddings, relation_embeddings, test_data_triplets, device)

auc_roc = roc_auc(labels, scores)
auc_pr = pr_auc(labels, scores)

preds = np.where(scores > 0.5, 1, 0)
acc = accuracy_score(labels, preds)
print(f"AUC-ROC: {auc_roc}")
print(f"AUC-PR: {auc_pr}")
print(f"ACC: {acc}")

AUC-ROC: 0.9504401622884378
AUC-PR: 0.9481627791443648
ACC: 0.8793284700114063


In [None]:
scores_record = np.stack([labels, scores], axis=1)
scores_record = pd.DataFrame(scores_record, columns=['label', 'score'])
# scores_record.to_csv(f"./train_results/mlp_ablation_merged0215/merged_scores.csv", index=False, header=False)
scores_record

Unnamed: 0,label,score
0,1.0,0.736643
1,1.0,0.883155
2,1.0,0.887239
3,1.0,0.631249
4,1.0,0.766208
...,...,...
939819,0.0,0.020183
939820,0.0,0.342502
939821,0.0,0.189065
939822,0.0,0.963828


# 2.模型打分、评价和预测(旧版本-嵌入使用字典)

In [3]:
def get_embeddings(entity2id, relation2id, gene_ids, disease_ids):
    kge_ent_emb = np.load(f'{model_res}/entity_embedding.npy')
    kge_rel_emb = np.load(f'{model_res}/relation_embedding.npy')
    gene_feat_emb, disease_feat_emb = load_entity_feature(data_dir)
    # gene_ids, disease_ids = load_gene_disease_ids(data_dir, entity2id)
    entity_embeddings, relation_embeddings = get_merged_embeddings(kge_ent_emb, kge_rel_emb, entity2id, relation2id,
                                                                    gene_feat_emb, disease_feat_emb, gene_ids, disease_ids, 100)
    return entity_embeddings, relation_embeddings

def score_triplet(model, triplet, entity_embeddings, relation_embeddings, device):
    """
    给定模型和三元组，生成评分。
    triplet: 包含 head, relation, tail 的三元组
    entity_embeddings: 实体嵌入字典
    relation_embeddings: 关系嵌入字典
    """
    head_emb = entity_embeddings[triplet[0]]
    relation_emb = relation_embeddings[triplet[1]]
    tail_emb = entity_embeddings[triplet[2]]
    input_emb = np.concatenate([head_emb, relation_emb, tail_emb], axis=-1)
    input_tensor = torch.tensor(input_emb, dtype=torch.float32).to(device).unsqueeze(0)
    score = model(input_tensor).item()

    return score

def score_batch_triplets(model, fix_entity, relation, perturb_entity_index, perturb_type, entity_embeddings, relation_embeddings, device):
    fix_emb = torch.tensor(entity_embeddings[fix_entity], dtype=torch.float32).to(device)  # (embedding_dim,)
    relation_emb = torch.tensor(relation_embeddings[relation], dtype=torch.float32).to(device)  # (embedding_dim,)
    perturb_emb = torch.stack([torch.tensor(entity_embeddings[entity_id], dtype=torch.float32).to(device) 
                            for entity_id in perturb_entity_index])  # (len(perturb_entity_index), embedding_dim)
    
    fix_emb = fix_emb.unsqueeze(0).expand(len(perturb_entity_index), -1)  # (len(perturb_entity_index), embedding_dim)
    relation_emb = relation_emb.unsqueeze(0).expand(len(perturb_entity_index), -1)  # (len(perturb_entity_index), embedding_dim)

    if perturb_type == 'head':
        input_emb = torch.cat([fix_emb, relation_emb, perturb_emb], dim=-1)  # (len(perturb_entity_index), 3 * embedding_dim)
    elif perturb_type == 'tail':
        input_emb = torch.cat([perturb_emb, relation_emb, fix_emb], dim=-1)  # (len(perturb_entity_index), 3 * embedding_dim)
    
    # calculate scores
    scores = model(input_emb).squeeze()  # (len(perturb_entity_index),)
    scores = scores.cpu()
    scores = scores.detach().numpy()
    return scores

def predict_hrt(model, entity_embeddings, relation_embeddings, entity2id, relation2id, gene_ids, disease_ids,
                 query_entity, query_entity_location, query_relation, known_triplets, device):
    # known_triplets = read_triplets(os.path.join(data_dir, "gene_disease_triplet.tsv"), entity2id, relation2id)
    head_relation_triplets = known_triplets[:, :2]
    tail_relation_triplets = np.stack((known_triplets[:, 2], known_triplets[:, 1]), axis=1)

    # gene_ids, disease_ids = load_gene_disease_ids(data_dir, entity2id)
    entity_index = entity2id[query_entity]
    relation_index = relation2id[query_relation]
    query_entity_type = query_entity.split("::")[0]

    if query_entity_type == "Gene":
        target_ids = disease_ids
    elif query_entity_type == "Disease":
        target_ids = gene_ids

    if query_entity_location == "head":
        subject_relation = np.stack((entity_index, relation_index), axis=0)
        delete_index = np.sum(head_relation_triplets == subject_relation, axis=1)
        delete_index = np.where(delete_index == 2)[0]
        delete_entity_index = known_triplets[delete_index, 2]
        target_entity_index = np.array(list(set(target_ids) - set(delete_entity_index)))
    elif query_entity_location == "tail":
        object_relation = np.stack((relation_index, entity_index), axis=0)
        delete_index = np.sum(tail_relation_triplets == object_relation, axis=1)
        delete_index = np.where(delete_index == 2)[0]
        delete_entity_index = known_triplets[delete_index, 0]
        target_entity_index = np.array(list(set(target_ids) - set(delete_entity_index)))

    scores = score_batch_triplets(model, entity_index, relation_index, target_entity_index, query_entity_location, entity_embeddings, relation_embeddings, device)
    scores_df = pd.DataFrame(np.concatenate([target_entity_index.reshape(-1, 1), scores.reshape(-1, 1)], axis=1), columns=['target_index', 'score'])
    scores_df = scores_df.sort_values(by='score', ascending=False, ignore_index=True)
    scores_df['target_index'] = scores_df['target_index'].astype('int64')

    id2entity = {v: k for k, v in entity2id.items()}
    scores_df.insert(1, 'target_id', scores_df['target_index'].map(id2entity))
    scores_df['rank'] = scores_df.index + 1

    return scores_df

def predict_triplets_scores(model, entity_embeddings, relation_embeddings, triplets, device):
    head_entity_index = triplets.iloc[:, 0]
    relation_index = triplets.iloc[:, 1]
    tail_entity_index = triplets.iloc[:, 2]
    head_emb = torch.stack([torch.tensor(entity_embeddings[entity_id], dtype=torch.float32).to(device) 
                            for entity_id in head_entity_index])
    relation_emb = torch.stack([torch.tensor(relation_embeddings[relation_id], dtype=torch.float32).to(device)
                            for relation_id in relation_index])
    tail_emb = torch.stack([torch.tensor(entity_embeddings[entity_id], dtype=torch.float32).to(device) 
                            for entity_id in tail_entity_index])
    input_emb = torch.cat([head_emb, relation_emb, tail_emb], dim=-1)# (len(perturb_entity_index), 3 * embedding_dim)
    scores = model(input_emb).squeeze()  # (len(perturb_entity_index),)
    scores = scores.cpu()
    scores = scores.detach().numpy()
    return scores

In [4]:
entity_embeddings, relation_embeddings = get_embeddings(entity2id, relation2id, gene_ids, disease_ids)
predict_scores_df = predict_hrt(model, entity_embeddings, relation_embeddings, entity2id, relation2id, gene_ids, disease_ids,
                                 query_entity, query_entity_location, query_relation, known_triplets, device)
print(predict_scores_df[predict_scores_df['target_id'] == 'Disease::DOID:784']['rank'].values[0])
print(predict_scores_df[predict_scores_df['target_id'] == 'Disease::DOID:8577']['rank'].values[0])
predict_scores_df

265
317


Unnamed: 0,target_index,target_id,score,rank
0,13502,Disease::MESH:D009369,9.885598e-01,1
1,8665,Disease::DOID:3910,9.870047e-01,2
2,7985,Disease::DOID:2645,9.809914e-01,3
3,8718,Disease::DOID:4015,9.804516e-01,4
4,14933,Disease::MONDO:0005369,9.785535e-01,5
...,...,...,...,...
26991,25143,Disease::MONDO:0044800,2.277495e-07,26992
26992,22751,Disease::MONDO:0021280,1.906655e-07,26993
26993,9947,Disease::DOID:6229,1.865906e-07,26994
26994,25584,Disease::MONDO:0100292,1.854514e-07,26995


In [5]:
test_triplets = test_data_id[columns]
labels = test_data_id['label'].values
scores = predict_triplets_scores(model, entity_embeddings, relation_embeddings, test_triplets, device)

auc_roc = roc_auc(labels, scores)
auc_pr = pr_auc(labels, scores)

preds = np.where(scores > 0.5, 1, 0)
acc = accuracy_score(labels, preds)
print(f"AUC-ROC: {auc_roc}")
print(f"AUC-PR: {auc_pr}")
print(f"ACC: {acc}")

AUC-ROC: 0.9484319046989589
AUC-PR: 0.9463732475895523
ACC: 0.8764364391630773
