In [17]:
import numpy as np
from keras.layers import Dense, Dropout, Input
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dropout
from keras.models import Model
from keras.layers.merge import concatenate
from keras.datasets import mnist
from keras.utils import to_categorical
from keras.utils import plot_model

In [30]:
!pip install pydot

Collecting pydot
  Downloading https://files.pythonhosted.org/packages/33/d1/b1479a770f66d962f545c2101630ce1d5592d90cb4f083d38862e93d16d2/pydot-1.4.1-py2.py3-none-any.whl
Installing collected packages: pydot
Successfully installed pydot-1.4.1


In [9]:
#Loading the data
((x_train, y_train), (x_test, y_test)) = mnist.load_data()

In [10]:
x_train.shape, x_test.shape, y_train.shape, y_test.shape

((60000, 28, 28), (10000, 28, 28), (60000,), (10000,))

In [11]:
num_labels = len(np.unique(y_train))
num_labels

10

In [12]:
y_train = to_categorical(y_train, num_classes=num_labels)
y_test = to_categorical(y_test, num_classes=num_labels)

In [15]:
#Reshaping and Normalizing the data
image_size = x_train.shape[1]
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_test = np.reshape(x_test, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255

In [18]:
#Parameters
input_shape = (image_size, image_size, 1)
batch_size = 32
kernel_size = 3
dropout = 0.4
n_filters = 32

In [19]:
#Creating the left branch of the Y-Network
left_input = Input(shape = input_shape)
x = left_input
filters = n_filters 

for _ in range(3):
    x = Conv2D(filters = filters, kernel_size= kernel_size, padding='same', activation='relu')(x)
    x = Dropout(dropout)(x)
    x = MaxPooling2D()(x)
    
    filters = filters*2

In [23]:
#Creating the right branch of the Y-Network
right_input = Input(shape = input_shape)
y = right_input
filters = n_filters 

for _ in range(3):
    y = Conv2D(filters = filters, kernel_size= kernel_size, padding='same', activation='relu', dilation_rate=2)(y)
    y = Dropout(dropout)(y)
    y = MaxPooling2D()(y)
    
    filters = filters*2

In [24]:
#Merging left and right branches
y = concatenate([x,y])
y = Flatten()(y)
y = Dropout(dropout)(y)
output = Dense(num_labels, activation='softmax')(y)

In [27]:
model = Model([left_input, right_input], output)
#plot_model(model, to_file='y-network.png', show_shapes=True)
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
input_3 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 28, 28, 32)   320         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 28, 28, 32)   320         input_3[0][0]                    
__________________________________________________________________________________________________
dropout_1 

In [31]:
plot_model(model, to_file='y-network.png', show_shapes=True)

ImportError: Failed to import `pydot`. Please install `pydot`. For example with `pip install pydot`.

In [33]:
model.compile(loss = 'categorical_crossentropy',
              optimizer = 'adam',
              metrics = ['accuracy'])

In [34]:
model.fit([x_train, x_train],
          y_train,
          validation_data = ([x_test, x_test], y_test),
          epochs = 20,
          batch_size=batch_size)

Train on 60000 samples, validate on 10000 samples
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x7f93e1b36f98>