This notebook contains code for training an AI on the MNIST dataset to classify handwritten digits.



In [None]:
import keras 
from keras.datasets import mnist
from keras import models
from keras import layers
from keras.utils import to_categorical

# Load data from MNIST dataset. There are 60000 train_images and 10000 test_images.
# First we must prepare the input data to match our needs.
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Reshape train_images into 2D tensor with 60000 samples and 28 * 28 features.
train_images = train_images.reshape((60000, 28 * 28))

# We turn it into float32 then divide by 255 because we want each pixel value to be
# a float inbetween 0 - 1. The max pixel value is 255. to_cateogrical helps makes
# the labels the right size.
train_images = train_images.astype('float32') / 255
train_labels = to_categorical(train_labels)

test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255
test_labels = to_categorical(test_labels)

# We are using a linear layer structure so we use Sequential architecture
network = models.Sequential()

# Add two densely-connected layers that apply tensor operations to input data.
# Each dense layer applies the following tensor operations relu(dot(Weight_Matrix, input) + b)
# relu is max(x,0), dot is dot product.
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))

# Returns probability scores for each digit. They sum to 1.
network.add(layers.Dense(10, activation='softmax'))

# categorical_crossentroy is loss function. rmsprop specifies to use mini
# batch stochastical gradient based descent. 
network.compile(optimizer='rmsprop',
                loss='categorical_crossentropy',
                metrics=['accuracy'])

# This is the trianing loop. The network iterates over batches of 128 samples
# 5 times over.
network.fit(train_images, train_labels, epochs=5, batch_size=128)

# Now we run our AI on the test images to see how it performs
test_loss, test_acc = network.evaluate(test_images, test_labels)
print('test_acc:', test_acc)



Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test_acc: 0.9805999994277954
