Skip to content
Permalink
Browse files

Passing an unseen entity or relation to predict() now returns a meani…

…ngful error message
  • Loading branch information...
lukostaz committed Apr 16, 2019
1 parent c209c95 commit fab722719debeb7aa890fa088162ba2abf09f610
Showing with 39 additions and 11 deletions.
  1. +17 −10 ampligraph/evaluation/protocol.py
  2. +22 −1 tests/ampligraph/latent_features/test_models.py
@@ -413,17 +413,24 @@ def generate_corruptions_for_fit(X, entities_list=None, eta=1, corrupt_side='s+o


def _convert_to_idx(X, ent_to_idx, rel_to_idx, obj_to_idx):
try:
x_idx_s = np.vectorize(ent_to_idx.get)(X[:, 0])
x_idx_p = np.vectorize(rel_to_idx.get)(X[:, 1])
x_idx_o = np.vectorize(obj_to_idx.get)(X[:, 2])
logger.debug('Returning ids.')
except TypeError:
msg='Unseen entities found in test/validation set. Please filter the data using filter_unseen_entities function.'
x_idx_s = np.vectorize(ent_to_idx.get)(X[:, 0])
x_idx_p = np.vectorize(rel_to_idx.get)(X[:, 1])
x_idx_o = np.vectorize(obj_to_idx.get)(X[:, 2])

if None in x_idx_s or None in x_idx_s:
msg = 'Input triples include one or more entities not present in the training set. ' \
'Please filter X using evaluation.filter_unseen_entities(), or retrain the model on a training set ' \
'that includes all the desired distinct entities.'
logger.error(msg)
raise TypeError(msg)


raise ValueError(msg)

if None in x_idx_p:
msg = 'Input triples include one or more relation type not present in the training set. ' \
'Please filter all relation in X that do not occur in the training test. ' \
'or retrain the model on a training set that includes all the desired relation types.'
logger.error(msg)
raise ValueError(msg)

return np.dstack([x_idx_s, x_idx_p, x_idx_o]).reshape((-1, 3))


@@ -1,4 +1,5 @@
import numpy as np
import pytest

from ampligraph.latent_features import TransE, DistMult, ComplEx, HolE
from ampligraph.datasets import load_wn18
@@ -129,13 +130,33 @@ def test_retrain():
def test_fit_predict_wn18_TransE():
X = load_wn18()
model = TransE(batches_count=1, seed=555, epochs=5, k=100, loss='pairwise', loss_params={'margin': 5},
verbose=True, optimizer='adagrad', optimizer_params={'lr':0.1})
verbose=True, optimizer='adagrad', optimizer_params={'lr': 0.1})
model.fit(X['train'])
y, _ = model.predict(X['test'][:1], get_ranks=True)

print(y)


def test_missing_entity_ComplEx():

X = np.array([['a', 'y', 'b'],
['b', 'y', 'a'],
['a', 'y', 'c'],
['c', 'y', 'a'],
['a', 'y', 'd'],
['c', 'y', 'd'],
['b', 'y', 'c'],
['f', 'y', 'e']])
model = ComplEx(batches_count=1, seed=555, epochs=2, k=5)
model.fit(X)
with pytest.raises(ValueError):
model.predict(['a', 'y', 'zzzzzzzzzzz'])
with pytest.raises(ValueError):
model.predict(['a', 'xxxxxxxxxx', 'e'])
with pytest.raises(ValueError):
model.predict(['zzzzzzzz', 'y', 'e'])


def test_fit_predict_wn18_ComplEx():
X = load_wn18()
model = ComplEx(batches_count=1, seed=555, epochs=5, k=100,

0 comments on commit fab7227

Please sign in to comment.
You can’t perform that action at this time.