In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
import io
import itertools
import sklearn.metrics
import datetime
import matplotlib.pyplot as plt

In [3]:
mnist_dataset,mnist_info=tfds.load(name='mnist',with_info=True,as_supervised=True)



In [4]:
mnist_train,mnist_test=mnist_dataset['train'],mnist_dataset['test']

In [5]:
no_of_validation_samples=0.1*mnist_info.splits['train'].num_examples
no_of_validation_samples=tf.cast(no_of_validation_samples,tf.int64)

no_of_test_samples=mnist_info.splits['test'].num_examples
no_of_test_samples=tf.cast(no_of_test_samples,tf.int64)

In [6]:
buffer_size=70000
batch_size=128
max_epoch=20

In [7]:
def scale(image,label):
    image=tf.cast(image,tf.float32)
    image/=255.
    return image,label

In [8]:
scaled_train_and_validation_data=mnist_train.map(scale)
scaled_test_data=mnist_test.map(scale)

In [9]:
suffled_train_and_validation_data=scaled_train_and_validation_data.shuffle(buffer_size)

In [10]:
validation_data=suffled_train_and_validation_data.take(no_of_validation_samples)
train_data=suffled_train_and_validation_data.skip(no_of_validation_samples)

In [11]:
train_data=train_data.batch(batch_size)
validation_data=validation_data.batch(no_of_validation_samples)
test_data=scaled_test_data.batch(no_of_test_samples)

In [23]:
for images,labels in validation_data:
    images_val=images.numpy()
    labels_val=labels.numpy()

In [24]:
model=tf.keras.Sequential([
    tf.keras.layers.Conv2D(50,5,activation='relu',input_shape=(28,28,1)),
    tf.keras.layers.MaxPooling2D(pool_size=(2,2)),
    tf.keras.layers.Conv2D(50,3,activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2,2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10)
])

In [25]:
loss_rn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits='softmax')
model.compile(optimizer='adam',loss=loss_rn,metrics=['accuracy'])

In [26]:
log_dir='Logs\\fit\\'+'run-1'

In [27]:
def plot_confusion_matrix(cm, class_names):
    figure = plt.figure(figsize=(8, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion matrix")
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)

  # Compute the labels from the normalized confusion matrix.
    labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)

  # Use white text if squares are dark; otherwise black.
    threshold = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        color = "white" if cm[i, j] > threshold else "black"
        plt.text(j, i, labels[i, j], horizontalalignment="center", color=color)

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    return figure

In [28]:
def plot_to_image(figure):
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close(figure)
    buf.seek(0)
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    image = tf.expand_dims(image, 0)
    return image

In [29]:
file_writer_cm = tf.summary.create_file_writer(log_dir + '/cm')

def log_confusion_matrix(epoch, logs):
    test_pred_raw = model.predict(images_val)
    test_pred = np.argmax(test_pred_raw, axis=1)
    cm = sklearn.metrics.confusion_matrix(labels_val, test_pred)
    figure = plot_confusion_matrix(cm, class_names=['0','1','2','3','4','5','6','7','8','9'])
    cm_image = plot_to_image(figure)

    with file_writer_cm.as_default():
        tf.summary.image("Confusion Matrix", cm_image, step=epoch)

In [30]:
cm_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix)
tensorboard_callback=tf.keras.callbacks.TensorBoard(log_dir=log_dir,histogram_freq=1,profile_batch=0)

In [31]:
early_stopping=tf.keras.callbacks.EarlyStopping(monitor='val_loss',mode='auto',min_delta=0,verbose=0,patience=2,restore_best_weights=True)

In [32]:
model.fit(train_data,epochs=max_epoch,callbacks=[tensorboard_callback,cm_callback,early_stopping],validation_data=validation_data,verbose=2)

Epoch 1/20
422/422 - 93s - loss: 0.2689 - accuracy: 0.9248 - val_loss: 0.0790 - val_accuracy: 0.9782
Epoch 2/20
422/422 - 80s - loss: 0.0703 - accuracy: 0.9794 - val_loss: 0.0529 - val_accuracy: 0.9835
Epoch 3/20
422/422 - 81s - loss: 0.0521 - accuracy: 0.9839 - val_loss: 0.0437 - val_accuracy: 0.9868
Epoch 4/20
422/422 - 76s - loss: 0.0421 - accuracy: 0.9874 - val_loss: 0.0308 - val_accuracy: 0.9897
Epoch 5/20
422/422 - 84s - loss: 0.0349 - accuracy: 0.9890 - val_loss: 0.0250 - val_accuracy: 0.9937
Epoch 6/20
422/422 - 100s - loss: 0.0310 - accuracy: 0.9909 - val_loss: 0.0212 - val_accuracy: 0.9935
Epoch 7/20
422/422 - 102s - loss: 0.0274 - accuracy: 0.9917 - val_loss: 0.0230 - val_accuracy: 0.9928
Epoch 8/20
422/422 - 79s - loss: 0.0235 - accuracy: 0.9926 - val_loss: 0.0223 - val_accuracy: 0.9927


<tensorflow.python.keras.callbacks.History at 0x2480017ef48>

In [33]:
%load_ext tensorboard
%tensorboard --logdir "logs/fit"

Reusing TensorBoard on port 6006 (pid 3184), started 2:18:42 ago. (Use '!kill 3184' to kill it.)