In [1]:
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from keras.utils import np_utils


In [2]:
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
print(x_train.shape, x_test.shape)

(60000, 28, 28) (10000, 28, 28)


In [4]:
x_train = x_train.reshape(60000, 28, 28, 1)
x_test = x_test.reshape(10000, 28, 28, 1)
input_shape = (28, 28, 1)

In [5]:
# one-hot encoding using keras' numpy-related utilities
n_classes = 10
y_train = np_utils.to_categorical(y_train, n_classes)
y_test = np_utils.to_categorical(y_test, n_classes)

In [6]:
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

In [7]:
x_train /= 255
x_test /= 255

In [8]:
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples


In [9]:
model = Sequential()

# Convnet Layer:
model.add(Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

# Classification Layer:
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))

model.add(Dense(10, activation='softmax'))

In [10]:
model.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer='adam')

In [11]:
#fit
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x1e073404280>

In [12]:
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

Test loss: 0.028285622596740723
Test accuracy: 0.9921000003814697


In [13]:
model.predict(x_test)

array([[6.7113606e-21, 3.9396012e-13, 2.0971738e-12, ..., 1.0000000e+00,
        6.3532077e-17, 1.5465559e-13],
       [1.3642009e-13, 3.4963563e-13, 1.0000000e+00, ..., 4.7837222e-13,
        1.9272604e-11, 8.6280682e-17],
       [1.3714609e-14, 1.0000000e+00, 6.5354167e-10, ..., 1.0113143e-09,
        1.3655230e-10, 1.4824292e-12],
       ...,
       [9.3567953e-23, 2.6477079e-14, 1.0922853e-15, ..., 4.8099300e-15,
        6.1982971e-14, 4.2500186e-13],
       [4.2483401e-17, 1.0304561e-17, 1.3975730e-20, ..., 3.8962283e-21,
        7.2649985e-12, 6.1123849e-13],
       [1.6251621e-14, 8.8524683e-19, 2.9500547e-18, ..., 1.6365449e-24,
        2.7082449e-13, 9.7370842e-19]], dtype=float32)