# Backward Compatibility with AmpliGraph 1

The main difference in the API of AmpliGraph 2 is how you import the models and evaluate performance.
We still provide backward compatibility with the APIs of AmpliGraph 1 through the module ampligraph.compat.

In [1]:
import sys
sys.path.append('../..')
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
import numpy as np

In [2]:
import ampligraph
# load the dataset
from ampligraph.datasets import load_wn18rr
X = load_wn18rr()

In [3]:
# Import the models from ampligraph.compat
# AmpliGraph 2 APIs support TransE, DistMult, ComplEx, HolE

from ampligraph.compat import DistMult

model = DistMult(batches_count=10, seed=0, epochs=500, k=350, eta=10,
                    # Use adam optimizer with learning rate 1e-3
                    optimizer='adam', optimizer_params={'lr':1e-3},
                    # Use multiclass_nll loss 
                    loss='multiclass_nll', loss_params={},
                    # Use L3 regularizer with regularizer weight 1e-3
                    regularizer='LP', regularizer_params={'p':3, 'lambda':1e-3}, 
                    # Enable stdout messages (set to false if you don't want to display)
                    verbose=True)


Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB



In [4]:
# Create the filter
filter = np.concatenate((X['train'], X['valid'][::10], X['test']))


In [5]:
# Fit the model on training and validation set
model.fit(X['train'][::2], 
          early_stopping = True,
          early_stopping_params = \
                  {
                      'x_valid': X['valid'][::10],  # validation set
                      'criteria':'hits@10',         # Uses hits10 criteria for early stopping
                      'burn_in': 20,                # early stopping kicks in after 100 epochs
                      'check_interval':20,          # validates every 20th epoch
                      'stop_interval':5,            # stops if 5 successive validation checks are bad.
                      'x_filter': filter,           # Use filter for filtering out positives 
                      'corruption_entities':'all',  # corrupt using all entities
                      'corrupt_side':'s'            # corrupt only subject
                  }
          )

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
73 triples containing invalid keys skipped!

749 triples containing invalid keys skipped!
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
73 triples containing invalid keys skipped!

749 triples containing invalid keys skipped!
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500
Epoch 56/500
Epoch 57/500
Epoch 58/500
Epoch 59/500
Epoch 60/500
73 triples containing invalid keys skipped!

749 

In [6]:
X_test = X['test']
X_test[:2]

array([['06845599', '_member_of_domain_usage', '03754979'],
       ['00789448', '_verb_group', '01062739']], dtype=object)

In [7]:
# Score assigned to unseen triples
model.predict(X_test[:2])


1 triples containing invalid keys skipped!


array([-0.3750222], dtype=float32)

In [8]:
# Get embedding of entities
embed = model.get_embeddings(['11647131','02518161'], embedding_type='entity')
print('Embedding size: ', embed.shape[1])
# Notice that the embedding size for ComplEx is double
# compared to the k specified when initializing the model,
# since ComplEx embeddings live in the space of complex numbers.
print('\n Embedding vectors: ')
print(embed)

Embedding size:  350

 Embedding vectors: 
[[ 0.11967714  0.14262393 -0.15412526  0.14064218 -0.03678143  0.17855963
   0.13154885  0.1871646   0.15934017  0.13739511  0.19862601  0.13786295
  -0.19559897 -0.19016445  0.16443959 -0.19838691  0.17837389  0.02906989
  -0.19489937 -0.14237712 -0.16674249  0.17720278  0.06629281  0.10147594
  -0.07981575 -0.15890078 -0.1647761  -0.11785773 -0.20817536  0.15641469
   0.20843565  0.15639313  0.16039309 -0.15879703 -0.15178938 -0.16976918
   0.17478164 -0.19067667 -0.12253983 -0.19105424 -0.14591898 -0.20406112
  -0.17460953 -0.19648121 -0.09293007 -0.19525793  0.18996416  0.17143488
   0.1854977   0.20067944 -0.18529464  0.1355615   0.06794422 -0.19307975
  -0.18535511  0.1253285  -0.18829922  0.1228343   0.13849634 -0.18207617
   0.19859134  0.16638105  0.20536643  0.18736392 -0.18714447  0.15228283
   0.1753144   0.17433217 -0.17006505 -0.1720568   0.19801918  0.18579687
  -0.18681724 -0.1791706   0.1113568   0.14846447 -0.20054623  0.2037

In [9]:
# get the entity and relation mappings to emb matrix
ent_to_idx, rel_to_idx = model.get_hyperparameter_dict()
len(ent_to_idx), len(rel_to_idx)

(33117, 11)

In [10]:
# import the evaluate_performance API from compat module
from ampligraph.compat import evaluate_performance
ranks = evaluate_performance(X_test, model, filter_triples=filter, corrupt_side='s,o', verbose=True)

# import the evaluation metrics
from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score

print('MR:', mr_score(ranks))
print('MRR:', mrr_score(ranks))
print('hits@1:', hits_at_n_score(ranks, 1))
print('hits@10:', hits_at_n_score(ranks, 10))


676 triples containing invalid keys skipped!

749 triples containing invalid keys skipped!
MR: 9791.32028469751
MRR: 0.266744577802399
hits@1: 0.2257562277580071
hits@10: 0.33451957295373663


In [11]:
from ampligraph.utils import save_model
# save the model
save_model(model, 'backward_model')




In [12]:
from ampligraph.utils import restore_model

# restore saved models or checkpoints
res_model = restore_model('backward_model')

Saved model does not include a db file. Skipping.


In [13]:
# import the evaluate_performance API from compat module
from ampligraph.compat import evaluate_performance
ranks = evaluate_performance(X_test, res_model, filter_triples=filter, corrupt_side='s,o', verbose=True)

# import the evaluation metrics
from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score

print('MR:', mr_score(ranks))
print('MRR:', mrr_score(ranks))
print('hits@1:', hits_at_n_score(ranks, 1))
print('hits@10:', hits_at_n_score(ranks, 10))


676 triples containing invalid keys skipped!

749 triples containing invalid keys skipped!
MR: 9791.32028469751
MRR: 0.266744577802399
hits@1: 0.2257562277580071
hits@10: 0.33451957295373663


# Discovery
The APIs for knowledge discovery can be imported from the ampligraph.discovery modules.
They are designed to be backward compatible.

In [15]:
from ampligraph.discovery import discover_facts

discover_facts(X['train'][:100], 
               res_model, 
               top_n=100, 
               strategy='entity_frequency', 
               max_candidates=100, 
               target_rel='/location/country/form_of_government', 
               seed=0)


(array([['/m/06w99h3', '/location/country/form_of_government', '/m/09nqf'],
        ['/m/0fvf9q', '/location/country/form_of_government',
         '/m/05b4l5x']], dtype=object),
 array([27.5, 47.5]))