In [None]:
# 02_entity_similarity_analysis
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on March 4, 2023
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on March 4, 2023
#
# 该脚本展示了如何分析训练的实体嵌入.
#
# 需要的包:
#          numpy
#          csv
#          matplotlib
#          sklearn
#          MulticoreTSNE, 安装命令为:
#                pip install cmake==3.18.4 -i https://pypi.tuna.tsinghua.edu.cn/simple
#                pip install MulticoreTSNE -i https://pypi.tuna.tsinghua.edu.cn/simple
#
# 需要的文件:
#          ../01-model/ckpts/RotatE_All_DRKG_0/All_DRKG_RotatE_entity.npy
#          ../../data/drkg/entities.tsv
#
# 源教程链接: https://github.com/gnn4dr/DRKG/blob/master/embedding_analysis/Entity_similarity_analysis.ipynb

# DRKG Entity Embedding Similarity Analysis

这个 notebook 展示了如何分析训练的实体嵌入.

在这个例子中, 我们首先加载训练的实体嵌入向量, 然后将它们映射回原始的实体名, 最后应用两种方法分析它们:

- 投射实体嵌入进入低维空间并可视化它们的分布.

- 使用余弦距离分析实体间的相似程度.

In [None]:
import numpy as np
import csv
import sklearn
import matplotlib.pyplot as plt
from MulticoreTSNE import MulticoreTSNE as TSNE
from sklearn.metrics.pairwise import cosine_similarity

## Loading Entity ID Mapping

In [None]:
entity2id = {}
id2entity = {}

with open("../../data/drkg/entities.tsv", newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['id','entity'])
    for row_val in reader:
        id = row_val['id']
        entity = row_val['entity']

        entity2id[entity] = int(id)
        id2entity[int(id)] = entity

print("Number of entities: {}".format(len(entity2id)))

## Loading Entity Embeddings

In [None]:
entity_emb = np.load('../01-model/ckpts/RotatE_All_DRKG_0/All_DRKG_RotatE_entity.npy')
print(entity_emb.shape)

## General Entity Embedding Clustering

这里我们使用 t-SNE 将实体嵌入降维, 然后可视化它们的分布.

In [None]:
# 将实体按照源数据集分类
dataset_id = {}
for entity_name, i in entity2id.items():
    entity_key = entity_name.split('::')[0]
    if dataset_id.get(entity_key, None) is None:
        dataset_id[entity_key] = []
    dataset_id[entity_key].append(i)

# 降维并转置
X_embedded = TSNE(n_components=2, n_jobs=32).fit_transform(entity_emb).T

In [None]:
# 绘制
fig = plt.figure()
ax = fig.add_subplot(111)

for key, val in dataset_id.items():
    val = np.asarray(val, dtype=int)
    ax.plot(X_embedded[0][val], X_embedded[1][val], '.', label=key)

lgd = ax.legend(bbox_to_anchor=(1.0, 1.0))
plt.savefig('./result/entity.svg', bbox_extra_artists=(lgd,), bbox_inches='tight', format='svg')