# Make a neural network for reading handwritten digits

Here a neural network is constructed and trained on MNIST data. The input is an image of a hand-written digit (between zero and nine) and the output is 10 numbers that correspond to probabilities for each number between 0 and 9.

It is a very simple neural network: Two hidden layers, fully connected.

MNIST contains images of 28x28=784 pixels and a corresponding label of which number it is.

In [1]:
import torch
import torch.nn as nn

## Get the MNIST data

In [None]:
import torchvision

# This or similar code is found all over the internet and is more or less
# incomprehensible to most of us. You need to study the docs of the torchvsion
# package to understand the details.

# NOTE: ToTensor scales to [0,1] interval
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
mnist_full_train=torchvision.datasets.MNIST(root='MNIST_dir', train=True, transform=transform, download=True)

# The images are found now as mnist_full_train[i][0] (for image number i)
# with corresponding labels mnist_full_train[i][1]
# You can try "print(mnist_full_train[0][0])" to see the actual pixel values for
# the first image

# Here we extract the image dimension
image_dim=mnist_full_train[0][0].shape[1]
print('Number of pixels in images:',image_dim,'x',image_dim)


In [3]:
# Select the first 1000 images for training to make training fast
train_set = torch.utils.data.Subset(mnist_full_train, list(range(1000)))

# Make a "data loader" that will return batches of images and labels
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)


## Make the neural network

We make a fully connected feed-forward network.

We use the nn.Sequential, which puts the modules one after the other.

Because the images are 2D tensors, the first module flattens the images to 1D - from 28x28 to an array of 784 numbers between 0 and 1 (gray scale).

In [4]:
my_network = nn.Sequential(
    # First, flatten image
    nn.Flatten(),
    # One linear layer with 784 inputs and 100 units
    nn.Linear(image_dim*image_dim,100),
    # ReLU activation function
    nn.ReLU(),
    # Next layer must now have 100 inputs and we chose 50 units in this layer
    nn.Linear(100,50),
    nn.ReLU(),
    # Final layer must have 10 output units, one for each digit
    nn.Linear(50,10)
    )

## Set up the loss function and optimizer

In [5]:
# The cross entropy loss function combines softmax and loss into one
loss_function = nn.CrossEntropyLoss()

# The optimizer must know which parameters to optimize
# The trainable parameters of a network is returned by the parameters() method
optimizer = torch.optim.Adam(my_network.parameters())


## Train the model

In [None]:

# Here we run 100 epochs
for epoch in range(1,101):
    # Each time we call the train_loader it returns one batch of examples
    for image,label in train_loader:
        ### Standard training sequence that can be used always ################
        # Reset all the gradients to zero
        optimizer.zero_grad()
        # Get the outputs of the neural network for the batch
        y = my_network(image)
        # Calculate the loss
        loss = loss_function(y,label)
        # Do the back-propagation
        loss.backward()
        # Update the weights
        optimizer.step()
        ### End of standard training sequence #################################

    if epoch%10==0:
        print('Epoch:',epoch,'Loss',loss.item())




## Check if results make sense

In [None]:
# Pick an arbitrary example from the training set
image,label = train_set[7]
# Calculate the output of the network and apply softmax
y = nn.functional.softmax(my_network(image).squeeze(),dim=0)
print('True label: ',label)
print('Output after softmax:')
for i in range(10):
    print(i,': ','%.5f' % y[i].item())

In [None]:
# Same as above but shown graphically

# And here we are using examples that at NOT in the training set

import matplotlib.pyplot as plt
import numpy as np

# Image number to start from in the full training set
start=40000

# Plot 10 images from the start index
plt.figure(figsize=(10,1))
for i in range(10):
    image,label = mnist_full_train[start+i]
    plt.subplot(1,10,i+1)
    plt.imshow(image.squeeze(),cmap='gray')
    plt.axis('off')
    plt.title(str(label))
plt.show()

# Plot barplot of the output of the network for the same 10 images
plt.figure(figsize=(10,0.5))
for i in range(10):
    image,label = mnist_full_train[start+i]
    y = nn.functional.softmax(my_network(image).squeeze(),dim=0)
    plt.subplot(1,10,i+1)
    plt.bar(range(10),y.detach().numpy())
    # Use smaller font for xticks
    plt.tick_params(axis='x', labelsize=6)
    # Make sure the x-axis is from 0 to 9
    plt.xticks(range(10))
    # remove the y-axis ticks
    plt.yticks([])

plt.show()
