# Early Stopping
The following example shows how to use early stopping while training a model and how to create checkpoints.

In [2]:
import sys
sys.path.append('../..')
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.get_logger().setLevel('ERROR')

In [4]:
import ampligraph
# Benchmark datasets are under ampligraph.datasets module
from ampligraph.datasets import load_fb15k_237
# load fb15k-237 dataset
dataset = load_fb15k_237()

In [6]:
# Import the KGE model
from ampligraph.latent_features import ScoringBasedEmbeddingModel

# create the model with transe scoring function
model = ScoringBasedEmbeddingModel(eta=1, 
                                     k=10,
                                     scoring_type='TransE')


# compile the model with loss and optimizer
model.compile(optimizer='adam', loss='multiclass_nll')

# Use this for checkpoints at regular intervals
#checkpoint = tf.keras.callbacks.ModelCheckpoint('./chkpt_transe', monitor='val_mrr', verbose=1, save_best_only=True, mode='min')

# Use this for early stopping
checkpoint = tf.keras.callbacks.EarlyStopping(monitor="val_mrr",            # which metrics to monitor
                                              patience=3,                   # If the monitored metric doesnt improve 
                                                                            # for these many checks the model early stops
                                              verbose=1,                    # verbosity
                                              mode="max",                   # how to compare the monitored metrics. 
                                                                            # max - means higher is better
                                              restore_best_weights=True)    # restore the weights with best value

dataset = load_fb15k_237()

model.fit(dataset['train'],
          batch_size=10000,
          epochs=100,
          validation_freq=2,                            # Epochs to elapse before next evaluation
          validation_batch_size=100,                    
          validation_burn_in=1,                         # Epoch to start the validation process
          validation_data = dataset['valid'][::100],
          callbacks=[checkpoint])                       # Pass the callback to the fit function


Epoch 1/100
Epoch 2/100

2023-02-08 13:21:22.984591: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8


Epoch 3/100
Epoch 4/100

2023-02-08 13:21:26.191442: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8


Epoch 5/100
Epoch 6/100

2023-02-08 13:21:29.466724: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8


Epoch 7/100
Epoch 8/100

2023-02-08 13:21:32.519336: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8


Epoch 9/100
Epoch 10/100

2023-02-08 13:21:35.457800: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8


Epoch 11/100
Epoch 12/100
      2/Unknown - 0s 159ms/step===>..] - ETA: 0s - loss: 5171.8193

2023-02-08 13:21:39.521803: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8


Epoch 13/100
Epoch 14/100


2023-02-08 13:21:40.702638: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8


Epoch 15/100
Epoch 16/100

2023-02-08 13:21:42.851130: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8


Epoch 17/100
Epoch 18/100

2023-02-08 13:21:45.612853: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8


Epoch 19/100
Epoch 20/100

2023-02-08 13:21:47.841319: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8


Epoch 21/100
Epoch 22/100
      2/Unknown - 0s 152ms/step===>..] - ETA: 0s - loss: 4140.2354

2023-02-08 13:21:50.470017: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8


Restoring model weights from the end of the best epoch: 16.
Epoch 22: early stopping


<tensorflow.python.keras.callbacks.History at 0x28a6d6da0>

In [7]:
# evaluate on the test set
ranks = model.evaluate(dataset['test'], # test set
                       batch_size=100, # evaluation batch size
                       corrupt_side='s,o', 
                       use_filter={'train':dataset['train'], # Filter to be used for evaluation
                                   'valid':dataset['valid'],
                                   'test':dataset['test']}
                       )

# import the evaluation metrics
from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score

print('MR:', mr_score(ranks))
print('MRR:', mrr_score(ranks))
print('hits@1:', hits_at_n_score(ranks, 1))
print('hits@10:', hits_at_n_score(ranks, 10))

MR: 600.0142626480086
MRR: 0.18888812773966743
hits@1: 0.1267491926803014
hits@10: 0.3130687934240141


In [None]:
# if using ModelCheckpoint then we can restore the checkpoints using restore model
# from ampligraph.utils import restore_model
# model = restore_model('chkpt_transe')