# Using Callbacks in Keras

In this notebook, we well see how to use pre-defined and custom callbacks in Keras for tasks such as chekpointing, learning rate scheduling, etc.

We'll use the same simple dataset and linear model of the previous notebook.


In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing

2021-12-22 12:16:08.763485: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2021-12-22 12:16:08.763507: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


##### Download the Auto-MPG dataset
 
Let's use the Auto-MPG dataset (seen in a previous notebook).

In [2]:
url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data'
column_names = ['MPG', 'Cylinders', 'Displacement', 'Horsepower', 'Weight',
                'Acceleration', 'Model Year', 'Origin']

dataset = pd.read_csv(url, names=column_names, na_values='?', comment='\t', sep=' ', skipinitialspace=True)
dataset = dataset.dropna()
dataset['Origin'] = dataset['Origin'].map({1: 'USA', 2: 'Europe', 3: 'Japan'})
dataset = pd.get_dummies(dataset, prefix='', prefix_sep='')
dataset.tail()

Unnamed: 0,MPG,Cylinders,Displacement,Horsepower,Weight,Acceleration,Model Year,Europe,Japan,USA
393,27.0,4,140.0,86.0,2790.0,15.6,82,0,0,1
394,44.0,4,97.0,52.0,2130.0,24.6,82,1,0,0
395,32.0,4,135.0,84.0,2295.0,11.6,82,0,0,1
396,28.0,4,120.0,79.0,2625.0,18.6,82,0,0,1
397,31.0,4,119.0,82.0,2720.0,19.4,82,0,0,1


In [3]:
train_dataset = dataset.sample(frac=0.8, random_state=0)
test_dataset = dataset.drop(train_dataset.index)

train_features = train_dataset.copy()
test_features = test_dataset.copy()

train_labels = train_features.pop('MPG')
test_labels = test_features.pop('MPG')

##### Build the model

Let's build a simple linear regression model (seen in a previous notebook) to test different callbacks during its training.

We use a `get_model()` function so that we can re-create the model from scratch multiple times with a single instruction.

In [4]:
def get_model():
    normalizer = preprocessing.Normalization(input_shape=[9,])
    normalizer.adapt(np.array(train_features))
    
    model = keras.Sequential([
        normalizer,
        layers.Dense(units=1)
    ])
    
    model.compile(
        optimizer=tf.optimizers.Adam(learning_rate=0.1),
        loss='mse', metrics=['mae', 'mse']
    )
    
    return model

## Early Stopping callback

Let's use the early stopping callback to stop training when it reaches stability.

The `monitor` parameter specifies the loss/metric to be monitored, and the `patience` parameters specifies the number of non-improving epochs to wait before stopping. 

In [5]:
es_callback = keras.callbacks.EarlyStopping(monitor='val_loss', patience=2, verbose=1)

# re-create the model to restart training every time
model = get_model()
history = model.fit(train_features, train_labels, epochs=200, validation_split = 0.2, callbacks=[es_callback])

2021-12-22 12:16:12.406346: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2021-12-22 12:16:12.406370: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
2021-12-22 12:16:12.406388: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (matteo-Inspiron-7591-2n1): /proc/driver/nvidia/version does not exist
2021-12-22 12:16:12.406572: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200
Epoch 34/200
Epoch 35/200
Epoch 36/200
Epoch 37/200
Epoch 38/200
Epoch 39/200
Epoch 40/200
Epoch 41/200
Epoch 42/200
Epoch 43/200
Epoch 44/200
Epoch 45/200
Epoch 46/200
Epoch 47/200
Epoch 48/200
Epoch 49/200


Epoch 50/200
Epoch 51/200
Epoch 52/200
Epoch 53/200
Epoch 54/200
Epoch 55/200
Epoch 56/200
Epoch 57/200
Epoch 58/200
Epoch 59/200
Epoch 60/200
Epoch 61/200
Epoch 62/200
Epoch 63/200
Epoch 64/200
Epoch 65/200
Epoch 66/200
Epoch 67/200
Epoch 68/200
Epoch 69/200
Epoch 70/200
Epoch 71/200
Epoch 72/200
Epoch 73/200
Epoch 00073: early stopping


As you can see, the training stopped after about 60/70 epochs, rather than running for the entire 200 epochs specified in `fit()`.

## Checkpoint Callback

Let's add a second callback to save a model checkpoint after every epoch. Notice that we can pass multiple callbacks at the same time to `fit()`.

In [6]:
cp_callback = keras.callbacks.ModelCheckpoint(
    './callback_test_chkp/chkp_{epoch:02d}',
    # './callback_test_chkp/chkp_best',
    monitor='val_loss',
    verbose=0, 
    save_best_only=False,
    # save_best_only=True,
    save_weights_only=False,
    mode='auto',
    save_freq='epoch'
)

In [7]:
model = get_model()
history = model.fit(train_features, train_labels, epochs=200, validation_split = 0.2,
                                callbacks=[es_callback, cp_callback])

Epoch 1/200
1/8 [==>...........................] - ETA: 1s - loss: 612.8845 - mae: 23.1756 - mse: 612.8845

2021-12-22 12:16:15.643760: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_01/assets
Epoch 2/200
1/8 [==>...........................] - ETA: 0s - loss: 638.7021 - mae: 24.4261 - mse: 638.7021INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_02/assets
Epoch 3/200
1/8 [==>...........................] - ETA: 0s - loss: 491.9919 - mae: 21.8199 - mse: 491.9919INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_03/assets
Epoch 4/200
1/8 [==>...........................] - ETA: 0s - loss: 432.2855 - mae: 20.5799 - mse: 432.2855INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_04/assets
Epoch 5/200
1/8 [==>...........................] - ETA: 0s - loss: 431.1570 - mae: 20.4707 - mse: 431.1570INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_05/assets
Epoch 6/200
1/8 [==>...........................] - ETA: 0s - loss: 410.1916 - mae: 19.7040 - mse: 410.1916INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_06/assets
Epoch 7/200
1/8 [==>........................

Epoch 26/200
1/8 [==>...........................] - ETA: 0s - loss: 68.1808 - mae: 7.3463 - mse: 68.1808INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_26/assets
Epoch 27/200
1/8 [==>...........................] - ETA: 0s - loss: 69.1897 - mae: 7.7204 - mse: 69.1897INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_27/assets
Epoch 28/200
1/8 [==>...........................] - ETA: 0s - loss: 49.1182 - mae: 6.3059 - mse: 49.1182INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_28/assets
Epoch 29/200
1/8 [==>...........................] - ETA: 0s - loss: 65.9604 - mae: 7.3686 - mse: 65.9604INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_29/assets
Epoch 30/200
1/8 [==>...........................] - ETA: 0s - loss: 43.6574 - mae: 5.5425 - mse: 43.6574INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_30/assets
Epoch 31/200
1/8 [==>...........................] - ETA: 0s - loss: 30.4164 - mae: 4.9916 - mse: 30.4164INFO:tensorflow:Asset

Epoch 51/200
1/8 [==>...........................] - ETA: 0s - loss: 11.5428 - mae: 2.4929 - mse: 11.5428INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_51/assets
Epoch 52/200
1/8 [==>...........................] - ETA: 0s - loss: 20.9639 - mae: 3.4833 - mse: 20.9639INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_52/assets
Epoch 53/200
1/8 [==>...........................] - ETA: 0s - loss: 19.2376 - mae: 3.2313 - mse: 19.2376INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_53/assets
Epoch 54/200
1/8 [==>...........................] - ETA: 0s - loss: 10.4956 - mae: 2.5923 - mse: 10.4956INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_54/assets
Epoch 55/200
1/8 [==>...........................] - ETA: 0s - loss: 16.4567 - mae: 2.7160 - mse: 16.4567INFO:tensorflow:Assets written to: ./callback_test_chkp/chkp_55/assets
Epoch 56/200
1/8 [==>...........................] - ETA: 0s - loss: 7.8315 - mae: 2.2470 - mse: 7.8315INFO:tensorflow:Assets 

Epoch 00075: early stopping


##### Restore a saved checkpoint

Let's try loading back two different models, and let's evaluate them on training data.

In [8]:
model_epoch1 = keras.models.load_model('./callback_test_chkp/chkp_01')
model_epoch1.evaluate(train_features, train_labels,)



[534.7635498046875, 22.51128578186035, 534.7635498046875]

In [9]:
model_epoch10 = keras.models.load_model('./callback_test_chkp/chkp_10')
model_epoch10.evaluate(train_features, train_labels,)



[263.253662109375, 15.80517578125, 263.253662109375]

## Learning Rate Scheduling

Let's try to change the learning rate by reducing it by 0.01 after every epoch. This is just to demonstrate LR scheduling, it is not a particularly useful scheduling mechanism.

In [10]:
def my_schedule(epoch, lr):
    return max(lr - 0.01, 0.01)

Test if the schedule works for different input LR values.

In [11]:
print(my_schedule(1, 0.05))
print(my_schedule(1, 0.01))

0.04
0.01


In [12]:
lr_callback = keras.callbacks.LearningRateScheduler(my_schedule, verbose=1)

In [13]:
model = get_model()
history = model.fit(train_features, train_labels, epochs=200, validation_split = 0.2,
                                callbacks=[lr_callback, es_callback])


Epoch 00001: LearningRateScheduler setting learning rate to 0.09000000149011612.
Epoch 1/200

Epoch 00002: LearningRateScheduler setting learning rate to 0.08000000357627869.
Epoch 2/200

Epoch 00003: LearningRateScheduler setting learning rate to 0.07000000566244126.
Epoch 3/200

Epoch 00004: LearningRateScheduler setting learning rate to 0.06000000774860382.
Epoch 4/200

Epoch 00005: LearningRateScheduler setting learning rate to 0.05000000610947609.
Epoch 5/200

Epoch 00006: LearningRateScheduler setting learning rate to 0.040000004470348356.
Epoch 6/200

Epoch 00007: LearningRateScheduler setting learning rate to 0.030000002831220625.
Epoch 7/200

Epoch 00008: LearningRateScheduler setting learning rate to 0.020000003054738043.
Epoch 8/200

Epoch 00009: LearningRateScheduler setting learning rate to 0.010000003278255462.
Epoch 9/200

Epoch 00010: LearningRateScheduler setting learning rate to 0.01.
Epoch 10/200

Epoch 00011: LearningRateScheduler setting learning rate to 0.01.
Epo


Epoch 00033: LearningRateScheduler setting learning rate to 0.01.
Epoch 33/200

Epoch 00034: LearningRateScheduler setting learning rate to 0.01.
Epoch 34/200

Epoch 00035: LearningRateScheduler setting learning rate to 0.01.
Epoch 35/200

Epoch 00036: LearningRateScheduler setting learning rate to 0.01.
Epoch 36/200

Epoch 00037: LearningRateScheduler setting learning rate to 0.01.
Epoch 37/200

Epoch 00038: LearningRateScheduler setting learning rate to 0.01.
Epoch 38/200

Epoch 00039: LearningRateScheduler setting learning rate to 0.01.
Epoch 39/200

Epoch 00040: LearningRateScheduler setting learning rate to 0.01.
Epoch 40/200

Epoch 00041: LearningRateScheduler setting learning rate to 0.01.
Epoch 41/200

Epoch 00042: LearningRateScheduler setting learning rate to 0.01.
Epoch 42/200

Epoch 00043: LearningRateScheduler setting learning rate to 0.01.
Epoch 43/200

Epoch 00044: LearningRateScheduler setting learning rate to 0.01.
Epoch 44/200

Epoch 00045: LearningRateScheduler sett


Epoch 00066: LearningRateScheduler setting learning rate to 0.01.
Epoch 66/200

Epoch 00067: LearningRateScheduler setting learning rate to 0.01.
Epoch 67/200

Epoch 00068: LearningRateScheduler setting learning rate to 0.01.
Epoch 68/200

Epoch 00069: LearningRateScheduler setting learning rate to 0.01.
Epoch 69/200

Epoch 00070: LearningRateScheduler setting learning rate to 0.01.
Epoch 70/200

Epoch 00071: LearningRateScheduler setting learning rate to 0.01.
Epoch 71/200

Epoch 00072: LearningRateScheduler setting learning rate to 0.01.
Epoch 72/200

Epoch 00073: LearningRateScheduler setting learning rate to 0.01.
Epoch 73/200

Epoch 00074: LearningRateScheduler setting learning rate to 0.01.
Epoch 74/200

Epoch 00075: LearningRateScheduler setting learning rate to 0.01.
Epoch 75/200

Epoch 00076: LearningRateScheduler setting learning rate to 0.01.
Epoch 76/200

Epoch 00077: LearningRateScheduler setting learning rate to 0.01.
Epoch 77/200

Epoch 00078: LearningRateScheduler sett


Epoch 00098: LearningRateScheduler setting learning rate to 0.01.
Epoch 98/200

Epoch 00099: LearningRateScheduler setting learning rate to 0.01.
Epoch 99/200

Epoch 00100: LearningRateScheduler setting learning rate to 0.01.
Epoch 100/200

Epoch 00101: LearningRateScheduler setting learning rate to 0.01.
Epoch 101/200

Epoch 00102: LearningRateScheduler setting learning rate to 0.01.
Epoch 102/200

Epoch 00103: LearningRateScheduler setting learning rate to 0.01.
Epoch 103/200

Epoch 00104: LearningRateScheduler setting learning rate to 0.01.
Epoch 104/200

Epoch 00105: LearningRateScheduler setting learning rate to 0.01.
Epoch 105/200

Epoch 00106: LearningRateScheduler setting learning rate to 0.01.
Epoch 106/200

Epoch 00107: LearningRateScheduler setting learning rate to 0.01.
Epoch 107/200

Epoch 00108: LearningRateScheduler setting learning rate to 0.01.
Epoch 108/200

Epoch 00109: LearningRateScheduler setting learning rate to 0.01.
Epoch 109/200

Epoch 00110: LearningRateSche


Epoch 00131: LearningRateScheduler setting learning rate to 0.01.
Epoch 131/200

Epoch 00132: LearningRateScheduler setting learning rate to 0.01.
Epoch 132/200

Epoch 00133: LearningRateScheduler setting learning rate to 0.01.
Epoch 133/200

Epoch 00134: LearningRateScheduler setting learning rate to 0.01.
Epoch 134/200

Epoch 00135: LearningRateScheduler setting learning rate to 0.01.
Epoch 135/200

Epoch 00136: LearningRateScheduler setting learning rate to 0.01.
Epoch 136/200

Epoch 00137: LearningRateScheduler setting learning rate to 0.01.
Epoch 137/200

Epoch 00138: LearningRateScheduler setting learning rate to 0.01.
Epoch 138/200

Epoch 00139: LearningRateScheduler setting learning rate to 0.01.
Epoch 139/200

Epoch 00140: LearningRateScheduler setting learning rate to 0.01.
Epoch 140/200

Epoch 00141: LearningRateScheduler setting learning rate to 0.01.
Epoch 141/200

Epoch 00142: LearningRateScheduler setting learning rate to 0.01.
Epoch 142/200

Epoch 00143: LearningRateSc


Epoch 00163: LearningRateScheduler setting learning rate to 0.01.
Epoch 163/200

Epoch 00164: LearningRateScheduler setting learning rate to 0.01.
Epoch 164/200

Epoch 00165: LearningRateScheduler setting learning rate to 0.01.
Epoch 165/200

Epoch 00166: LearningRateScheduler setting learning rate to 0.01.
Epoch 166/200

Epoch 00167: LearningRateScheduler setting learning rate to 0.01.
Epoch 167/200

Epoch 00168: LearningRateScheduler setting learning rate to 0.01.
Epoch 168/200

Epoch 00169: LearningRateScheduler setting learning rate to 0.01.
Epoch 169/200

Epoch 00170: LearningRateScheduler setting learning rate to 0.01.
Epoch 170/200

Epoch 00171: LearningRateScheduler setting learning rate to 0.01.
Epoch 171/200

Epoch 00172: LearningRateScheduler setting learning rate to 0.01.
Epoch 172/200

Epoch 00173: LearningRateScheduler setting learning rate to 0.01.
Epoch 173/200

Epoch 00174: LearningRateScheduler setting learning rate to 0.01.
Epoch 174/200

Epoch 00175: LearningRateSc


Epoch 00197: LearningRateScheduler setting learning rate to 0.01.
Epoch 197/200

Epoch 00198: LearningRateScheduler setting learning rate to 0.01.
Epoch 198/200

Epoch 00199: LearningRateScheduler setting learning rate to 0.01.
Epoch 199/200

Epoch 00200: LearningRateScheduler setting learning rate to 0.01.
Epoch 200/200


## Custom Callback N.1

Let's write a simple custom callback that logs the loss and metrics values after every batch, epoch, etc.

In [14]:
class CustomLogger(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        print("Starting training; log content: {}".format(logs))

    def on_train_end(self, logs=None):
        print("Stop training; got log content: {}".format(logs))

    def on_epoch_end(self, epoch, logs=None):
        print("End epoch {} of training; log content: {}".format(epoch, logs))

    def on_train_batch_end(self, batch, logs=None):
        print("...Training: end of batch {}; log content: {}".format(batch, logs))


In [15]:
log_callback = CustomLogger()

model = get_model()
history = model.fit(train_features, train_labels, epochs=200,
                                verbose=0, #verbose=0 to avoid mixing our prints and those of standard keras
                                validation_split = 0.2,
                                callbacks=[log_callback]
)

Starting training; log content: {}
...Training: end of batch 0; log content: {'loss': 640.1168212890625, 'mae': 23.924179077148438, 'mse': 640.1168212890625}
...Training: end of batch 1; log content: {'loss': 601.4976196289062, 'mae': 23.367958068847656, 'mse': 601.4976196289062}
...Training: end of batch 2; log content: {'loss': 603.247802734375, 'mae': 23.63275718688965, 'mse': 603.247802734375}
...Training: end of batch 3; log content: {'loss': 586.8490600585938, 'mae': 23.314189910888672, 'mse': 586.8490600585938}
...Training: end of batch 4; log content: {'loss': 568.4512939453125, 'mae': 22.980281829833984, 'mse': 568.4512939453125}
...Training: end of batch 5; log content: {'loss': 558.9595336914062, 'mae': 22.837554931640625, 'mse': 558.9595336914062}
...Training: end of batch 6; log content: {'loss': 556.0956420898438, 'mae': 22.81576919555664, 'mse': 556.0956420898438}
...Training: end of batch 7; log content: {'loss': 561.2073974609375, 'mae': 22.94468879699707, 'mse': 561.2

End epoch 10 of training; log content: {'loss': 250.2487030029297, 'mae': 15.372559547424316, 'mse': 250.2487030029297, 'val_loss': 259.7608642578125, 'val_mae': 15.729279518127441, 'val_mse': 259.7608642578125}
...Training: end of batch 0; log content: {'loss': 219.69247436523438, 'mae': 14.310229301452637, 'mse': 219.69247436523438}
...Training: end of batch 1; log content: {'loss': 249.23289489746094, 'mae': 15.20467758178711, 'mse': 249.23289489746094}
...Training: end of batch 2; log content: {'loss': 241.60801696777344, 'mae': 14.985664367675781, 'mse': 241.60801696777344}
...Training: end of batch 3; log content: {'loss': 240.0593719482422, 'mae': 14.980049133300781, 'mse': 240.0593719482422}
...Training: end of batch 4; log content: {'loss': 239.7289581298828, 'mae': 14.99125862121582, 'mse': 239.7289581298828}
...Training: end of batch 5; log content: {'loss': 230.0460662841797, 'mae': 14.690673828125, 'mse': 230.0460662841797}
...Training: end of batch 6; log content: {'loss'

End epoch 18 of training; log content: {'loss': 124.58759307861328, 'mae': 10.566314697265625, 'mse': 124.58759307861328, 'val_loss': 131.84349060058594, 'val_mae': 10.96011734008789, 'val_mse': 131.84349060058594}
...Training: end of batch 0; log content: {'loss': 116.57064819335938, 'mae': 10.363588333129883, 'mse': 116.57064819335938}
...Training: end of batch 1; log content: {'loss': 119.81245422363281, 'mae': 10.290020942687988, 'mse': 119.81245422363281}
...Training: end of batch 2; log content: {'loss': 112.76578521728516, 'mae': 10.046494483947754, 'mse': 112.76578521728516}
...Training: end of batch 3; log content: {'loss': 115.21142578125, 'mae': 10.063802719116211, 'mse': 115.21142578125}
...Training: end of batch 4; log content: {'loss': 114.1043930053711, 'mae': 10.024415969848633, 'mse': 114.1043930053711}
...Training: end of batch 5; log content: {'loss': 113.20852661132812, 'mae': 9.982121467590332, 'mse': 113.20852661132812}
...Training: end of batch 6; log content: {'

...Training: end of batch 5; log content: {'loss': 54.75857925415039, 'mae': 6.580890655517578, 'mse': 54.75857925415039}
...Training: end of batch 6; log content: {'loss': 54.31869888305664, 'mae': 6.525269985198975, 'mse': 54.31869888305664}
...Training: end of batch 7; log content: {'loss': 53.55525207519531, 'mae': 6.506133556365967, 'mse': 53.55525207519531}
End epoch 27 of training; log content: {'loss': 53.55525207519531, 'mae': 6.506133556365967, 'mse': 53.55525207519531, 'val_loss': 58.736698150634766, 'val_mae': 6.916505813598633, 'val_mse': 58.736698150634766}
...Training: end of batch 0; log content: {'loss': 61.73976135253906, 'mae': 6.933187484741211, 'mse': 61.73976135253906}
...Training: end of batch 1; log content: {'loss': 47.07685089111328, 'mae': 6.088812828063965, 'mse': 47.07685089111328}
...Training: end of batch 2; log content: {'loss': 48.54850387573242, 'mae': 6.21239709854126, 'mse': 48.54850387573242}
...Training: end of batch 3; log content: {'loss': 47.838

End epoch 34 of training; log content: {'loss': 29.056724548339844, 'mae': 4.451137065887451, 'mse': 29.056724548339844, 'val_loss': 32.411991119384766, 'val_mae': 4.766209125518799, 'val_mse': 32.411991119384766}
...Training: end of batch 0; log content: {'loss': 21.57598114013672, 'mae': 3.6450138092041016, 'mse': 21.57598114013672}
...Training: end of batch 1; log content: {'loss': 19.100364685058594, 'mae': 3.505295753479004, 'mse': 19.100364685058594}
...Training: end of batch 2; log content: {'loss': 25.434341430664062, 'mae': 3.9524612426757812, 'mse': 25.434341430664062}
...Training: end of batch 3; log content: {'loss': 27.3979549407959, 'mae': 4.21675968170166, 'mse': 27.3979549407959}
...Training: end of batch 4; log content: {'loss': 26.446508407592773, 'mae': 4.189908027648926, 'mse': 26.446508407592773}
...Training: end of batch 5; log content: {'loss': 27.263036727905273, 'mae': 4.226528167724609, 'mse': 27.263036727905273}
...Training: end of batch 6; log content: {'los

End epoch 44 of training; log content: {'loss': 15.798630714416504, 'mae': 3.032130002975464, 'mse': 15.798630714416504, 'val_loss': 16.72330093383789, 'val_mae': 3.077984571456909, 'val_mse': 16.72330093383789}
...Training: end of batch 0; log content: {'loss': 11.792201042175293, 'mae': 2.804126024246216, 'mse': 11.792201042175293}
...Training: end of batch 1; log content: {'loss': 11.98503303527832, 'mae': 2.845733165740967, 'mse': 11.98503303527832}
...Training: end of batch 2; log content: {'loss': 16.22673988342285, 'mae': 2.99578857421875, 'mse': 16.22673988342285}
...Training: end of batch 3; log content: {'loss': 17.248491287231445, 'mae': 3.1394951343536377, 'mse': 17.248491287231445}
...Training: end of batch 4; log content: {'loss': 16.820524215698242, 'mae': 3.121908187866211, 'mse': 16.820524215698242}
...Training: end of batch 5; log content: {'loss': 15.829340934753418, 'mae': 3.031987190246582, 'mse': 15.829340934753418}
...Training: end of batch 6; log content: {'loss

End epoch 54 of training; log content: {'loss': 12.426100730895996, 'mae': 2.6393463611602783, 'mse': 12.426100730895996, 'val_loss': 12.23901653289795, 'val_mae': 2.651272773742676, 'val_mse': 12.23901653289795}
...Training: end of batch 0; log content: {'loss': 19.046146392822266, 'mae': 2.969933271408081, 'mse': 19.046146392822266}
...Training: end of batch 1; log content: {'loss': 13.875015258789062, 'mae': 2.702345848083496, 'mse': 13.875015258789062}
...Training: end of batch 2; log content: {'loss': 12.090901374816895, 'mae': 2.5573532581329346, 'mse': 12.090901374816895}
...Training: end of batch 3; log content: {'loss': 10.916969299316406, 'mae': 2.389885902404785, 'mse': 10.916969299316406}
...Training: end of batch 4; log content: {'loss': 12.513747215270996, 'mae': 2.5495100021362305, 'mse': 12.513747215270996}
...Training: end of batch 5; log content: {'loss': 11.698369026184082, 'mae': 2.5066885948181152, 'mse': 11.698369026184082}
...Training: end of batch 6; log content

End epoch 64 of training; log content: {'loss': 11.636542320251465, 'mae': 2.5599405765533447, 'mse': 11.636542320251465, 'val_loss': 10.813929557800293, 'val_mae': 2.5090341567993164, 'val_mse': 10.813929557800293}
...Training: end of batch 0; log content: {'loss': 14.56159782409668, 'mae': 2.4341213703155518, 'mse': 14.56159782409668}
...Training: end of batch 1; log content: {'loss': 12.815101623535156, 'mae': 2.378309726715088, 'mse': 12.815101623535156}
...Training: end of batch 2; log content: {'loss': 12.389732360839844, 'mae': 2.429386615753174, 'mse': 12.389732360839844}
...Training: end of batch 3; log content: {'loss': 12.077654838562012, 'mae': 2.418034076690674, 'mse': 12.077654838562012}
...Training: end of batch 4; log content: {'loss': 11.399609565734863, 'mae': 2.4300196170806885, 'mse': 11.399609565734863}
...Training: end of batch 5; log content: {'loss': 11.196110725402832, 'mae': 2.479538917541504, 'mse': 11.196110725402832}
...Training: end of batch 6; log content

End epoch 74 of training; log content: {'loss': 11.479355812072754, 'mae': 2.5637998580932617, 'mse': 11.479355812072754, 'val_loss': 10.369196891784668, 'val_mae': 2.4663989543914795, 'val_mse': 10.369196891784668}
...Training: end of batch 0; log content: {'loss': 9.048422813415527, 'mae': 2.208756446838379, 'mse': 9.048422813415527}
...Training: end of batch 1; log content: {'loss': 9.464361190795898, 'mae': 2.3993887901306152, 'mse': 9.464361190795898}
...Training: end of batch 2; log content: {'loss': 9.09804630279541, 'mae': 2.347954034805298, 'mse': 9.09804630279541}
...Training: end of batch 3; log content: {'loss': 10.923328399658203, 'mae': 2.6119256019592285, 'mse': 10.923328399658203}
...Training: end of batch 4; log content: {'loss': 10.52726936340332, 'mae': 2.564666986465454, 'mse': 10.52726936340332}
...Training: end of batch 5; log content: {'loss': 11.826176643371582, 'mae': 2.681410074234009, 'mse': 11.826176643371582}
...Training: end of batch 6; log content: {'loss

End epoch 84 of training; log content: {'loss': 11.429457664489746, 'mae': 2.575195550918579, 'mse': 11.429457664489746, 'val_loss': 10.240166664123535, 'val_mae': 2.4483139514923096, 'val_mse': 10.240166664123535}
...Training: end of batch 0; log content: {'loss': 9.13519287109375, 'mae': 2.277327537536621, 'mse': 9.13519287109375}
...Training: end of batch 1; log content: {'loss': 8.557243347167969, 'mae': 2.1775014400482178, 'mse': 8.557243347167969}
...Training: end of batch 2; log content: {'loss': 9.054030418395996, 'mae': 2.2506089210510254, 'mse': 9.054030418395996}
...Training: end of batch 3; log content: {'loss': 9.46218490600586, 'mae': 2.3568379878997803, 'mse': 9.46218490600586}
...Training: end of batch 4; log content: {'loss': 9.656797409057617, 'mae': 2.4288620948791504, 'mse': 9.656797409057617}
...Training: end of batch 5; log content: {'loss': 11.01569652557373, 'mae': 2.514904260635376, 'mse': 11.01569652557373}
...Training: end of batch 6; log content: {'loss': 10

...Training: end of batch 2; log content: {'loss': 9.048181533813477, 'mae': 2.336766004562378, 'mse': 9.048181533813477}
...Training: end of batch 3; log content: {'loss': 9.840761184692383, 'mae': 2.3899600505828857, 'mse': 9.840761184692383}
...Training: end of batch 4; log content: {'loss': 12.092401504516602, 'mae': 2.6098341941833496, 'mse': 12.092401504516602}
...Training: end of batch 5; log content: {'loss': 12.012706756591797, 'mae': 2.641591787338257, 'mse': 12.012706756591797}
...Training: end of batch 6; log content: {'loss': 11.758798599243164, 'mae': 2.6484549045562744, 'mse': 11.758798599243164}
...Training: end of batch 7; log content: {'loss': 11.340327262878418, 'mae': 2.5844709873199463, 'mse': 11.340327262878418}
End epoch 94 of training; log content: {'loss': 11.340327262878418, 'mae': 2.5844709873199463, 'mse': 11.340327262878418, 'val_loss': 10.257661819458008, 'val_mae': 2.4530436992645264, 'val_mse': 10.257661819458008}
...Training: end of batch 0; log content

...Training: end of batch 3; log content: {'loss': 12.283809661865234, 'mae': 2.6045279502868652, 'mse': 12.283809661865234}
...Training: end of batch 4; log content: {'loss': 11.806069374084473, 'mae': 2.5780189037323, 'mse': 11.806069374084473}
...Training: end of batch 5; log content: {'loss': 11.243138313293457, 'mae': 2.5275561809539795, 'mse': 11.243138313293457}
...Training: end of batch 6; log content: {'loss': 11.667283058166504, 'mae': 2.5884616374969482, 'mse': 11.667283058166504}
...Training: end of batch 7; log content: {'loss': 11.369519233703613, 'mae': 2.580326557159424, 'mse': 11.369519233703613}
End epoch 103 of training; log content: {'loss': 11.369519233703613, 'mae': 2.580326557159424, 'mse': 11.369519233703613, 'val_loss': 10.07176399230957, 'val_mae': 2.429921865463257, 'val_mse': 10.07176399230957}
...Training: end of batch 0; log content: {'loss': 15.642499923706055, 'mae': 2.772655963897705, 'mse': 15.642499923706055}
...Training: end of batch 1; log content: 

End epoch 111 of training; log content: {'loss': 11.360349655151367, 'mae': 2.585927724838257, 'mse': 11.360349655151367, 'val_loss': 10.078371047973633, 'val_mae': 2.431715726852417, 'val_mse': 10.078371047973633}
...Training: end of batch 0; log content: {'loss': 5.728188514709473, 'mae': 2.0338966846466064, 'mse': 5.728188514709473}
...Training: end of batch 1; log content: {'loss': 8.649191856384277, 'mae': 2.3612937927246094, 'mse': 8.649191856384277}
...Training: end of batch 2; log content: {'loss': 10.046685218811035, 'mae': 2.3580052852630615, 'mse': 10.046685218811035}
...Training: end of batch 3; log content: {'loss': 10.109508514404297, 'mae': 2.3672149181365967, 'mse': 10.109508514404297}
...Training: end of batch 4; log content: {'loss': 10.394903182983398, 'mae': 2.4409701824188232, 'mse': 10.394903182983398}
...Training: end of batch 5; log content: {'loss': 10.401373863220215, 'mae': 2.4355785846710205, 'mse': 10.401373863220215}
...Training: end of batch 6; log conten

...Training: end of batch 0; log content: {'loss': 14.621796607971191, 'mae': 2.9855029582977295, 'mse': 14.621796607971191}
...Training: end of batch 1; log content: {'loss': 11.047760009765625, 'mae': 2.6430020332336426, 'mse': 11.047760009765625}
...Training: end of batch 2; log content: {'loss': 11.862040519714355, 'mae': 2.6891214847564697, 'mse': 11.862040519714355}
...Training: end of batch 3; log content: {'loss': 11.529729843139648, 'mae': 2.5956473350524902, 'mse': 11.529729843139648}
...Training: end of batch 4; log content: {'loss': 11.087690353393555, 'mae': 2.565415859222412, 'mse': 11.087690353393555}
...Training: end of batch 5; log content: {'loss': 10.77916431427002, 'mae': 2.532031774520874, 'mse': 10.77916431427002}
...Training: end of batch 6; log content: {'loss': 10.736971855163574, 'mae': 2.5389304161071777, 'mse': 10.736971855163574}
...Training: end of batch 7; log content: {'loss': 11.383538246154785, 'mae': 2.575547218322754, 'mse': 11.383538246154785}
End e

...Training: end of batch 5; log content: {'loss': 12.460037231445312, 'mae': 2.699021100997925, 'mse': 12.460037231445312}
...Training: end of batch 6; log content: {'loss': 11.960405349731445, 'mae': 2.651123285293579, 'mse': 11.960405349731445}
...Training: end of batch 7; log content: {'loss': 11.31297779083252, 'mae': 2.579160451889038, 'mse': 11.31297779083252}
End epoch 129 of training; log content: {'loss': 11.31297779083252, 'mae': 2.579160451889038, 'mse': 11.31297779083252, 'val_loss': 10.041238784790039, 'val_mae': 2.4488422870635986, 'val_mse': 10.041238784790039}
...Training: end of batch 0; log content: {'loss': 13.854005813598633, 'mae': 2.9107556343078613, 'mse': 13.854005813598633}
...Training: end of batch 1; log content: {'loss': 11.085378646850586, 'mae': 2.673971652984619, 'mse': 11.085378646850586}
...Training: end of batch 2; log content: {'loss': 12.655142784118652, 'mae': 2.838080644607544, 'mse': 12.655142784118652}
...Training: end of batch 3; log content: {

End epoch 138 of training; log content: {'loss': 11.33808422088623, 'mae': 2.580759286880493, 'mse': 11.33808422088623, 'val_loss': 10.014967918395996, 'val_mae': 2.423929452896118, 'val_mse': 10.014967918395996}
...Training: end of batch 0; log content: {'loss': 10.517919540405273, 'mae': 2.504194736480713, 'mse': 10.517919540405273}
...Training: end of batch 1; log content: {'loss': 11.551580429077148, 'mae': 2.5665111541748047, 'mse': 11.551580429077148}
...Training: end of batch 2; log content: {'loss': 9.770895957946777, 'mae': 2.37129282951355, 'mse': 9.770895957946777}
...Training: end of batch 3; log content: {'loss': 9.949148178100586, 'mae': 2.3982324600219727, 'mse': 9.949148178100586}
...Training: end of batch 4; log content: {'loss': 10.388580322265625, 'mae': 2.501370906829834, 'mse': 10.388580322265625}
...Training: end of batch 5; log content: {'loss': 11.81733226776123, 'mae': 2.6298418045043945, 'mse': 11.81733226776123}
...Training: end of batch 6; log content: {'los

...Training: end of batch 3; log content: {'loss': 12.91079044342041, 'mae': 2.7449913024902344, 'mse': 12.91079044342041}
...Training: end of batch 4; log content: {'loss': 11.446213722229004, 'mae': 2.5737760066986084, 'mse': 11.446213722229004}
...Training: end of batch 5; log content: {'loss': 11.35612964630127, 'mae': 2.541987180709839, 'mse': 11.35612964630127}
...Training: end of batch 6; log content: {'loss': 11.355813026428223, 'mae': 2.5619640350341797, 'mse': 11.355813026428223}
...Training: end of batch 7; log content: {'loss': 11.338687896728516, 'mae': 2.5768258571624756, 'mse': 11.338687896728516}
End epoch 147 of training; log content: {'loss': 11.338687896728516, 'mae': 2.5768258571624756, 'mse': 11.338687896728516, 'val_loss': 9.946730613708496, 'val_mae': 2.442250967025757, 'val_mse': 9.946730613708496}
...Training: end of batch 0; log content: {'loss': 9.944671630859375, 'mae': 2.5002951622009277, 'mse': 9.944671630859375}
...Training: end of batch 1; log content: {

...Training: end of batch 6; log content: {'loss': 11.480706214904785, 'mae': 2.591892719268799, 'mse': 11.480706214904785}
...Training: end of batch 7; log content: {'loss': 11.413748741149902, 'mae': 2.599073886871338, 'mse': 11.413748741149902}
End epoch 155 of training; log content: {'loss': 11.413748741149902, 'mae': 2.599073886871338, 'mse': 11.413748741149902, 'val_loss': 10.048480033874512, 'val_mae': 2.421628713607788, 'val_mse': 10.048480033874512}
...Training: end of batch 0; log content: {'loss': 8.000626564025879, 'mae': 2.3386712074279785, 'mse': 8.000626564025879}
...Training: end of batch 1; log content: {'loss': 11.644508361816406, 'mae': 2.4196040630340576, 'mse': 11.644508361816406}
...Training: end of batch 2; log content: {'loss': 11.344955444335938, 'mae': 2.4757702350616455, 'mse': 11.344955444335938}
...Training: end of batch 3; log content: {'loss': 11.897747039794922, 'mae': 2.5746986865997314, 'mse': 11.897747039794922}
...Training: end of batch 4; log conten

End epoch 164 of training; log content: {'loss': 11.432588577270508, 'mae': 2.606843948364258, 'mse': 11.432588577270508, 'val_loss': 9.924283027648926, 'val_mae': 2.4150688648223877, 'val_mse': 9.924283027648926}
...Training: end of batch 0; log content: {'loss': 10.877325057983398, 'mae': 2.730469226837158, 'mse': 10.877325057983398}
...Training: end of batch 1; log content: {'loss': 10.55762767791748, 'mae': 2.592839479446411, 'mse': 10.55762767791748}
...Training: end of batch 2; log content: {'loss': 9.60256576538086, 'mae': 2.506728410720825, 'mse': 9.60256576538086}
...Training: end of batch 3; log content: {'loss': 8.848785400390625, 'mae': 2.3971240520477295, 'mse': 8.848785400390625}
...Training: end of batch 4; log content: {'loss': 9.819811820983887, 'mae': 2.495211601257324, 'mse': 9.819811820983887}
...Training: end of batch 5; log content: {'loss': 11.377223014831543, 'mae': 2.5955944061279297, 'mse': 11.377223014831543}
...Training: end of batch 6; log content: {'loss':

End epoch 174 of training; log content: {'loss': 11.291980743408203, 'mae': 2.573643684387207, 'mse': 11.291980743408203, 'val_loss': 10.011096954345703, 'val_mae': 2.445298433303833, 'val_mse': 10.011096954345703}
...Training: end of batch 0; log content: {'loss': 10.098746299743652, 'mae': 2.3413033485412598, 'mse': 10.098746299743652}
...Training: end of batch 1; log content: {'loss': 14.49561882019043, 'mae': 2.7893435955047607, 'mse': 14.49561882019043}
...Training: end of batch 2; log content: {'loss': 12.478515625, 'mae': 2.554910898208618, 'mse': 12.478515625}
...Training: end of batch 3; log content: {'loss': 13.329014778137207, 'mae': 2.743533134460449, 'mse': 13.329014778137207}
...Training: end of batch 4; log content: {'loss': 12.412055969238281, 'mae': 2.625807523727417, 'mse': 12.412055969238281}
...Training: end of batch 5; log content: {'loss': 11.858630180358887, 'mae': 2.5915772914886475, 'mse': 11.858630180358887}
...Training: end of batch 6; log content: {'loss': 1

...Training: end of batch 3; log content: {'loss': 10.248334884643555, 'mae': 2.587010145187378, 'mse': 10.248334884643555}
...Training: end of batch 4; log content: {'loss': 10.638434410095215, 'mae': 2.606750965118408, 'mse': 10.638434410095215}
...Training: end of batch 5; log content: {'loss': 10.939873695373535, 'mae': 2.610248327255249, 'mse': 10.939873695373535}
...Training: end of batch 6; log content: {'loss': 11.34882926940918, 'mae': 2.6258223056793213, 'mse': 11.34882926940918}
...Training: end of batch 7; log content: {'loss': 11.323558807373047, 'mae': 2.5732691287994385, 'mse': 11.323558807373047}
End epoch 185 of training; log content: {'loss': 11.323558807373047, 'mae': 2.5732691287994385, 'mse': 11.323558807373047, 'val_loss': 10.019936561584473, 'val_mae': 2.440904378890991, 'val_mse': 10.019936561584473}
...Training: end of batch 0; log content: {'loss': 5.464108467102051, 'mae': 1.8658342361450195, 'mse': 5.464108467102051}
...Training: end of batch 1; log content:

...Training: end of batch 4; log content: {'loss': 11.849218368530273, 'mae': 2.6018729209899902, 'mse': 11.849218368530273}
...Training: end of batch 5; log content: {'loss': 11.50069808959961, 'mae': 2.5721538066864014, 'mse': 11.50069808959961}
...Training: end of batch 6; log content: {'loss': 11.498367309570312, 'mae': 2.5616977214813232, 'mse': 11.498367309570312}
...Training: end of batch 7; log content: {'loss': 11.356362342834473, 'mae': 2.5678436756134033, 'mse': 11.356362342834473}
End epoch 195 of training; log content: {'loss': 11.356362342834473, 'mae': 2.5678436756134033, 'mse': 11.356362342834473, 'val_loss': 10.046046257019043, 'val_mae': 2.4511568546295166, 'val_mse': 10.046046257019043}
...Training: end of batch 0; log content: {'loss': 8.762008666992188, 'mae': 2.4012558460235596, 'mse': 8.762008666992188}
...Training: end of batch 1; log content: {'loss': 11.206226348876953, 'mae': 2.597317695617676, 'mse': 11.206226348876953}
...Training: end of batch 2; log conte

# Custom Callback N. 2

Let us write another example of custom callback. This time, let's implement a custom early stopping mechanism on a pre-defined Validation MAE value.

In [16]:
class MyEarlyStopping(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs):
        if(logs['val_mae']< 10.0):
            print("\nReached MAE < 10.0, so cancelling training!")
            self.model.stop_training = True


In [17]:
my_es_callback = MyEarlyStopping()

model = get_model()
history = model.fit(train_features, train_labels, epochs=200, validation_split = 0.2, callbacks=[my_es_callback])

Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
1/8 [==>...........................] - ETA: 0s - loss: 97.8023 - mae: 9.4119 - mse: 97.8023
Reached MAE < 10.0, so cancelling training!
