In [8]:
{}
## Tensorflow Custom Training with Custom Callbacks
# Custom Training & Evaluation based on https://keras.io/guides/writing_a_training_loop_from_scratch/
# Reduce learning rate on plateau based on https://keras.io/api/callbacks/reduce_lr_on_plateau/
## Tutorial by Alexander Pelkmann 2021 
# Linkedin --> https://www.linkedin.com/in/alexander-pelkmann
# Github --> https://github.com/Pelk89
# Medium --> https://medium.com/@alexander.pelkmann


{}

In [9]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

## Custom Modification: import custom reduce learning rate on plateau callback
from custom_callback import CustomReduceLRoP


In [10]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value


@tf.function
def test_step(x, y):
    val_logits = model(x, training=False)
    val_acc_metric.update_state(y, val_logits)



In [11]:
# Instantiate an optimizer.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)


In [12]:
# Get model
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)

# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

# Prepare the training dataset.
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(64)


In [14]:

## Custom Modification: Set reduce_lin true for linear reducing; Set factor to reduce learning rate by 0.0001

reduce_rl_plateau = CustomReduceLRoP(patience=10,
                              factor=0.0001,
                              verbose=1, 
                              optim_lr=optimizer.learning_rate, 
                              reduce_lin=True)



In [15]:
epochs = 15

## Custom Modification: Reset cooldown and wait timer for the callback
reduce_rl_plateau.on_train_begin()

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_batch_train, y_batch_train)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * 64))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        test_step(x_batch_val, y_batch_val)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()

 
    print("Validation acc: %.4f" % (float(val_acc),))

    ## Custom Modification: pass epoch and validation loss to the callback 
    reduce_rl_plateau.on_epoch_end(epoch, val_acc)




Start of epoch 0
Training loss (for one batch) at step 0: 111.1472
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.0064
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.3015
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 1.4842
Seen so far: 38464 samples
Training loss (for one batch) at step 800: 0.7201
Seen so far: 51264 samples
Training acc over epoch: 0.6875
Validation acc: 0.8217

Start of epoch 1
Training loss (for one batch) at step 0: 0.6239
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.8006
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.7073
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.6513
Seen so far: 38464 samples
Training loss (for one batch) at step 800: 0.7789
Seen so far: 51264 samples
Training acc over epoch: 0.8340
Validation acc: 0.8689

Start of epoch 2
Training loss (for one batch) at step 0: 0.4879
Seen so far: 64 samples
Tr