Import Libraries

In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
from keras import layers
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense,Dropout,Flatten
from keras.layers import Conv2D,MaxPool2D
from keras import backend as k

Variables:
batch: the process of splitting the training dataset in n batches (mini-batches),
classes: number of classifications (labels) of the data,
epochs: variations, one epoch is one forward pass + one backward pass on training

In [2]:
batch_size = 128
num_classes = 10
epochs = 4

Assign training and test data

In [3]:
img_rows, img_cols = 28,28
(x_train,y_train),(x_test,y_test) =   mnist.load_data()

Reshape the images

In [4]:
if k.image_data_format()=='channels_first':
    x_train=x_train.reshape(x_train.shape[0],img_rows,img_cols,1)
    x_test=x_test.reshape(x_test.shape[0],img_rows,img_cols,1)
else:
    x_train=x_train.reshape(x_train.shape[0],img_rows,img_cols,1)
    x_test=x_test.reshape(x_test.shape[0],img_rows,img_cols,1)

input_shape=(img_rows,img_cols,1)
x_train = x_train/255.0
x_test=x_test/255.0
print('x_train shape:',x_train.shape,'\nx_test shape:',x_test.shape)

x_train shape: (60000, 28, 28, 1) 
x_test shape: (10000, 28, 28, 1)


Convert class vectors to binary class matrices

In [5]:
y_train=keras.utils.to_categorical(y_train,num_classes)
y_test=keras.utils.to_categorical(y_test,num_classes)

Design the CNN architecture

In [None]:
# Attention Class
# Define attention mechanism layer
class Attention(layers.Layer):
    def __init__(self, units):
        super(Attention, self).__init__()
        self.W1 = layers.Dense(units)
        self.W2 = layers.Dense(units)
        self.V = layers.Dense(1)

    def call(self, features, hidden):
        hidden_with_time_axis = tf.expand_dims(hidden, 1)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))
        attention_weights = tf.nn.softmax(self.V(score), axis=1)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)
        return context_vector, attention_weights

In [6]:
model=Sequential()

model.add( Conv2D(32,kernel_size=(3,3),activation='relu',input_shape=input_shape) )
model.add( MaxPool2D(pool_size=(2,2)) )
model.add( Conv2D(64,kernel_size=(3,3),activation='relu') )
model.add( MaxPool2D(pool_size=(2,2)) )
model.add( Conv2D(64,kernel_size=(3,3),activation='relu') )
#attention = Attention(64)
#model.add = attention(conv3, hidden_state)
model.add( Flatten() )
model.add( Dense(32,activation='relu') )
model.add( Dense(num_classes,activation='softmax') )
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 26, 26, 32)        320       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 32)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 11, 11, 64)        18496     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 5, 5, 64)         0         
 2D)                                                             
                                                                 
 conv2d_2 (Conv2D)           (None, 3, 3, 64)          36928     
                                                                 
 flatten (Flatten)           (None, 576)               0

Compile the model

In [7]:
model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.categorical_crossentropy,
              metrics=['accuracy']
             )
model.fit(x_train,y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test,y_test)
          )

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<keras.callbacks.History at 0x7f62158d4e80>

Save the Model

In [9]:
model.save('AttCNNmodel.h5')
print("model is saved")

model is saved
