<a href="https://colab.research.google.com/github/AxinLi1/CS436_quiz4/blob/main/early_stop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist

In [2]:
# Load the MNIST dataset
(x_train, y_train), (x_val, y_val) = mnist.load_data()

x_train = x_train.reshape((x_train.shape[0], -1)) / 255.0
x_val = x_val.reshape((x_val.shape[0], -1)) / 255.0

y_train = tf.keras.utils.to_categorical(y_train, 10)
y_val = tf.keras.utils.to_categorical(y_val, 10)

model = Sequential([
    Dense(128, activation='relu', input_shape=(x_train.shape[1],)),
    Dense(64, activation='relu'),
    Dense(10, activation='softmax')
])

model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])

'''
  Early Stopping Implementation
Early stopping pervents overfitting during the training of machine omdels by
stopping when the model stops improving. If the validation loss fails to improve
over a specific number of epochs (5), then the training is stopped.
This prevents the model from continuing learning from the noise in the training
data. This saves computatitional resources and help ensures that the model generalizes
better to unseen data.
'''

patience = 5  # How many epochs to wait before stopping
best_val_loss = float('inf')
epochs_without_improvement = 0

max_epochs = 100
for epoch in range(max_epochs):
    model.fit(x_train, y_train, batch_size=32, epochs=1, verbose=0)

    val_loss, val_accuracy = model.evaluate(x_val, y_val, verbose=0)
    print(f"Epoch {epoch+1}: Validation Loss = {val_loss}, Validation Accuracy = {val_accuracy}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1

    if epochs_without_improvement >= patience:
        print("Early stopping triggered!")
        break

final_loss, final_accuracy = model.evaluate(x_val, y_val)
print(f"Final validation accuracy: {final_accuracy}")

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1: Validation Loss = 0.11584243178367615, Validation Accuracy = 0.9656000137329102
Epoch 2: Validation Loss = 0.08798765391111374, Validation Accuracy = 0.9713000059127808
Epoch 3: Validation Loss = 0.07714589685201645, Validation Accuracy = 0.9750000238418579
Epoch 4: Validation Loss = 0.09113817662000656, Validation Accuracy = 0.9728999733924866
Epoch 5: Validation Loss = 0.07659478485584259, Validation Accuracy = 0.9771999716758728
Epoch 6: Validation Loss = 0.07312313467264175, Validation Accuracy = 0.9779999852180481
Epoch 7: Validation Loss = 0.08488316088914871, Validation Accuracy = 0.9760000109672546
Epoch 8: Validation Loss = 0.07805923372507095, Validation Accuracy = 0.9790999889373779
Epoch 9: Validation Loss = 0.07891424000263214, Validation Accuracy = 0.9797000288963318
Epoch 10: Validation Loss = 0.11099742352962494, Validation Accuracy = 0.9711999893188477
Epoch 11: Validation Loss = 0.13522830605506897, Validation Accuracy = 0.9703999757766724
Early stopping trig