In [7]:
import os
import dgl
import torch
import numpy as np
import pandas as pd

kg_data_path = '../kg/'
train_data_path = 'train_data/'
kge_data_path = 'out/'
kge_method = 'TransR'


# load dgl graph from directory
kg = dgl.data.CSVDataset(kg_data_path, force_reload=True)[0]

Done saving data into cached files.


In [8]:
# read entity and relation dict
entity_dict_file = os.path.join(train_data_path, 'entities.tsv')
relation_dict_file = os.path.join(train_data_path, 'relations.tsv')

df_entity = pd.read_csv(entity_dict_file, sep='\t', names=['index', 'entity'])
df_relation = pd.read_csv(relation_dict_file, sep='\t', names=['index', 'relation'])

entity_list = df_entity['entity'].tolist()
relation_list = df_relation['relation'].tolist()

In [9]:
# load pretrained kge
kge_outs = os.listdir(kge_data_path)
kge_outs = [f for f in kge_outs if f.startswith(kge_method)]
kge_outs.sort()

kge_files = os.listdir(os.path.join(kge_data_path, kge_outs[-1]))
entity_emb_file = [f for f in kge_files if f.endswith('entity.npy')][0]
entity_emb = np.load(os.path.join(kge_data_path, kge_outs[-1], entity_emb_file))
relation_emb_file = [f for f in kge_files if f.endswith('relation.npy')][0]
relation_emb = np.load(os.path.join(kge_data_path, kge_outs[-1], relation_emb_file))

In [10]:
# load to kg
kg_embedding_size = entity_emb.shape[1]
kg.ndata['embedding'] = {ntype: torch.zeros(kg.number_of_nodes(ntype), kg_embedding_size, device=kg.device) for ntype in kg.ntypes}
for index, entityID in enumerate(entity_list):
    ntype = entityID.split('ID')[0]
    nid = int(entityID.split('ID')[1])
    kg.nodes[ntype].data['embedding'][nid] = torch.tensor(entity_emb[index], device=kg.device)

kg_entity_embs = []
for etype in kg.ntypes:
    kg_entity_embs.append(kg.ndata['embedding'][etype])
kg_entity_embs = torch.cat(kg_entity_embs, dim=0)

In [11]:
kg_relation_embs = []
for etype in kg.etypes:
    kg_relation_embs.append(torch.tensor(relation_emb[relation_list.index(etype)]).to(kg_entity_embs.device).unsqueeze(0))
kg_relation_embs = torch.cat(kg_relation_embs, dim=0)

In [12]:
all_kg_embs = torch.cat([torch.zeros((1, kg_embedding_size), dtype=kg_relation_embs.dtype), kg_entity_embs, kg_relation_embs], dim=0)

# # save as npy
# np.save('entity_embedding.npy', kg_entity_embs.cpu().numpy())
# np.save('relation_embedding.npy', kg_relation_embs.cpu().numpy())
np.save(f'kg_embedding_{all_kg_embs.shape[-1]}.npy', all_kg_embs.cpu().numpy())
