Callbacks are a way to customize the behaviour of our model in the training or evaluation process.

In [1]:
import os
import matplotlib.pyplot

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds

In [2]:
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [3]:
(ds_train, ds_test), ds_info = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,  # will return tuple (img, label) otherwise dict
    with_info=True,  # able to get info about dataset
)

In [4]:
def normalize_img(image, label):
    """Normalizes images"""
    return tf.cast(image, tf.float32) / 255.0, label

AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 128

In [5]:
# Setup for train dataset
ds_train = ds_train.map(normalize_img, num_parallel_calls=AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.prefetch(AUTOTUNE)

In [18]:
model = keras.Sequential(
    [
        keras.Input((28, 28, 1)),
        layers.Conv2D(32, 3, activation="relu"),
        layers.Flatten(),
        tf.keras.layers.Dense(10, activation="softmax"),
    ]
)

### Callbacks
We have the model set up and the compile and fit code below. Now we will define a few callbacks

#### Save Callback
This is used to save the model at different time in the training

In [19]:
save_callback = keras.callbacks.ModelCheckpoint(
    'checkpoint/', save_weights_only = True, monitor = 'accuracy', save_best_only = False
)

#### Learning Rate Scheduler
change learning rate with the epochs changes

In [20]:
# custom define a function to schedule LR
def scheduler(epoch, lr):
    if epoch < 2:
        return lr
    else:
        return lr * 0.999 #reduce by 0.1 percent every epoch

In [21]:
lr_scheduler = keras.callbacks.LearningRateScheduler(
    scheduler, verbose = 1
)

#### Custom Callback
We will write a class to inherit from keras callbacks module. The Documenation has a great article as cited here. [Documentation on Callbacks](https://www.tensorflow.org/guide/keras/custom_callback)

In [22]:
class CustomCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        #print(logs.keys()) #print keys of dict of every epoch
        if logs.get('accuracy') > 0.90: #
            print('Accuracy over 90%, Quitting Training..')
            self.model.stop_training = True

### Model Compiling and Training

In [23]:
model.compile(
    optimizer=keras.optimizers.Adam(0.01),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"],
)

In [24]:
model.fit(
    ds_train,
    epochs=10,
    callbacks=[save_callback,lr_scheduler, CustomCallback()],
    verbose=2,
)


Epoch 00001: LearningRateScheduler reducing learning rate to 0.009999999776482582.
Epoch 1/10
Accuracy over 90%, Quitting Training..
469/469 - 1s - loss: 0.1525 - accuracy: 0.9542 - lr: 0.0100


<tensorflow.python.keras.callbacks.History at 0x7f1f37699040>