In [48]:
import tensorflow
import numpy
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import pandas as pd

In [49]:
dataSet=tensorflow.keras.datasets.mnist

In [50]:
(X_train,y_train),(X_test,y_test)=dataSet.load_data()

 # We have 60000 test samples

In [51]:
X_train.shape

# We use 30% of the test samples for validation

In [52]:
X_train,X_val,y_train,y_val=train_test_split(X_train,y_train,test_size=0.3,random_state=101)

In [53]:
X_train.shape

In [54]:
X_val.shape

In [55]:
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import callbacks

In [56]:
early_stopping=callbacks.EarlyStopping(min_delta=0.01,
                                      patience=10,
                                      restore_best_weights=True)

In [57]:
model=keras.Sequential([
    layers.Flatten(input_shape=[28,28]),
    layers.Dense(units=256,activation='relu'),
    layers.BatchNormalization(),
    layers.Dropout(0.3),
    layers.Dense(units=128,activation='relu'),
    layers.Dense(units=10,activation='softmax')
])

In [58]:
model.compile(optimizer='adam',loss=tensorflow.keras.losses.SparseCategoricalCrossentropy(from_logits=True))

In [59]:
history=model.fit(X_train,y_train,validation_data=(X_val,y_val),batch_size=256,epochs=30,callbacks=[early_stopping])

In [60]:
history_df=pd.DataFrame(history.history)
history_df.plot()

# We have used Callbacks to avoid overfitting the model. Droup out layers also serve for the same reason.

In [61]:
results_prob=model.predict(X_test)

In [62]:
results=numpy.argmax(results_prob,axis=1)

In [63]:
results.shape

In [64]:
accurate_results=numpy.array(results==y_test)

# Accuracy computed to be 0.9807

In [65]:
sum(accurate_results)/len(accurate_results)

In [68]:
plt.figure(figsize=(20,10))
for i in range(1,32,2):
    plt.subplot(4,8,i)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(X_test[i])
    plt.title("Predicted Value : "+str(results[i]))
    plt.subplot(4,8,i+1)
    plt.bar(range(10),results_prob[i])
    plt.xticks(range(10))
plt.tight_layout()