### How to train a simple neural network for the MNIST dataset (digit classification)

In [44]:
import numpy as np

from keras.datasets import mnist
from keras import models
from keras import layers
from keras.utils import to_categorical

In [38]:
# Download the MNIST dataset
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

In [39]:
# The training data has 60000 samples (28x28 images)
print train_images.shape
# Each image has a label (0 to 9)
print len(train_labels)
train_labels

(60000, 28, 28)
60000


array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

In [40]:
# The test data has 10000 samples (28x28 images)
print test_images.shape
# Each image has a label (0 to 9)
print len(test_labels)
test_labels

(10000, 28, 28)
10000


array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)

In [41]:
# Define the network (one layer of 512 units and an output of 10 units)
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dense(10, activation='softmax'))

In [42]:
# Compile the network and define the optimizer, loss and performance metric
network.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])

In [43]:
# Reshape the samples (each sample has 28*28 'features') and normalize between 0-1
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255

In [45]:
# Convert the 0-9 labels to arrays of 0s and 1s of size 10
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

print train_labels

[[0. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]]


In [47]:
# Perform the training of the network
network.fit(train_images, train_labels, epochs=10, batch_size=128)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x110498e10>

In [48]:
# Evaluate the performance of the model on the test data
test_loss, test_acc = network.evaluate(test_images, test_labels)
print('test_acc:', test_acc)

('test_acc:', 0.9826)
