In [None]:
%matplotlib inline

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math
from keras.utils.np_utils import to_categorical
from keras.models import Sequential, load_model, Model
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D

In [None]:
from keras.datasets import mnist
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
X_train.shape
Y_train[0]

In [None]:
plt.imshow(X_train[0], cmap='binary')

In [None]:
X_train = X_train.reshape(60000, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)

Y_train = to_categorical(Y_train)
Y_test = to_categorical(Y_test)

X_train = X_train/255.0
X_test = X_test/255.0

In [None]:
model = Sequential()
model.add(Conv2D(filters=16, kernel_size=(5, 5), padding='same', strides=1, activation='relu', input_shape=(28, 28, 1)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(filters=36, kernel_size=(5, 5), padding='same', strides=1, activation='relu'))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(10, activation='softmax'))
model.summary()

In [None]:
model.compile(
    loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']
)


In [None]:
model.fit(X_train, Y_train, batch_size=200, epochs=1, validation_split=0.2)

In [None]:
score = model.evaluate(X_test, Y_test)
print(score)

In [None]:
'''
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_13 (Conv2D)           (None, 28, 28, 16)        416       
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 14, 14, 16)        0         
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 14, 14, 36)        14436     
_________________________________________________________________
flatten_6 (Flatten)          (None, 7056)              0         
_________________________________________________________________
dense_11 (Dense)             (None, 128)               903296    
_________________________________________________________________
dense_12 (Dense)             (None, 10)                1290      
=================================================================
'''
l1 = model.get_layer('conv2d_13')
#l1.get_weights()
l1.get_weights()[0].shape

In [None]:
def plot_weight(w):
    w_min = np.min(w)
    w_max = np.max(w)
    num_grid = math.ceil(math.sqrt(w.shape[3]))
    fix, aixs = plt.subplots(num_grid, num_grid)
    for i, ax in enumerate(aixs.flat):
        if i < w.shape[3]:
            img = w[:,:,0,i]
            ax.imshow(img, vmin=w_min, vmax=w_max)
            
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

In [None]:
l1 = model.get_layer('conv2d_13')
w1 = l1.get_weights()[0]
plot_weight(w1)

l2 = model.get_layer('conv2d_14')
w2 = l2.get_weights()[0]
plot_weight(w2)

In [None]:
'''
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_13 (Conv2D)           (None, 28, 28, 16)        416       
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 14, 14, 16)        0         
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 14, 14, 36)        14436     
_________________________________________________________________
flatten_6 (Flatten)          (None, 7056)              0         
_________________________________________________________________
dense_11 (Dense)             (None, 128)               903296    
_________________________________________________________________
dense_12 (Dense)             (None, 10)                1290      
=================================================================
'''
temp_model = Model(inputs=model.get_layer('conv2d_13').input, outputs=model.get_layer('conv2d_14').output)
output = temp_model.predict(X_test)

In [None]:
output.shape

In [None]:
def plot_output(w):
    num_grid = math.ceil(math.sqrt(w.shape[3]))
    fix, aixs = plt.subplots(num_grid, num_grid)
    for i, ax in enumerate(aixs.flat):
        if i < w.shape[3]:
            img = w[0,:,:,i]
            ax.imshow(img, cmap='binary')
            
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

In [None]:
plot_output(output)