In [1]:
import tensorflow as tf

# Instantiate the dataset API
fmnist = tf.keras.datasets.fashion_mnist

# Load the dataset
(x_train, y_train),(x_test, y_test) = fmnist.load_data()

# Normalize the pixel values
x_train, x_test = x_train / 255.0, x_test / 255.0

## Creating a Callback class

You can create a callback by defining a class that inherits the [tf.keras.callbacks.Callback](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback) base class. From there, you can define available methods to set where the callback will be executed. For instance below, you will use the [on_epoch_end()](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback#on_epoch_end) method to check the loss at each training epoch.

<h3>An overview of callback methods</h3>

* Global methods

>on_(train|test|predict)_begin(self, logs=None) Called at the beginning of fit/evaluate/predict.

>on_(train|test|predict)_end(self, logs=None) Called at the end of fit/evaluate/predict.

* Batch-level methods for training/testing/predicting

> on_(train|test|predict)_batch_begin(self, batch, logs=None)
Called right before processing a batch during training/testing/predicting.

> on_(train|test|predict)_batch_end(self, batch, logs=None)
Called at the end of training/testing/predicting a batch. Within this method, logs is a dict containing the metrics results.

* Epoch-level methods (training only)

>on_epoch_begin(self, epoch, logs=None)
Called at the beginning of an epoch during training.

>on_epoch_end(self, epoch, logs=None)
Called at the end of an epoch during training.

In [20]:
class myCallback(tf.keras.callbacks.Callback):
    
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))
        
    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))
        print("Final Loss: {}".format(logs["loss"]))
        print("Final accuracy: {}".format(logs["accuracy"]))
        
    def on_epoch_begin(self, epoch, logs={}):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
    
    def on_epoch_end(self, epoch, logs={}):
        '''
        Halts the training after reaching 60 percent accuracy

        Args:
          epoch (integer) - index of epoch (required but unused in the function definition below)
          logs (dict) - metric results from the training epoch
        '''

        # Check accuracy
        if(logs.get('loss') < 0.4):

          # Stop if threshold is met
          print("\nLoss is lower than 0.4 so cancelling training!")
          self.model.stop_training = True

# Instantiate class
callbacks = myCallback()

## Define and compile the model

Next, you will define and compile the model. The architecture will be similar to the one you built in the previous lab. Afterwards, you will set the optimizer, loss, and metrics that you will use for training.

In [37]:
# Define the model
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

# Compile the model
model.compile(optimizer=tf.optimizers.Adam(),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])



### Train the model

Now you are ready to train the model. To set the callback, simply set the `callbacks` parameter to the `myCallback` instance you declared before. Run the cell below and observe what happens.

You can pass a list of callbacks (as the keyword argument callbacks) to the following model methods:


* keras.Model.fit()
* keras.Model.evaluate()
* keras.Model.predict()

In [None]:
# Train the model with a callback
model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

Starting training; got log keys: []
Start epoch 0 of training; got log keys: []
Epoch 1/10