# Physics 494/594
## MNIST and Softmax for Classification of Handwritten Digits

In [None]:
# %load ./include/header.py
import numpy as np
import matplotlib.pyplot as plt
import sys
from tqdm import trange,tqdm
sys.path.append('./include')
import ml4s

%matplotlib inline
%config InlineBackend.figure_format = 'svg'
plt.style.use('./include/notebook.mplstyle')
np.set_printoptions(linewidth=120)
ml4s.set_css_style('./include/bootstrap.css')
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

## Last Time

### [Notebook Link: 20_Classification.ipynb](./20_Classification.ipynb)

- Studied a new type of problem involving the classification of inputs (map continue/discrete inputs into discrete outputs)
- New cost functions: binary and categorical cross-entropy

## Today
- Classification for many classes (1-hot encoding and softmax)
- MNIST: hand-written digits

#### Categorical Cross Entropy

\begin{equation}
C(\boldsymbol{w}) = -\frac{1}{N} \sum_{n=1}^{N} \sum_{j=1}^K y_j^{(n)} \ln a_j^L
\end{equation}

#### Softmax Layer
\begin{equation}
a_j^L = \frac{\mathrm{e}^{z_j^L}}{\sum_{k=1}^K \mathrm{e}^{z_k^L}} \, .
\end{equation}

### Import tensorflow

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import datetime

### Machine Learning for Digit Recognition

This is an **old** and important problem (think computer vision for processing checks). 

In [None]:
from IPython.display import YouTubeVideo
YouTubeVideo(id='FwFduRA_L6Q',width=600,height=300)

### Load the MNIST Data Set

This is a standard [(and famous)](https://en.wikipedia.org/wiki/MNIST_database) set of 60,000 training and 10,000 test images, each containing 28x28 pixels.

We want to:

1. Flatten the input data such that it is 1D.
2. Rescale and cast as floats instead of integers [0-255].
3. Turn the integer labels into 1-hot vectors.

In [None]:
# load the data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# determine the properties
rows,cols = x_train[0].shape
num_classes = 10

In [None]:
# reshape and rescale
x_train = x_train.reshape(x_train.shape[0], rows*cols).astype('float32')/255
x_test = x_test.reshape(x_test.shape[0], rows*cols).astype('float32')/255

# use a built-in function to get 1-hot encoding
y_train_hot = keras.utils.to_categorical(y_train, num_classes)
y_test_hot = keras.utils.to_categorical(y_test, num_classes)

### Compare the class vs. 1-hot encoding and show digit

In [None]:
# get a random digit in the training set
idx = np.random.randint(low=0, high=len(x_train))

print(f'class: {y_train[idx]}')
print(f'1-hot: {y_train_hot[idx].astype(int)}')

plt.matshow(x_train[idx,:].reshape(rows,cols), cmap='binary')
plt.xticks([]);
plt.yticks([]);

### Define a function that shows many digits

In [None]:
def plot_digit_array(x,y, show_prediction=False):
    '''Expects a list of digits (x) and associated labels (y)'''
    
    # determine the number of rows and columns of our image array
    num_digits = x.shape[0]
    num_cols = int(np.sqrt(num_digits))
    num_rows = num_digits//num_cols + 1

    fig,ax = plt.subplots(nrows=num_rows,ncols=num_cols,sharex=True,sharey=True,
                          figsize=(num_cols,num_rows))
    
    # plot all the numbers
    for i,cax in enumerate(ax.flatten()):
        if i < num_digits:
            cax.matshow(x[i].reshape(28,28), cmap='binary')
            cax.axis('off')
            if show_prediction:
                cax.text(0.99,0.99,f'{y[i]}',horizontalalignment='right',verticalalignment='top', 
                         transform=cax.transAxes, fontsize=8, color='r')
        else:
            cax.axis('off')

In [None]:
%%time
idx = np.random.randint(low=0, high=x_train.shape[0], size=100)
plot_digit_array(x_train[idx],y_train[idx])

### Setup and train a neural network to learn the digits

In [None]:
model = keras.Sequential(
[
    layers.Dense(128,input_shape=(rows*cols,),activation='relu'),
    layers.Dropout(0.25),
    layers.Dense(num_classes, activation='softmax')
])
model.summary()

# this provides a more modern approach to the compilation
model.compile(loss=tf.losses.CategoricalCrossentropy(), optimizer='adam', metrics=[tf.metrics.CategoricalAccuracy()])

#model.compile(loss="categorical_crossentropy", optimizer='adam', metrics=['accuracy'])

### Setup Checkpointing of the Model

In [None]:
checkpoint_path = "checkpoints/mnist_dnn/cp.ckpt"

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

In [None]:
batch_size = 32
epochs = 10

training = model.fit(x_train,y_train_hot, batch_size=batch_size, epochs=epochs,
                     verbose=1, validation_data=(x_test,y_test_hot),callbacks=[cp_callback])

In [None]:
# look into training history
fig,ax = plt.subplots(2,1, sharex=True, figsize=(5,5))

score = model.evaluate(x_test, y_test_hot, verbose=0);

# accuracy
ax[0].plot(training.history['categorical_accuracy'], color=colors[0])
ax[0].plot(training.history['val_categorical_accuracy'], ls='--', color=colors[-3])
ax[0].set_ylabel('model accuracy')
ax[0].legend(['train', 'test'], loc='best')
ax[0].text(0.5,0.95,f'{score[1]:.2f}',horizontalalignment='center',verticalalignment='top', 
                         transform=ax[0].transAxes)
ax[0].set_ylim(top=1)

# loss
ax[1].plot(training.history['loss'], color=colors[0])
ax[1].plot(training.history['val_loss'], ls='--', color=colors[-3])
ax[1].set_ylabel('model loss')
ax[1].set_xlabel('epoch')
ax[1].set_ylim(bottom=0)
ax[1].text(0.5,0.95,f'{score[0]:.2f}',horizontalalignment='center',verticalalignment='top', 
                         transform=ax[1].transAxes)
ax[1].legend(['train', 'test'], loc='best');

### Investigate the predictions

In [None]:
predictions_prob_train = model(x_train)
predictions_prob_test = model(x_test)

predictions_train = np.argmax(predictions_prob_train,axis=1)
predictions_test = np.argmax(predictions_prob_test,axis=1)

mistakes_train = np.where(predictions_train != y_train)[0]
mistakes_test = np.where(predictions_test != y_test)[0]

num_mistakes_train,num_mistakes_test = len(mistakes_train),len(mistakes_test)

print(f'Train Mistakes: {100*num_mistakes_train/x_train.shape[0]:.2f}%')
print(f'Test Mistakes : {100*num_mistakes_test/x_test.shape[0]:.2f}%')

### Plot some mistakes

In [None]:
plot_digit_array(x_test[mistakes_test[:100]],predictions_test[mistakes_test[:100]],show_prediction=True)

### Check the Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(y_test, predictions_test)
cm_display = ConfusionMatrixDisplay(cm).plot(cmap='Blues')

### Investigate softmax probabilities to see how far we are off

In [None]:
from matplotlib import gridspec

idx = np.random.choice(mistakes_test)

fig = plt.figure(figsize=(3,1.2*3),constrained_layout=True) 
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 5], figure=fig) 

ax = [plt.subplot(gs[0]),plt.subplot(gs[1])]

ax[0].bar(range(num_classes),predictions_prob_test[idx], color='r')
ax[0].set_xticks(range(num_classes))
ax[0].set_yticks([]);
ax[0].set_xlim(-0.5,9.5)

ax[1].matshow(x_test[idx,:].reshape(rows,cols), cmap='binary')
ax[1].text(0.99,0.99,f'{predictions_test[idx]}',horizontalalignment='right',verticalalignment='top', 
                         transform=ax[1].transAxes, color='r')
ax[1].set_xticks([]);
ax[1].set_yticks([]);