In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models, callbacks

In [None]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

N,n,m = x_train.shape
x_train = np.reshape(x_train,(N,n,m,1))

N,n,m = x_test.shape
x_test = np.reshape(x_test,(N,n,m,1))

In [None]:
def displayConvLayer(layer_name):
    weights = layer_name.get_weights()[0]
    fig, ax = plt.subplots(weights.shape[-2],weights.shape[-1],figsize=(15,15))
    for i in range(weights.shape[-1]):
        for j in range(weights.shape[-2]):
            weight = np.reshape(weights[:,:,j,i],weights.shape[0:2])
            if weights.shape[-2]==1:
                ax[i].imshow(weight)
            else:
                ax[j,i].imshow(weight)
                
    return weights

In [None]:
model = models.Sequential()
conv1 = layers.Conv2D(6, (5, 5), padding='same', activation='relu', input_shape=(28, 28 ,1))
model.add(conv1)
model.add(layers.MaxPooling2D((2, 2)))
conv2 = layers.Conv2D(16, (5, 5), activation='relu')
model.add(conv2)
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten(input_shape=(5, 5, 16)))
model.add(layers.Dense(120, activation='relu'))
model.add(layers.Dense(80, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.summary()

In [None]:
#print_weights = callbacks.LambdaCallback(on_epoch_end=lambda batch, logs: print(conv1.get_weights()[0]))

In [None]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.CategoricalCrossentropy(),
              metrics=['accuracy'])

In [None]:
weights_i  = displayConvLayer(conv1)

In [None]:
history = model.fit(x_train, tf.one_hot(y_train,10), batch_size  = 128, epochs=20,
          validation_data=(x_test, tf.one_hot(y_test,10)))

In [None]:
weights_f = displayConvLayer(conv1)

In [None]:
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')

test_loss, test_acc = model.evaluate(x_test,  tf.one_hot(y_test,10), verbose=2)

In [None]:
plt.plot(history.history['loss'], label='loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')