<a href="https://colab.research.google.com/github/ReutFarkash/useful/blob/main/TensorFlow_Tutorial_14.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[TensorFlow Tutorial 14 - Callbacks with Keras and Writing Custom Callbacks](https://www.youtube.com/watch?v=WUzLJZCKNu4&ab_channel=AladdinPersson)

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds

In [3]:
(ds_train, ds_test), ds_info = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True, # will return tuple (img, label) otherwise dict
    with_info=True, # able to get info about dataset
)

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…



[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


In [4]:
def normalize_img(image, label):
  """Normalizes images"""
  return tf.cast(image, tf.float32) / 255.0, label

In [5]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 128

# Setup for train dataset
ds_train = ds_train.map(normalize_img, num_parallel_calls=AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.prefetch(AUTOTUNE)

In [6]:
model = keras.Sequential(
    [
     keras.Input((28, 28, 1)),
     layers.Conv2D(32, 3, activation="relu"),
     layers.Flatten(),
     layers.Dense(10),
    ]
)

save_callback = keras.callbacks.ModelCheckpoint(
    'checkpoint/',
    save_weights_only=True,
    monitor='accuracy',
    save_best_only=False,
)

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(0.01),
    metrics=["accuracy"],
)

model.fit(ds_train, epochs=10, verbose=2, callbacks=[save_callback])

Epoch 1/10
469/469 - 15s - loss: 0.1576 - accuracy: 0.9530
Epoch 2/10
469/469 - 15s - loss: 0.0574 - accuracy: 0.9827
Epoch 3/10
469/469 - 14s - loss: 0.0381 - accuracy: 0.9880
Epoch 4/10
469/469 - 14s - loss: 0.0269 - accuracy: 0.9909
Epoch 5/10
469/469 - 14s - loss: 0.0177 - accuracy: 0.9941
Epoch 6/10
469/469 - 15s - loss: 0.0163 - accuracy: 0.9945
Epoch 7/10
469/469 - 15s - loss: 0.0149 - accuracy: 0.9948
Epoch 8/10
469/469 - 15s - loss: 0.0150 - accuracy: 0.9949
Epoch 9/10
469/469 - 15s - loss: 0.0112 - accuracy: 0.9961
Epoch 10/10
469/469 - 15s - loss: 0.0075 - accuracy: 0.9976


<tensorflow.python.keras.callbacks.History at 0x7f8192d13550>

In [7]:
def scheduler(epoch, lr):
  if epoch < 2:
    return lr
  else:
    return lr * 0.99

lr_scheduler= keras.callbacks.LearningRateScheduler(scheduler, verbose=1)

model.fit(ds_train, epochs=10, verbose=2, callbacks=[save_callback, lr_scheduler])


Epoch 00001: LearningRateScheduler reducing learning rate to 0.009999999776482582.
Epoch 1/10
469/469 - 15s - loss: 0.0083 - accuracy: 0.9973

Epoch 00002: LearningRateScheduler reducing learning rate to 0.009999999776482582.
Epoch 2/10
469/469 - 15s - loss: 0.0102 - accuracy: 0.9967

Epoch 00003: LearningRateScheduler reducing learning rate to 0.009899999778717757.
Epoch 3/10
469/469 - 15s - loss: 0.0152 - accuracy: 0.9949

Epoch 00004: LearningRateScheduler reducing learning rate to 0.009800999946892262.
Epoch 4/10
469/469 - 15s - loss: 0.0077 - accuracy: 0.9974

Epoch 00005: LearningRateScheduler reducing learning rate to 0.009702990353107453.
Epoch 5/10
469/469 - 15s - loss: 0.0092 - accuracy: 0.9970

Epoch 00006: LearningRateScheduler reducing learning rate to 0.009605960855260491.
Epoch 6/10
469/469 - 15s - loss: 0.0058 - accuracy: 0.9984

Epoch 00007: LearningRateScheduler reducing learning rate to 0.00950990131124854.
Epoch 7/10
469/469 - 15s - loss: 0.0071 - accuracy: 0.9978


<tensorflow.python.keras.callbacks.History at 0x7f8192d98b70>

In [9]:
class CustomCallback(keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print(logs.keys)

model.fit(ds_train, epochs=10, verbose=2, callbacks=[save_callback, lr_scheduler, CustomCallback()])


Epoch 00001: LearningRateScheduler reducing learning rate to 0.009227447211742401.
Epoch 1/10
<built-in method keys of dict object at 0x7f8194988048>
469/469 - 15s - loss: 0.0043 - accuracy: 0.9988

Epoch 00002: LearningRateScheduler reducing learning rate to 0.009227447211742401.
Epoch 2/10
<built-in method keys of dict object at 0x7f81948a4480>
469/469 - 15s - loss: 0.0035 - accuracy: 0.9988

Epoch 00003: LearningRateScheduler reducing learning rate to 0.009135172739624976.
Epoch 3/10
<built-in method keys of dict object at 0x7f8194988048>
469/469 - 15s - loss: 0.0073 - accuracy: 0.9979

Epoch 00004: LearningRateScheduler reducing learning rate to 0.009043820975348353.
Epoch 4/10
<built-in method keys of dict object at 0x7f8192d41120>
469/469 - 15s - loss: 0.0063 - accuracy: 0.9980

Epoch 00005: LearningRateScheduler reducing learning rate to 0.008953382922336458.
Epoch 5/10
<built-in method keys of dict object at 0x7f8194988d80>
469/469 - 15s - loss: 0.0048 - accuracy: 0.9987

Epoc

<tensorflow.python.keras.callbacks.History at 0x7f8192c0de48>

In [10]:
class CustomCallback(keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    if logs.get("accuracy") > 0.90:
      print("Accuracy over 90%, quitting training")
      self.model.stop_training = True

model.fit(ds_train, epochs=10, verbose=2, callbacks=[save_callback, lr_scheduler, CustomCallback()])


Epoch 00001: LearningRateScheduler reducing learning rate to 0.008514578454196453.
Epoch 1/10
Accuracy over 90%, quitting training
469/469 - 15s - loss: 0.0096 - accuracy: 0.9978


<tensorflow.python.keras.callbacks.History at 0x7f8192daeb00>