# Callbacks

In [29]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler
from tensorflow.keras.callbacks import TensorBoard, Callback

In [6]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0 #normalization

model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dropout(0.2),#regularization
    Dense(10, activation='softmax')
])

model.compile(optimizer = Adam(learning_rate = 0.01), loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])

In [9]:
model_checkpoint_cb = ModelCheckpoint(
    filepath = 'saved_models/model{epoch:02d}.keras',
    verbose = 2,
    save_best_only = True
)

In [13]:
early_stopping_cb = EarlyStopping(
    patience = 10
)

In [16]:
import math

def scheduler(epoch, lr):
    print('LEARNING_RATE:', lr)
    if epoch < 5 :
        return lr
    else :
        return lr * math.exp(-0.1)
    
learning_rate_scheduler_cb = LearningRateScheduler(scheduler)

In [27]:
tensorboard_cb = TensorBoard('logs', update_freq=1)

In [30]:
class CustomLoger(Callback):
    def __init__(self, logfile):
        super().__init__()
        self.logfile = logfile
    def on_epoch_end(self, epoch, logs):
        with open(self.logfile, 'a+') as f:
            f.write(f"Epoch {epoch} : Validation Loss : {logs['val_loss']}" + '\n')

In [31]:
custom_logger_cb = CustomLoger('mylog.txt')

In [32]:
model.fit(X_train, y_train, epochs = 50, validation_data = (X_test, y_test), callbacks = [model_checkpoint_cb, early_stopping_cb, learning_rate_scheduler_cb, tensorboard_cb, custom_logger_cb])

LEARNING_RATE: 0.009999999776482582
Epoch 1/50
[1m1867/1875[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.9617 - loss: 0.1484
Epoch 1: val_loss did not improve from 0.16758
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 6ms/step - accuracy: 0.9617 - loss: 0.1484 - val_accuracy: 0.9652 - val_loss: 0.2545 - learning_rate: 0.0100
LEARNING_RATE: 0.009999999776482582
Epoch 2/50
[1m1868/1875[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.9607 - loss: 0.1664
Epoch 2: val_loss did not improve from 0.16758
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 6ms/step - accuracy: 0.9607 - loss: 0.1664 - val_accuracy: 0.9600 - val_loss: 0.2983 - learning_rate: 0.0100
LEARNING_RATE: 0.009999999776482582
Epoch 3/50
[1m1865/1875[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 6ms/step - accuracy: 0.9638 - loss: 0.1481
Epoch 3: val_loss did not improve from 0.16758
[1m1875/1875[0m [32m━━━━━━

<keras.src.callbacks.history.History at 0x201c88c8550>