In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds
import matplotlib.pyplot

In [2]:
(ds_train, ds_test), ds_info = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,  # will return tuple (img, label) otherwise dict
    with_info=True,      # able to get info about dataset
)


def normalize_img(image, label):
    """Normalizes images"""
    return tf.cast(image, tf.float32) / 255.0, label


AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 128

# Setup for train dataset
ds_train = ds_train.map(normalize_img, num_parallel_calls=AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.prefetch(AUTOTUNE)

# Setup for test dataset
ds_test = ds_test.map(normalize_img, num_parallel_calls=AUTOTUNE)
ds_test = ds_test.batch(BATCH_SIZE)
ds_test = ds_test.prefetch(AUTOTUNE)

In [6]:
save_callback = keras.callbacks.ModelCheckpoint(
  "checkpoint/",
  save_weights_only=True,
  monitor="accuracy",
  save_best_only=False
)

def scheduler(epoch, lr):
  if epoch < 2: return lr
  else: return lr*0.99 

lr_scheduler = keras.callbacks.LearningRateScheduler(
  scheduler, verbose=1
)

class CustomCallback(keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    if logs.get("accuracy") > 0.99:
      print("Accuracy over 99%, quitting training")
      self.model.stop_training = True

In [7]:
model = keras.Sequential(
    [
        keras.Input((28, 28, 1)),
        layers.Conv2D(32, 3, activation="relu"),
        layers.Flatten(),
        tf.keras.layers.Dense(10, activation="softmax"),
    ]
)

In [8]:
model.compile(
    optimizer=keras.optimizers.Adam(0.01),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"],
)

model.fit(
    ds_train,
    epochs=10,
    callbacks=[save_callback, lr_scheduler, CustomCallback()],
    verbose=True,
)

Epoch 1/10

Epoch 00001: LearningRateScheduler reducing learning rate to 0.009999999776482582.
Epoch 2/10

Epoch 00002: LearningRateScheduler reducing learning rate to 0.009999999776482582.
Epoch 3/10

Epoch 00003: LearningRateScheduler reducing learning rate to 0.009899999778717757.
Epoch 4/10

Epoch 00004: LearningRateScheduler reducing learning rate to 0.009800999946892262.
Accuracy over 99%, quitting training


<tensorflow.python.keras.callbacks.History at 0x7f9891f56880>