In [2]:
from tensorflow.keras import datasets, utils, models, layers, optimizers, losses, callbacks
from src.model import cnn_model
from src.utils import plot_results, plot_predicts
import yaml
import datetime


In [None]:
with open("./configs/configurations.yaml", 'r') as file:
    config = yaml.safe_load(file)

## Data Preparation

In [1]:
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

# Preprocessing
X_train = train_images.reshape(len(train_images), 28, 28, 1)
X_test = test_images.reshape(len(test_images), 28, 28, 1)

# Normalization
X_train = X_train.astype('float32') / 255
X_test =  X_test.astype('float32') / 255

# OneHot Encoding
Y_train = utils.to_categorical(train_labels)
Y_test = utils.to_categorical(test_labels)

## Model Compile and Callbacks

In [None]:
model = cnn_model()
model.summary()
model.compile(optimizer=optimizers.Adam(), loss=losses.categorical_crossentropy, metrics=['accuracy'])

In [None]:
checkpoints = callbacks.ModelCheckpoint(config["training_data"]["callbacks_path_checkpoints"] + 'model.{epoch}.h5')
tensorboard = callbacks.TensorBoard(config["training_data"]["callbacks_path_tensorboard"] + 'tensorboard')
csv_logger = callbacks.CSVLogger(config["training_data"]["callbacks_path_csvlogger"] + 'training.log')
callbacks_list = [checkpoints, tensorboard, csv_logger]

## Model Train

In [6]:
time_start = datetime.datetime.now()
model_history = model.fit(X_train, Y_train,
                           batch_size=config["training_data"]["batch_size"],
                           epochs=config["training_data"]["epochs"], 
                           validation_split=config["training_data"]["validation_split"],
                           shuffle=True,
                           callbacks=callbacks_list)
time_end = datetime.datetime.now()
print(f"Training Time: {time_end-time_start}")

In [None]:
model.save(config["training_data"]["model_save_path"] + 'model.h5')
model.save_weights(config["training_data"]["model_save_path"] + 'model_weights.h5')

In [None]:
plot_results(model_history)

## Model Evaluation

In [None]:
test_loss, test_accuracy = model.evaluate(X_test, Y_test)
print("Test Accuracy: ", test_accuracy); print("Test Loss: ", test_loss)

In [None]:
plot_predicts(model, X_test, test_labels, test_images, 10)