# Implementing Callbacks in TensorFlow using the MNIST Dataset



In [9]:
import tensorflow as tf
from tensorflow import keras

## Load and inspect the data



In [None]:
# Load the data

# Get current working directory
data=tf.keras.datasets.mnist

# Discard test set
(x_train, y_train), _= data.load_data()
        
# Normalize pixel values
x_train = x_train / 255.0

Now take a look at the shape of the training data:

In [None]:
data_shape = x_train.shape

print(f"There are {data_shape[0]} examples with shape ({data_shape[1]}, {data_shape[2]})")

There are 60000 examples with shape (28, 28)


## Defining your callback

Now it is time to create your own custom callback. 

In [14]:
# CLASS: myCallback

# Remember to inherit from the correct class
class myCallback(tf.keras.callbacks.Callback):
       
        def on_epoch_end(self, epoch, logs={}):
            if logs.get('accuracy') is not None and logs.get('accuracy') > 0.99:
                print("\nReached 99% accuracy so cancelling training!") 
                
                # Stop training once the below condition is met
                self.model.stop_training = True




## Create and train your model

Now that you have defined your callback it is time to complete the `train_mnist` function below. 



In [15]:
# FUNCTION: train_mnist
def train_mnist(x_train, y_train):
    
    # Instantiate the callback class
    callbacks = myCallback()
    
    # Define the model
    model = tf.keras.models.Sequential([ 
        tf.keras.layers.Flatten(input_shape=(28,28)),
        tf.keras.layers.Dense(units=150,activation=tf.nn.relu),
        tf.keras.layers.Dense(units=10,activation=tf.nn.softmax)
    ]) 

    # Compile the model
    model.compile(optimizer='adam', 
                  loss='sparse_categorical_crossentropy', 
                  metrics=['accuracy']) 
    
    # Fit the model for 10 epochs adding the callbacks
    # and save the training history
    history = model.fit(x_train,y_train, epochs=10, callbacks=[callbacks])


    return history

Call the `train_mnist` passing in the appropiate parameters to get the training history:

In [16]:
hist = train_mnist(x_train, y_train)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Reached 99% accuracy so cancelling training!


If you see the message `Reached 99% accuracy so cancelling training!` printed out after less than 9 epochs it means your callback worked as expected. 