# Keras Custom callback 

This is a notebook for the medium article [A practical introduction to Keras Callbacks](https://medium.com/@bindiatwork/a-practical-introduction-to-keras-callbacks-in-tensorflow-2-705d0c584966)

Please check out article for instructions

**License**: [BSD 2-Clause](https://opensource.org/licenses/BSD-2-Clause)

In [1]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

from tensorflow import keras
import tensorflow as tf
import matplotlib.pyplot as plt

### Helper function to plot metric

In [2]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
def plot_metric(history, metric):
    train_metrics = history.history[metric]
    val_metrics = history.history['val_'+metric]
    epochs = range(1, len(train_metrics) + 1)
    plt.plot(epochs, train_metrics)
    plt.plot(epochs, val_metrics)
    plt.title('Training and validation '+ metric)
    plt.xlabel("Epochs")
    plt.ylabel(metric)
    plt.legend(["train_"+metric, 'val_'+metric])
    plt.show()

In [3]:
def plot_lr(history):
    learning_rate = history.history['lr']
    epochs = range(1, len(learning_rate) + 1)
    plt.plot(epochs, learning_rate)
    plt.title('Learning rate')
    plt.xlabel('Epochs')
    plt.ylabel('Learning rate')
    plt.show()

## Fashion MMIST dataset

In [4]:
fashion_mnist = keras.datasets.fashion_mnist
(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist.load_data()

In [5]:
X_train_full.shape

(60000, 28, 28)

In [6]:
X_train_full.dtype

dtype('uint8')

In [24]:
X_test.shape

(10000, 28, 28)

In [25]:
y_test.dtype

dtype('uint8')

In [7]:
# For faster training, let's use a subset 10,000
X_train, y_train = X_train_full[:10000] / 255.0, y_train_full[:10000]

## Building a NN model

In [8]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten

def create_model(): 
    model = Sequential([
        Flatten(input_shape=(28, 28)),
        Dense(300, activation='relu'),
        Dense(100, activation='relu'),
        Dense(10, activation='softmax'),
    ])
    model.compile(
        optimizer='sgd', 
        loss='sparse_categorical_crossentropy', 
        metrics=['accuracy']
    )
    return model

In [9]:
model = create_model()
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 300)               235500    
_________________________________________________________________
dense_1 (Dense)              (None, 100)               30100     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                1010      
Total params: 266,610
Trainable params: 266,610
Non-trainable params: 0
_________________________________________________________________


## 1. Building a custom callback for training

In [38]:
from tensorflow.keras.callbacks import Callback

class TrainingCallback(Callback):
    
    def on_train_begin(self, logs=None):
        print("Starting training...")
        
    def on_epoch_begin(self, epoch, logs=None):
        print(f"Starting epoch {epoch}")

    def on_train_batch_begin(self, batch, logs=None):
        print(f"Training: Starting batch {batch}")
        
    def on_train_batch_end(self, batch, logs=None):
        print(f"Training: Finished batch {batch}, loss is {logs['loss']}")
        
    def on_epoch_end(self, epoch, logs=None):
        print(f"Finished epoch {epoch}, loss is {logs['loss']}, accuracy is {logs['accuracy']}")
        
    def on_train_end(self, logs=None):
        print("Finished training")

In [39]:
history = model.fit(
    X_train, 
    y_train, 
    epochs=2, 
    validation_split=0.20, 
    batch_size=4000, 
    verbose=2,
    callbacks=[TrainingCallback()]
)

Train on 8000 samples, validate on 2000 samples
Starting training...
Starting epoch 0
Epoch 1/2
Training: Starting batch 0
Training: Finished batch 0, loss is 0.3087729215621948
Training: Starting batch 1
Training: Finished batch 1, loss is 0.302765429019928
Finished epoch 0, loss is 0.3057691752910614, accuracy is 0.8997499942779541
8000/8000 - 1s - loss: 0.3058 - accuracy: 0.8997 - val_loss: 0.4249 - val_accuracy: 0.8515
Starting epoch 1
Epoch 2/2
Training: Starting batch 0
Training: Finished batch 0, loss is 0.2967260777950287
Training: Starting batch 1
Training: Finished batch 1, loss is 0.31522929668426514
Finished epoch 1, loss is 0.3059776872396469, accuracy is 0.8993750214576721
8000/8000 - 0s - loss: 0.3060 - accuracy: 0.8994 - val_loss: 0.4249 - val_accuracy: 0.8510
Finished training


## 2. Building a custom callback for testing

In [19]:
class TestingCallback(Callback):
    
    def on_test_begin(self, logs=None):
        print("Starting testing ...")
        
    def on_test_batch_begin(self, batch, logs=None):
        print(f"Testing: Starting batch {batch}")
    
    def on_test_batch_end(self, batch, logs=None):
        print(f"Testing: Finished batch {batch}")
        
    def on_test_end(self, logs=None):
        print("Finished testing")

In [23]:
model.evaluate(X_test, y_test, verbose=False, callbacks=[TestingCallback()], batch_size=2000)

Starting testing ...
Testing: Starting batch 0
Testing: Finished batch 0
Testing: Starting batch 1
Testing: Finished batch 1
Testing: Starting batch 2
Testing: Finished batch 2
Testing: Starting batch 3
Testing: Finished batch 3
Testing: Starting batch 4
Testing: Finished batch 4
Finished testing


[85.61210479736329, 0.8063]

## 3. Building a custom callback for prediction

In [21]:
class PredictionCallback(Callback):
    
    def on_predict_begin(self, logs=None):
        print("Starting prediction ...")
    
    def on_predict_batch_begin(self, batch, logs=None):
        print(f"Prediction: Starting batch {batch}")
        
    def on_predict_batch_end(self, batch, logs=None):
        print(f"Prediction: Finish batch {batch}")
    
    def on_predict_end(self, logs=None):
        print("Finished prediction")

In [27]:
model.predict(X_test, verbose=False, callbacks=[PredictionCallback()], batch_size=2000)

Starting prediction ...
Prediction: Starting batch 0
Prediction: Finish batch 0
Prediction: Starting batch 1
Prediction: Finish batch 1
Prediction: Starting batch 2
Prediction: Finish batch 2
Prediction: Starting batch 3
Prediction: Finish batch 3
Prediction: Starting batch 4
Prediction: Finish batch 4
Finished prediction


array([[0.000000e+00, 0.000000e+00, 0.000000e+00, ..., 0.000000e+00,
        0.000000e+00, 1.000000e+00],
       [0.000000e+00, 0.000000e+00, 1.000000e+00, ..., 0.000000e+00,
        0.000000e+00, 0.000000e+00],
       [0.000000e+00, 1.000000e+00, 0.000000e+00, ..., 0.000000e+00,
        0.000000e+00, 0.000000e+00],
       ...,
       [0.000000e+00, 0.000000e+00, 0.000000e+00, ..., 0.000000e+00,
        1.000000e+00, 0.000000e+00],
       [0.000000e+00, 1.000000e+00, 0.000000e+00, ..., 0.000000e+00,
        0.000000e+00, 0.000000e+00],
       [0.000000e+00, 0.000000e+00, 0.000000e+00, ..., 2.315278e-25,
        0.000000e+00, 0.000000e+00]], dtype=float32)

## 2. Examples of Keras callback applications
Reference: https://keras.io/guides/writing_your_own_callbacks/

#### 2.1 Early stopping at minimum loss

In [42]:
import numpy as np


class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
    """Stop training when the loss is at its min, i.e. the loss stops decreasing.

  Arguments:
      patience: Number of epochs to wait after min has been hit. After this
      number of no improvement, training stops.
  """

    def __init__(self, patience=0):
        super(EarlyStoppingAtMinLoss, self).__init__()
        self.patience = patience
        # best_weights to store the weights at which the minimum loss occurs.
        self.best_weights = None

    def on_train_begin(self, logs=None):
        # The number of epoch it has waited when loss is no longer minimum.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_epoch = 0
        # Initialize the best as infinity.
        self.best = np.Inf

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get("loss")
        if np.less(current, self.best):
            self.best = current
            self.wait = 0
            # Record the best weights if current results is better (less).
            self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print("Restoring model weights from the end of the best epoch.")
                self.model.set_weights(self.best_weights)

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))

In [43]:
model = create_model()

In [45]:
history = model.fit(
    X_train, 
    y_train, 
    epochs=50, 
    validation_split=0.20, 
    batch_size=64, 
    verbose=2,
    callbacks=[EarlyStoppingAtMinLoss()]
)

Train on 8000 samples, validate on 2000 samples
Epoch 1/50
8000/8000 - 2s - loss: 1.2644 - accuracy: 0.6643 - val_loss: 1.0006 - val_accuracy: 0.6845
Epoch 2/50
8000/8000 - 1s - loss: 0.8767 - accuracy: 0.7197 - val_loss: 0.8218 - val_accuracy: 0.7325
Epoch 3/50
8000/8000 - 1s - loss: 0.7529 - accuracy: 0.7579 - val_loss: 0.7451 - val_accuracy: 0.7475
Epoch 4/50
8000/8000 - 1s - loss: 0.6833 - accuracy: 0.7774 - val_loss: 0.6963 - val_accuracy: 0.7685
Epoch 5/50
8000/8000 - 1s - loss: 0.6363 - accuracy: 0.7878 - val_loss: 0.6644 - val_accuracy: 0.7730
Epoch 6/50
8000/8000 - 1s - loss: 0.6030 - accuracy: 0.7996 - val_loss: 0.6251 - val_accuracy: 0.7880
Epoch 7/50
8000/8000 - 1s - loss: 0.5751 - accuracy: 0.8099 - val_loss: 0.5956 - val_accuracy: 0.7980
Epoch 8/50
8000/8000 - 1s - loss: 0.5531 - accuracy: 0.8166 - val_loss: 0.5838 - val_accuracy: 0.8005
Epoch 9/50
8000/8000 - 1s - loss: 0.5355 - accuracy: 0.8204 - val_loss: 0.5615 - val_accuracy: 0.8065
Epoch 10/50
8000/8000 - 1s - loss:

#### 2.2 Learning rate scheduler

In [46]:
class CustomLearningRateScheduler(Callback):
    """Learning rate scheduler which sets the learning rate according to schedule.

  Arguments:
      schedule: a function that takes an epoch index
          (integer, indexed from 0) and current learning rate
          as inputs and returns a new learning rate as output (float).
  """

    def __init__(self, schedule):
        super(CustomLearningRateScheduler, self).__init__()
        self.schedule = schedule

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, "lr"):
            raise ValueError('Optimizer must have a "lr" attribute.')
        # Get the current learning rate from model's optimizer.
        lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
        # Call schedule function to get the scheduled learning rate.
        scheduled_lr = self.schedule(epoch, lr)
        # Set the value back to the optimizer before this epoch starts
        tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
        print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr))


LR_SCHEDULE = [
    # (epoch to start, learning rate) tuples
    (3, 0.05),
    (6, 0.01),
    (9, 0.005),
    (12, 0.001),
]


def lr_schedule(epoch, lr):
    """Helper function to retrieve the scheduled learning rate based on epoch."""
    if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
        return lr
    for i in range(len(LR_SCHEDULE)):
        if epoch == LR_SCHEDULE[i][0]:
            return LR_SCHEDULE[i][1]
    return lr

In [47]:
model = create_model()

In [49]:
history = model.fit(
    X_train, 
    y_train, 
    epochs=15, 
    validation_split=0.20, 
    batch_size=64, 
    verbose=2,
    callbacks=[CustomLearningRateScheduler(lr_schedule)]
)

Train on 8000 samples, validate on 2000 samples

Epoch 00000: Learning rate is 0.0010.
Epoch 1/15
8000/8000 - 1s - loss: 0.3998 - accuracy: 0.8665 - val_loss: 0.4640 - val_accuracy: 0.8330

Epoch 00001: Learning rate is 0.0010.
Epoch 2/15
8000/8000 - 1s - loss: 0.3993 - accuracy: 0.8651 - val_loss: 0.4638 - val_accuracy: 0.8335

Epoch 00002: Learning rate is 0.0010.
Epoch 3/15
8000/8000 - 1s - loss: 0.3987 - accuracy: 0.8643 - val_loss: 0.4638 - val_accuracy: 0.8350

Epoch 00003: Learning rate is 0.0500.
Epoch 4/15
8000/8000 - 1s - loss: 0.4711 - accuracy: 0.8365 - val_loss: 0.4964 - val_accuracy: 0.8165

Epoch 00004: Learning rate is 0.0500.
Epoch 5/15
8000/8000 - 1s - loss: 0.4348 - accuracy: 0.8494 - val_loss: 0.5464 - val_accuracy: 0.8025

Epoch 00005: Learning rate is 0.0500.
Epoch 6/15
8000/8000 - 1s - loss: 0.4169 - accuracy: 0.8508 - val_loss: 0.4499 - val_accuracy: 0.8395

Epoch 00006: Learning rate is 0.0100.
Epoch 7/15
8000/8000 - 1s - loss: 0.3632 - accuracy: 0.8771 - val_l