# Partitioned Training
The following example shows how to train a model with partitioning.

In [3]:
import sys
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
import numpy as np
import ampligraph

## Train and predict scores

In [4]:
# Import the KGE model
from ampligraph.latent_features import ScoringBasedEmbeddingModel

PATH_TO_DATASET = 'your/path/to/dataset/'

# create the model with transe scoring function
partitioned_model = ScoringBasedEmbeddingModel(eta=2, 
                                               k=50, 
                                               scoring_type='TransE')
partitioned_model.compile(optimizer='adam', loss='multiclass_nll')

# Here we have specified the path of the input file
# you can also load using default dataloaders load_fb15k_237() and pass numpy array inputs
partitioned_model.fit(PATH_TO_DATASET + 'wn18RR/train.txt',
                      batch_size=10000, 
                      partitioning_k=3, # set flag to partition the inputs
                      epochs=10)


_split: memory before: 848.0Bytes, after: 4.3447MB, consumed: 4.3439MB; exec time: 29.242s


2023-02-08 16:47:49.873938: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x28f31d630>

In [6]:
# Unfiltered evaluation
ranks = partitioned_model.evaluate(PATH_TO_DATASET + 'wn18RR/test.txt', 
                                   batch_size=400)

from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score
mr_score(ranks), mrr_score(ranks), hits_at_n_score(ranks, 1), hits_at_n_score(ranks, 10), len(ranks)



210 triples containing invalid keys skipped!


(20079.140731874144, 0.011132840015629617, 0.0, 0.03625170998632011, 2924)

In [7]:
# Filtered evaluation
ranks = partitioned_model.evaluate(PATH_TO_DATASET + 'wn18RR/test.txt', 
                        batch_size=400,
                        corrupt_side='s,o',
                        use_filter={'train': PATH_TO_DATASET + 'wn18RR/train.txt',
                                    'valid': PATH_TO_DATASET + 'wn18RR/valid.txt',
                                    'test': PATH_TO_DATASET + 'wn18RR/test.txt'})

mr_score(ranks), mrr_score(ranks), hits_at_n_score(ranks, 1), hits_at_n_score(ranks, 10), len(ranks)



210 triples containing invalid keys skipped!

210 triples containing invalid keys skipped!

210 triples containing invalid keys skipped!


(20066.594562243503,
 0.01583735697421522,
 0.005471956224350205,
 0.038132694938440494,
 2924)

In [8]:
from ampligraph.utils import save_model
save_model(model=partitioned_model, model_name_path='./partitioned_model')

The path ./partitioned_model already exists. This save operation will overwrite the model                 at the specified path.


In [10]:
from ampligraph.utils import restore_model
model = restore_model('./partitioned_model')

Saved model does not include a db file. Skipping.


In [11]:
# Unfiltered evaluation
ranks = model.evaluate(PATH_TO_DATASET + 'wn18RR/test.txt',
                       batch_size=400)

from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score
mr_score(ranks), mrr_score(ranks), hits_at_n_score(ranks, 1), hits_at_n_score(ranks, 10), len(ranks)


210 triples containing invalid keys skipped!


(20079.140731874144, 0.011132840015629617, 0.0, 0.03625170998632011, 2924)

In [12]:
ranks = model.evaluate(PATH_TO_DATASET + 'wn18RR/test.txt', 
                        batch_size=400,
                        corrupt_side='s,o',
                        use_filter={'train': PATH_TO_DATASET + 'wn18RR/train.txt',
                                    'valid': PATH_TO_DATASET + 'wn18RR/valid.txt',
                                    'test': PATH_TO_DATASET + 'wn18RR/test.txt'})

mr_score(ranks), mrr_score(ranks), hits_at_n_score(ranks, 1), hits_at_n_score(ranks, 10), len(ranks)



210 triples containing invalid keys skipped!

210 triples containing invalid keys skipped!

210 triples containing invalid keys skipped!


(20066.594562243503,
 0.01583735697421522,
 0.005471956224350205,
 0.038132694938440494,
 2924)