### Save Checkpoint (EMNIST Digits)

This notebook is used to train the best model found during training. \
It will be fully trained then saved as a checkpoint for use in the demo notebook.

In [None]:
# Install dependencies
%pip install --upgrade pip
%pip install torchvision
%pip install opencv-python-headless
%pip install pandas
%pip install numpy
%pip install tensorflow[and-cuda]

In [None]:
# Import dependencies
import random
import numpy as np # type: ignore
import pandas as pd # type: ignore
import tensorflow as tf # type: ignore
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint # type: ignore
from tensorflow.keras.layers import Conv2D, Dense, Dropout, Flatten, Input, MaxPooling2D # type: ignore
from tensorflow.keras.models import Sequential # type: ignore
from tensorflow.keras.utils import to_categorical # type: ignore

import cv2 # type: ignore
from torchvision import datasets # type: ignore

In [None]:
# Tensorflow version
print(f"Tensorflow: v{tf.__version__}")

# Check GPU availability
print(f"GPUs Available: {len(tf.config.list_physical_devices('GPU'))}")
print(tf.config.list_physical_devices('GPU'))

In [None]:
# Set seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

In [None]:
# Define data transformations
def transform(image):
    image = np.array(image)
    image = cv2.flip(image, 1)
    center = (image.shape[1] // 2, image.shape[0] // 2)
    rotation_matrix = cv2.getRotationMatrix2D(center, angle=90, scale=1.0)
    image = cv2.warpAffine(image, rotation_matrix, (image.shape[1], image.shape[0]), flags=cv2.INTER_LINEAR)
    image = image.astype('float32') / 255.0
    image = np.expand_dims(image, axis=-1)
    return image

# Load EMNIST Digits subset
emnist_train = datasets.EMNIST(root='../data', split='digits', train=True, transform=transform, download=True)
emnist_test = datasets.EMNIST(root='../data', split='digits', train=False, transform=transform, download=True)

# Convert to numpy arrays for TensorFlow
X_train = np.array([img[0] for img in emnist_train])
y_train = np.array([img[1] for img in emnist_train])
X_test = np.array([img[0] for img in emnist_test])
y_test = np.array([img[1] for img in emnist_test])

# One-hot encode the labels
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# Display some info and stats about the dataset
print(f'Training data shape: {X_train.shape}')
print(f'Test data shape: {X_test.shape}')
print(f'Number of classes: {y_train.shape[1]}')

In [None]:
# Get the top model from the combined results
all_results = pd.read_csv('../results/emnist-digits/all.csv')
top_model = all_results.head(1)
top_model

In [None]:
# Define the model
def create_custom_model(conv1, conv2, dense_units, dropout_rate, learning_rate):
    # Build the model architecture
    model = Sequential([
        Input(shape=(28, 28, 1)),
        Conv2D(conv1, (3, 3), activation='relu'),
        MaxPooling2D((2, 2)),
        Conv2D(conv2, (3, 3), activation='relu'),
        MaxPooling2D((2, 2)),
        Flatten(),
        Dense(dense_units, activation='relu'),
        Dropout(dropout_rate),
        Dense(10, activation='softmax')
    ])
    
    # Compile the model with the chosen learning rate
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

# Build the model
model = create_custom_model(
    int(top_model['conv1'].iloc[0]),
    int(top_model['conv2'].iloc[0]),
    int(top_model['dense_units'].iloc[0]),
    float(top_model['dropout_rate'].iloc[0]),
    float(top_model['learning_rate'].iloc[0])
)

In [None]:
# Train the model
model.fit(
    X_train, y_train, 
    validation_data=(X_test, y_test),
    epochs=30,
    batch_size=int(top_model['batch_size']),
    callbacks=[
        ModelCheckpoint(
            '../results/emnist-digits/best-model.keras',
            monitor='val_accuracy',
            mode='max',
            save_best_only=True,
            verbose=1
        ),
        EarlyStopping(
            monitor='val_accuracy',
            patience=3,
            restore_best_weights=True,
            verbose=1
        ),
    ]
)