In [1]:
# Import Python libraries for Training and Processing

import tensorflow as tf
from keras.datasets import cifar10

import os  # Save a trained model
from matplotlib import pyplot as plt  # Plot charts

#  Define constants
IMAGE_WIDTH = 32
IMAGE_HEIGHT = 32
CHANNEL = 3  # color image
EPOCHS = 5  # Number of Training cycles

ModuleNotFoundError: No module named 'keras'

In [None]:
def load_data():
    # Load cifar10 data into train and test sets
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()

    x_train = x_train.astype('float32')  # Type Conversion
    x_train = tf.keras.utils.normalize(x_train, axis=1)

    return [x_train, y_train]

<img src="img/architecture.png" width="800">

In [None]:
def build_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(16, (3, 3), activation='relu', input_shape=(IMAGE_HEIGHT, IMAGE_WIDTH, CHANNEL)),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),  # Dense Layer 128
        tf.keras.layers.Dropout(0.2),  # Dropout chance 20%
        tf.keras.layers.Dense(64, activation='relu'),  # Dense Layer 64
        tf.keras.layers.Dropout(0.2),  # Dropout chance 20%
        tf.keras.layers.Dense(10, activation='softmax'),  # Probability distribution
    ])

    return model

In [None]:
def train_model(model, training_data):
    x_train, y_train = training_data[0], training_data[1]

    model.compile(
        optimizer=tf.keras.optimizers.Adam(),  # Optimizer = Adam
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )

    model.fit(x_train, y_train, epochs=EPOCHS, shuffle=True, validation_split=0.2)

In [None]:
def plot_metrics(model):
    plt.plot(model.history.history['sparse_categorical_accuracy'])
    plt.plot(model.history.history['val_sparse_categorical_accuracy'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'val'], loc='upper left')
    plt.show()


<img src="img/metrics.png" width="400">


In [None]:
def save_model(model):
    while (True):
        path = input("Enter directory where you would like to save the model:\n")
        path = path.replace("\\", "/")  # Replace Slash with Backslash
        filename = input("Enter a filename for the model\n")
        if not os.path.isfile(path + filename + '.model'):  # Check if file already exists
            break
        else:
            print("\nThe file " + filename + " does already exist in the directory. \n"
                                             "Please try again.\n\n ")
    model.save(path + filename + '.model')

In [None]:
def main():
    training_data = load_data()  # load training and testing data
    model = build_model()  # build CNN model
    train_model(model, training_data)  # train CNN model
    plot_metrics(model)  # plot loss after each iteration
    save_model(model)


if __name__ == '__main__':
    main()