In [6]:
import numpy as np
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.layers import Flatten, concatenate
from tensorflow.keras.layers import Dense, Dropout, Input
from tensorflow.keras .datasets import mnist
from tensorflow.keras.utils import to_categorical, plot_model

In [2]:
# load the datast
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [3]:
# from sparse labels to categorical
num_labels = len(np.unique(y_train))
y_train = to_categorical(y_train)
y_test  = to_categorical(y_test)

In [4]:
# reshape and normalize input images 
image_size = x_train.shape[1]
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_test  = np.reshape(x_test,  [-1, image_size, image_size, 1])

x_train = x_train.astype('float32') / 255
x_test  = x_test.astype('float32') / 255

In [5]:
# network params
input_shape = (image_size, image_size, 1)
n_filters = 32
batch_size = 32
kernel_size = 3
dropout = 0.4

In [8]:
# left branch of Y-network
left_inputs = Input(shape=input_shape)
x = left_inputs
filters = n_filters

# 3 layers of Conv2D-Dropout-MaxPooling2D
# num of filters doubles after each layer (32, 64, 128)
for i in range(3):
    x = Conv2D(filters=filters, 
               kernel_size=kernel_size, 
               padding='same', 
               activation='relu')(x)
    x = Dropout(dropout)(x)
    x = MaxPooling2D()(x)
    filters *= 2


# right branch of Y-network
right_inputs = Input(shape=input_shape)
y = right_inputs
filters = n_filters

# 3 layers of Conv2D-Dropout-MaxPooling2D
# number of filters doubles after each layer (32, 64, 128)
for i in range(3):
    y = Conv2D(filters=filters, 
               kernel_size=kernel_size, 
               padding='same', 
               activation='relu')(y)
    y = Dropout(dropout)(y)
    y = MaxPooling2D()(y)
    filters *= 2
    

# merge left and right branches outputs
y = concatenate([x, y])

# features maps to vector before connecting to Dense
y = Flatten()(y)
y = Dropout(dropout)(y)
outputs = Dense(num_labels, activation='softmax')(y)

# build the model in functional API
model = Model([left_inputs, right_inputs], outputs)

# verify the model using graph
plot_model(model, to_file='cnn-Y-network.png', show_shapes=True)

# verify the model using layer text description
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 input_3 (InputLayer)           [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 28, 28, 32)   320         ['input_2[0][0]']                
                                                                                                  
 conv2d_3 (Conv2D)              (None, 28, 28, 32)   320         ['input_3[0][0]']                
                                                                                              

In [9]:
# classifier loss, Adam optimizer, accuracy metrics
model.compile(loss='categorical_crossentropy', 
              optimizer='Adam', 
              metrics=['accuracy'])

In [10]:
# train the model with input images and labesl 
model.fit([x_train, x_train], 
          y_train, 
          validation_data=([x_test, x_test], y_test), 
          epochs=25, 
          batch_size=batch_size)

Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25


<keras.callbacks.History at 0x21426815f40>

In [13]:
# model accuracy on test data
score = model.evaluate([x_test, x_test], 
                       y_test, 
                       batch_size=batch_size, 
                       verbose=0)

print("\nTest accuracy: %.1f%%" % (100.0 * score[1]))


Test accuracy: 99.3%
