# Physics 494/594
## Generalization of the MNIST Model

In [None]:
# %load ./include/header.py
import numpy as np
import matplotlib.pyplot as plt
import sys
from tqdm import trange,tqdm
sys.path.append('./include')
import ml4s

%matplotlib inline
%config InlineBackend.figure_format = 'svg'
plt.style.use('./include/notebook.mplstyle')
np.set_printoptions(linewidth=120)
ml4s.set_css_style('./include/bootstrap.css')
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

## Last Time

### [Notebook Link: 21_MNIST_Softmax.ipynb](./21_MNIST_Softmax.ipynb)

- Classification for many classes (1-hot encoding and softmax)
- MNIST: hand-written digits

## Today
- Explore the robustness and generalizability of our classification model.
- Investigate translations and rotations of digits.

### Import tensorflow

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import datetime

### Load the MNIST Data Set

In [None]:
# load the data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# determine the properties
rows,cols = x_train[0].shape
num_classes = np.max(y_test)+1

# reshape and rescale
x_train = x_train.reshape(x_train.shape[0], rows*cols).astype('float32')/255
x_test = x_test.reshape(x_test.shape[0], rows*cols).astype('float32')/255

# use a built-in function to get 1-hot encoding
y_train_hot = keras.utils.to_categorical(y_train, num_classes)
y_test_hot = keras.utils.to_categorical(y_test, num_classes)

### Define a function that shows many digits

In [None]:
def plot_digit_array(x,y, show_prediction=False):
    '''Expects a list of digits (x) and associated labels (y)'''
    
    # determine the number of rows and columns of our image array
    num_digits = x.shape[0]
    num_cols = int(np.sqrt(num_digits))
    num_rows = num_digits//num_cols + 1

    fig,ax = plt.subplots(nrows=num_rows,ncols=num_cols,sharex=True,sharey=True,
                          figsize=(num_cols,num_rows))
    
    # plot all the numbers
    for i,cax in enumerate(ax.flatten()):
        if i < num_digits:
            cax.matshow(x[i].reshape(28,28), cmap='binary')
            cax.axis('off')
            if show_prediction:
                cax.text(0.99,0.99,f'{y[i]}',horizontalalignment='right',verticalalignment='top', 
                         transform=cax.transAxes, fontsize=8, color='r')
        else:
            cax.axis('off')

### Setup the Model and Load Weights from a Checkpoint

To load weights, we need to ensure consistency of the model.

In [None]:
model = keras.Sequential(
[
    layers.Dense(128,input_shape=(rows*cols,),activation='relu'),
    layers.Dropout(0.25),
    layers.Dense(num_classes, activation='softmax')
])
model.summary()

# load the weights
checkpoint_path = "checkpoints/mnist_dnn/cp.ckpt"
model.load_weights(checkpoint_path)

# compile the new model
model.compile(loss=tf.losses.CategoricalCrossentropy(), optimizer='adam', metrics=[tf.metrics.CategoricalAccuracy()])

### Investigate the predictions

In [None]:
predictions_prob_train = model(x_train)
predictions_prob_test = model(x_test)

predictions_train = np.argmax(predictions_prob_train,axis=1)
predictions_test = np.argmax(predictions_prob_test,axis=1)

mistakes_train = np.where(predictions_train != y_train)[0]
mistakes_test = np.where(predictions_test != y_test)[0]

num_mistakes_train,num_mistakes_test = len(mistakes_train),len(mistakes_test)

print(f'Train Mistakes: {100*num_mistakes_train/x_train.shape[0]:.2f}%')
print(f'Test Mistakes : {100*num_mistakes_test/x_test.shape[0]:.2f}%')

### Plot some mistakes

In [None]:
plot_digit_array(x_test[mistakes_test[:100]],predictions_test[mistakes_test[:100]],show_prediction=True)

### Investigate how well our network generalizes

Let's first determine the cases where we predicted the correct label

In [None]:
correct_train = np.where(predictions_train == y_train)[0]
correct_test = np.where(predictions_test == y_test)[0]

In [None]:
from matplotlib import gridspec

idx = np.random.choice(correct_test)

fig = plt.figure(figsize=(3,1.2*3),constrained_layout=True) 
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 5], figure=fig) 

ax = [plt.subplot(gs[0]),plt.subplot(gs[1])]

ax[0].bar(range(num_classes),predictions_prob_test[idx], color='r')
ax[0].set_xticks(range(num_classes))
ax[0].set_yticks([]);
ax[0].set_xlim(-0.5,9.5)

ax[1].matshow(x_test[idx,:].reshape(rows,cols), cmap='binary')
ax[1].text(0.99,0.99,f'{predictions_test[idx]}',horizontalalignment='right',verticalalignment='top', 
                         transform=ax[1].transAxes, color='r')
ax[1].set_xticks([]);
ax[1].set_yticks([]);

### Let's see what happens if we translate the image

In [None]:
shift = 0
translated_image = np.roll(x_test[idx,:].reshape(rows,cols), shift, axis=[1, 0])

# make a prediction
predict = model(translated_image.reshape(1,rows*cols))

fig = plt.figure(figsize=(3,1.2*3),constrained_layout=True) 
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 5], figure=fig) 

ax = [plt.subplot(gs[0]),plt.subplot(gs[1])]

ax[0].bar(range(num_classes),predict[0,:], color='r')
ax[0].set_xticks(range(num_classes))
ax[0].set_yticks([]);
ax[0].set_xlim(-0.5,9.5)

ax[1].matshow(translated_image, cmap='binary')

ax[1].text(0.99,0.99,f'{np.argmax(predict,axis=1)[0]}',horizontalalignment='right',verticalalignment='top', 
                         transform=ax[1].transAxes, color='r')
ax[1].set_xticks([]);
ax[1].set_yticks([]);

### Or rotate it

In [None]:
from scipy import ndimage

angle = 15
rotated_digit = np.abs(ndimage.rotate(x_test[idx,:].reshape(rows,cols), angle, reshape=False))
rotated_digit -= np.min(rotated_digit)
rotated_digit /= np.max(rotated_digit)

# make a prediction
predict = model(rotated_digit.reshape(1,rows*cols))

fig = plt.figure(figsize=(3,1.2*3),constrained_layout=True) 
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 5], figure=fig) 

ax = [plt.subplot(gs[0]),plt.subplot(gs[1])]

ax[0].bar(range(num_classes),predict[0,:], color='r')
ax[0].set_xticks(range(num_classes))
ax[0].set_yticks([]);
ax[0].set_xlim(-0.5,9.5)

ax[1].matshow(rotated_digit, cmap='binary')

ax[1].text(0.99,0.99,f'{np.argmax(predict,axis=1)[0]}',horizontalalignment='right',verticalalignment='top', 
                         transform=ax[1].transAxes, color='r')
ax[1].set_xticks([]);
ax[1].set_yticks([]);

### This seems worrysome, our brains can do this trivially!  

Taking into account spatial correlations is the idea behind convolutional neural networks (CNN)!