#### Importing python packages

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

#### Loading the MNIST database

In [None]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

##### Convolutional Neural Network needs the data to be expanded (does not change the data itself, just its dimensions)

In [None]:
x_train = x_train[..., np.newaxis]
x_test = x_test[..., np.newaxis]

## Defining the model

In [None]:
model = tf.keras.models.Sequential([tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
                                    tf.keras.layers.MaxPooling2D((2, 2)),
                                    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
                                    tf.keras.layers.MaxPooling2D((2, 2)),
                                    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
                                    tf.keras.layers.Flatten(),
                                    tf.keras.layers.Dense(64, activation='relu'),
                                    tf.keras.layers.Dense(10, activation='softmax')])

In [None]:
model.summary()

## Defining training

In [None]:
LEARNING_RATE = 0.0001

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

## Training the model

In [None]:
VALIDATION_SPLIT = 0.2
EPOCHS = 10
BATCH_SIZE = 32

In [None]:
history = model.fit(x_train,
                    y_train,
                    validation_split=VALIDATION_SPLIT,
                    epochs=EPOCHS,
                    batch_size=BATCH_SIZE)

## Plotting training curves

In [None]:
def plot_learning_curves(history: tf.keras.callbacks.History):
    history = history.history
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))

    ax1.title.set_text('Loss curves')
    ax1.plot(range(1, EPOCHS+1),
             history['loss'],
             label='loss')
    ax1.plot(range(1, EPOCHS+1),
             history['val_loss'],
             label='val_loss')
    ax1.legend()

    ax2.title.set_text('Accurracy curves')
    ax2.plot(range(1, EPOCHS+1),
             history['accuracy'],
             label='accuracy')
    ax2.plot(range(1, EPOCHS+1),
             history['val_accuracy'],
             label='val_accuracy')
    ax2.legend()

In [None]:
plot_learning_curves(history)

## Testing the model

#### Performance on training set (includes validation set)

In [None]:
model.evaluate(x_train, y_train, verbose=2)

#### Performance on hold-out test set (should be similar to training)

In [None]:
model.evaluate(x_test, y_test, verbose=2)

## Saving the model for JS
#### Overrides model in the repository

In [None]:
model.save('../model/model.h5')