Skip to content
Permalink
Browse files

Rename 'type' parameter to fix masking built-in function name (#63)

* Issue #44 rename 'type' parameter to fix masking built-in function name

* Rename `type` parameter in the documentation as well

* A better name for `type` parameter
  • Loading branch information...
iamaziz authored and rorymcgrath committed Apr 5, 2019
1 parent 94a5c31 commit ac515bc34c64becc04385797c8b9f0d93b20b58d
Showing with 15 additions and 16 deletions.
  1. +9 −10 ampligraph/latent_features/models.py
  2. +2 −2 docs/examples.md
  3. +1 −1 test.py
  4. +3 −3 tests/ampligraph/latent_features/test_models.py
@@ -329,15 +329,15 @@ def _load_model_from_trained_params(self):
self.ent_emb = tf.constant(self.trained_model_params[0])
self.rel_emb = tf.constant(self.trained_model_params[1])

def get_embeddings(self, entities, type='entity'):
def get_embeddings(self, entities, embedding_type='entity'):
"""Get the embeddings of entities or relations.
Parameters
----------
entities : array-like, dtype=int, shape=[n]
The entities (or relations) of interest. Element of the vector must be the original string literals, and
not internal IDs.
type : string
embedding_type : string
If 'entity', the ``entities`` argument will be considered as a list of knowledge graph entities (i.e. nodes).
If set to 'relation', they will be treated as relation types instead (i.e. predicates).
@@ -347,20 +347,19 @@ def get_embeddings(self, entities, type='entity'):
An array of k-dimensional embeddings.
"""
# TODO - Rename type with something else. This is masking the built-in function "type" #44
if not self.is_fitted:
msg = 'Model has not been fitted.'
logger.error(msg)
raise RuntimeError(msg)

if type is 'entity':
if embedding_type is 'entity':
emb_list = self.trained_model_params[0]
lookup_dict = self.ent_to_idx
elif type is 'relation':
elif embedding_type is 'relation':
emb_list = self.trained_model_params[1]
lookup_dict = self.rel_to_idx
else:
msg = 'Invalid entity type: {}'.format(type)
msg = 'Invalid entity type: {}'.format(embedding_type)
logger.error(msg)
raise ValueError(msg)

@@ -1102,7 +1101,7 @@ class TransE(EmbeddingModel):
>>> model.fit(X)
>>> model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]))
[-2.219729, -3.9848995]
>>> model.get_embeddings(['f','e'], type='entity')
>>> model.get_embeddings(['f','e'], embedding_type='entity')
array([[-0.65229136, -0.50060457, 1.2316223 , 0.23738968, 0.29145557,
-0.20187911, -0.3053819 , -0.6947149 , 0.9377473 , 0.12985024],
[-1.1272118 , 0.10723944, 0.79431695, 0.6795645 , -0.14428931,
@@ -1336,7 +1335,7 @@ class DistMult(EmbeddingModel):
>>> model.fit(X)
>>> model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]))
[3.29703, -3.543957]
>>> model.get_embeddings(['f','e'], type='entity')
>>> model.get_embeddings(['f','e'], embedding_type='entity')
array([[-0.7101061 , -0.35752687, 0.5337027 , -0.612499 , -0.34532365,
-0.7219143 , -0.07083285, 0.19323194, 1.0108972 , 0.42850104],
[-1.2280471 , -0.22018537, 0.17179069, 0.757755 , -0.05845603,
@@ -1565,7 +1564,7 @@ class ComplEx(EmbeddingModel):
>>> model.fit(X)
>>> model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]))
[0.96325016, -0.17629346]
>>> model.get_embeddings(['f','e'], type='entity')
>>> model.get_embeddings(['f','e'], embedding_type='entity')
array([[-0.11257 , -0.09226837, 0.2829331 , -0.02094189, 0.02826234,
-0.3068198 , -0.41022655, -0.23714773, -0.00084166, 0.22521858,
-0.48155236, 0.29627186, 0.29841757, 0.16540456, 0.45836073,
@@ -1804,7 +1803,7 @@ class HolE(ComplEx):
>>> model.fit(X)
>>> model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]), get_ranks=True)
[0.3046168, -0.0379385]
>>> model.get_embeddings(['f','e'], type='entity')
>>> model.get_embeddings(['f','e'], embedding_type='entity')
array([[-0.2704807 , -0.05434025, 0.13363852, 0.04879733, 0.00184516,
-0.1149573 , -0.1177371 , -0.20798951, 0.01935115, 0.13033926,
-0.81528974, 0.22864424, 0.2045117 , 0.1145515 , 0.248952 ,
@@ -151,7 +151,7 @@ X = np.array([['a', 'y', 'b'],
['b', 'y', 'c'],
['f', 'y', 'e']])
model.fit(X)
model.get_embeddings(['f','e'], type='entity')
model.get_embeddings(['f','e'], embedding_type='entity')
```

## Save and restore a model
@@ -192,4 +192,4 @@ print(y_pred_after)
# Assert that the before and after values are same
assert(y_pred_before==y_pred_after)
```
```
@@ -31,7 +31,7 @@
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

embs = model.get_embeddings(embs_labels, type='entity')
embs = model.get_embeddings(embs_labels, embedding_type='entity')
embs_2d = TSNE(n_components=2).fit_transform(embs)

fig, ax = plt.subplots()
@@ -137,7 +137,7 @@ def test_lookup_embeddings():
['b', 'y', 'c'],
['f', 'y', 'e']])
model.fit(X)
model.get_embeddings(['a', 'b'], type='entity')
model.get_embeddings(['a', 'b'], embedding_type='entity')


def test_save_and_restore_model():
@@ -181,7 +181,7 @@ def test_save_and_restore_model():
y_pred_after, _ = loaded_model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]), get_ranks=True)
npt.assert_array_equal(y_pred_after, y_pred_before)

npt.assert_array_equal(loaded_model.get_embeddings(['a', 'b'], type='entity'),
model.get_embeddings(['a', 'b'], type='entity'))
npt.assert_array_equal(loaded_model.get_embeddings(['a', 'b'], embedding_type='entity'),
model.get_embeddings(['a', 'b'], embedding_type='entity'))

shutil.rmtree(EXAMPLE_LOC)

0 comments on commit ac515bc

Please sign in to comment.
You can’t perform that action at this time.