In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

from resnet50.model.architecture import res_net_50

# Get the data

In [None]:
(Xtrain, ytrain), (Xtest, ytest) = keras.datasets.mnist.load_data()

In [None]:
Xtrain.shape, ytrain.shape, Xtest.shape, ytest.shape

# Plot some digits

In [None]:
n_images = 10
index_sample = np.random.randint(0, Xtrain.shape[0], n_images)

n_columns = 5
n_rows = int(np.ceil(n_images / 5))
fig, axes = plt.subplots(n_rows, n_columns, figsize=(20, 4 * n_rows))

raveled_axes = np.ravel(axes)

for ax, i in zip(raveled_axes, index_sample):
    ax.set_title(ytrain[i], fontsize=16)
    ax.imshow(Xtrain[i], cmap='Greys')
    
for ax in raveled_axes[n_images:]:
    ax.set_visible(False)

# Prepare data

In [None]:
Xtrain = Xtrain.reshape(Xtrain.shape + (1,))
Xtest = Xtest.reshape(Xtest.shape + (1,))

In [None]:
Xtrain = tf.image.resize(Xtrain, (64, 64))
Xtest = tf.image.resize(Xtest, (64, 64))

# Create model

In [None]:
model = res_net_50(Xtrain.shape[1:], np.unique(ytrain).shape[0])

In [None]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train

In [None]:
history = model.fit(Xtrain, ytrain, batch_size=128, epochs=1, validation_split=0.2, shuffle=True)

# Evaluate on test set

In [None]:
test_loss, test_accuracy = model.evaluate(Xtest, ytest, batch_size=256, verbose=0)
print("Test loss: {}".format(test_loss))
print("Test accuracy: {}".format(test_accuracy))

# Some wrong predictions visualized

In [None]:
test_predictions = model.predict(Xtest, batch_size=256)

In [None]:
test_predictions_class = np.argmax(test_predictions, axis=1)

In [None]:
wrong_prediction_bools = (ytest != test_predictions_class)

In [None]:
Xtest_wrong = Xtest[wrong_prediction_bools]
ytest_wrong = ytest[wrong_prediction_bools]
test_prediction_class_wrong = test_predictions_class[wrong_prediction_bools]

In [None]:
# rows indicate true label, columns indicate predicted label
confusion_matrix(ytest, test_predictions_class)

In [None]:
n_images = 15
index_sample = np.random.randint(0, Xtest_wrong.shape[0], n_images)

n_columns = 5
n_rows = int(np.ceil(n_images / 5))
fig, axes = plt.subplots(n_rows, n_columns, figsize=(20, 4 * n_rows))

raveled_axes = np.ravel(axes)

for ax, i in zip(raveled_axes, index_sample):
    ax.set_title("True/Pred: {}/{}".format(ytest_wrong[i], test_prediction_class_wrong[i]), fontsize=16)
    ax.imshow(Xtest_wrong[i,:,:,0], cmap='Greys')
    
for ax in raveled_axes[n_images:]:
    ax.set_visible(False)