# Using Callbacks to Control Training
In this notebook, we will use the `Callbacks API` to stop training when a specified metric is met.
For example: if you set 1000 epochs and your desired accuracy is already reached at epoch 200, then the training will automatically stop.

## Load and Normalize the Fashion MNIST dataset
We will normalize the pixel values to help optimize the training.

In [2]:
import tensorflow as tf

# Instantiate the dataset API
fmnist = tf.keras.datasets.fashion_mnist

# Load the dataset
(x_train, y_train),(x_test, y_test) = fmnist.load_data()

# Normalize the pixel values
x_train, x_test = x_train / 255.0, x_test / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


## Creating a Callback class

* You can create a callback by defining a class that inherits the `tf.keras.callbacks.Callback` base class.
* From there, we can define available methods to set where the callback will be executed.
* For instance below, we will use the `on_epoch_end()` method to check the loss at each training epoch.

In [3]:
class myCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs={}):
    '''
    Halts the training after reaching 60 percent accuracy

    Args:
      epoch (integer) - index of epoch (required but unused in the function definition below)
      logs (dict) - metric results from the training epoch
    '''

    # Check accuracy
    if(logs.get('loss') < 0.4):

      # Stop if threshold is met
      print("\nLoss is lower than 0.4 so cancelling training!")
      self.model.stop_training = True

# Instantiate class
callbacks = myCallback()

## Define and compile the model
Next, you will define and compile the model. We will then set the optimizer, loss, and metrics that we will use for training.

In [4]:
#define the model
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28,28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
#compile the model
model.compile(optimizer=tf.optimizers.Adam(),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

## Train the model
Now, you are ready to train the model. To set the callback, simply set the `callbacks` parameter to the `myCallback` instance which we have declared before. 

In [5]:
#train the model with a callback
model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

Epoch 1/10
Epoch 2/10
Loss is lower than 0.4 so cancelling training!


<keras.callbacks.History at 0x7fac938d00a0>

You will notice that the training does not need to complete all 10 epochs. By having a callback at each end of the epoch, it is able to check the training parameters and compare if it meets the threshold you set in the function definition. In this case, it will simply stop when the loss falls below `0.40` after the current epoch.