In [1]:
import pickle
import numpy as np
import bcolz
import openke
from openke.config import Trainer, Tester
from openke.module.model import TransE
from openke.module.loss import MarginLoss
from openke.module.strategy import NegativeSampling
from openke.data import TrainDataLoader, TestDataLoader

In [3]:
# dataloader for training
train_dataloader = TrainDataLoader(
	in_path = "./benchmarks/FB15K237/",
	nbatches = 100,
	threads = 8,
	sampling_mode = "normal",
	bern_flag = 1,
	filter_flag = 1,
	neg_ent = 25,
	neg_rel = 0)

# dataloader for test
test_dataloader = TestDataLoader("./benchmarks/FB15K237/", "link")
#load pretrained glove dictionary
glove_path = "/home/ubuntu/text-pwrd-kg-reasoning/data/"
vectors = bcolz.open(f'{glove_path}/6B.200.dat')[:]
words = pickle.load(open(f'{glove_path}/6B.200_words.pkl', 'rb'))
word2idx = pickle.load(open(f'{glove_path}/6B.200_idx.pkl', 'rb'))
# wrd2embedding
glove = {w: vectors[word2idx[w]] for w in words}
entity2name = {}

with open(f'{glove_path}/mid2name.tsv', 'r') as f:
	line = f.readline().split()
	while line:
		entity2name[line[0]] = line[1:]
		line = f.readline().split()

id2entity = {}
with open("/home/ubuntu/text-pwrd-kg-reasoning/OpenKE/benchmarks/FB15K237/entity2id.txt") as f:
	line = f.readline().split()
	max_id = int(line[0])
	line = f.readline().split()
	while line:
		id2entity[int(line[1])] = line[0]
		line = f.readline().split()

# create weight matrix for entity
matrix_len = train_dataloader.get_ent_tot()
weights_matrix = np.zeros((matrix_len, 200))
not_found = 0

for i in range(max_id):
	entity = id2entity[i]
	try:
		words = entity2name[entity]
	except KeyError:
		words = ['unk']
	for wrd in words:
		try:
			weights_matrix[i] += glove[wrd]
		except KeyError:
			weights_matrix[i] += glove['unk']
			not_found += 1

In [7]:
weights_matrix

array([[ 0.15637 , -0.40361 , -0.29629 , ..., -0.82404 , -0.034287,
        -0.16396 ],
       [ 0.15637 , -0.40361 , -0.29629 , ..., -0.82404 , -0.034287,
        -0.16396 ],
       [-0.52733 , -0.476889,  0.20259 , ..., -0.858234,  0.348813,
         0.24793 ],
       ...,
       [ 0.15637 , -0.40361 , -0.29629 , ..., -0.82404 , -0.034287,
        -0.16396 ],
       [ 0.15637 , -0.40361 , -0.29629 , ..., -0.82404 , -0.034287,
        -0.16396 ],
       [ 0.15637 , -0.40361 , -0.29629 , ..., -0.82404 , -0.034287,
        -0.16396 ]])

In [8]:
entity2name['/m/027rn']

['U.S.-attempted', 'annexation', 'of', 'the', 'Dominican', 'Republic']