### CallBack:
A callback is an object that can perform actions at various stages of training (e.g. at the start or end of an epoch, before or after a single batch, etc).

Common Usages:
- Write TensorBoard logs after every batch of training to monitor your metrics
- Periodically save your model to disk
- Do early stopping
- Get a view on internal states and statistics of a model during training

Common Builtin callback:
1. `tf.keras.callbacks.ModelCheckpoint( filepath, monitor: str = "val_loss", verbose: int = 0, save_best_only: bool = False, save_weights_only: bool = False, mode: str = "auto", save_freq="epoch", options=None, initial_value_threshold=None, **kwargs ):`ModelCheckpoint callback is used in conjunction with training using model.fit() to save a model or weights (in a checkpoint file) at some interval, so the model or weights can be loaded later to continue the training from the state saved.

In [30]:
import matplotlib.pyplot
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds

layers = keras.layers
regularizers = keras.regularizers

#HYPERPARAMETERS
BATCH_SIZE = 32
WEIGHT_DECAY = 0.001
LEARNING_RATE = 0.001
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [31]:
(ds_train, ds_test), ds_info = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
# print(ds_info)

In [32]:
def normalize_img(image, label):
    """Normalizes images"""
    return tf.cast(image, tf.float32) / 255.0, label

In [33]:
# Setup for train dataset
ds_train = ds_train.map(normalize_img, num_parallel_calls=AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.prefetch(AUTOTUNE)

In [34]:
model = keras.Sequential(
    [
        keras.Input((28, 28, 1)),
        layers.Conv2D(32, 3, activation="relu"),
        layers.Flatten(),
        tf.keras.layers.Dense(10, activation="softmax"),
    ]
)

model.compile(
    optimizer=keras.optimizers.Adam(0.01),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"],
)

In [35]:
## ModelCheckpoint:
model_save_callback = tf.keras.callbacks.ModelCheckpoint(
    "checkpoint/", save_weights_only=True, monitor="train_acc" , save_best_only=False,
)

In [36]:
def scheduler(epoch, lr):
    if epoch<2:
        return lr
    else:
        return lr*0.99
lr_scheduler = keras.callbacks.LearningRateScheduler(schedule=scheduler, verbose=1)

In [40]:
class CustomCallBack(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(logs.keys())
        if logs.get("accuracy")> 0.90:
            print("Accuracy over 90% quiting training.")
            self.model.stop_training = True


In [41]:
model.fit(
    ds_train,
    epochs=3,
    callbacks=[model_save_callback, lr_scheduler, CustomCallBack()],
    verbose=2,
)


Epoch 1: LearningRateScheduler setting learning rate to 0.00989999994635582.
Epoch 1/3


dict_keys(['loss', 'accuracy', 'lr'])
Accuracy over 90% quiting training.
1875/1875 - 18s - loss: 0.0411 - accuracy: 0.9870 - lr: 0.0099 - 18s/epoch - 10ms/step


<keras.callbacks.History at 0x1fdf7b75850>