In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds

In [2]:
(ds_train, ds_test), ds_info = tfds.load(
    "mnist",
    split = ["train", "test"],
    shuffle_files = True,
    as_supervised = True,
    with_info = True,
)

In [3]:
def normalize_img(image, label):
    return tf.cast(image, tf.float32) / 255.0, label

In [4]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 32

ds_train = ds_train.map(normalize_img, num_parallel_calls=AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.prefetch(AUTOTUNE)

In [5]:
model = keras.Sequential([
    keras.Input((28,28,1)),
    layers.Conv2D(32, 3, activation='relu'),
    layers.Flatten(),
    layers.Dense(10),
])

In [6]:
save_callback = keras.callbacks.ModelCheckpoint(
    'checkpoint/',
    save_weights_only = True,
    monitor = 'accuracy',
    save_best_only = False,
)

In [7]:
def scheduler(epoch, lr):
    if epoch < 2:
        return lr
    else:
        return lr * 0.99

In [8]:
lr_scheduler = keras.callbacks.LearningRateScheduler(schedule=scheduler, verbose=1)

In [9]:
class CustomCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if logs.get('accuracy') > 0.90:
            print('Accuracy over 90%, quitting training')
            self.model.stop_training = True

In [10]:
model.compile(
    optimizer = keras.optimizers.Adam(0.01),
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ["accuracy"]
)

In [11]:
model.fit(
    ds_train, 
    epochs=5, 
    verbose=2, 
    callbacks = [save_callback, lr_scheduler, CustomCallback()]
)

Epoch 1/5

Epoch 00001: LearningRateScheduler setting learning rate to 0.009999999776482582.
1875/1875 - 29s - loss: 0.1467 - accuracy: 0.9565
Accuracy over 90%, quitting training


<keras.callbacks.History at 0x1dfd813ef48>