# PyTorch CNN
> Example on MNIST dataset.

In [20]:
import numpy as np
import MNISTtools
import matplotlib as pyplot
import torch
import torch.nn as nn
import torch.nn.functional as F

* Load MNIST data
* Normalize and reshape input data
* Reshape data dimension to suit satisfy PyTorch function, and transform data to torch tensor.

In [21]:
def normalize_MNIST_images(x):
    x = x.astype(np.float32)
    MAX = np.max(x)
    MIN = np.min(x)
    x = - 1 + 2 * ( x - MIN) / ( MAX - MIN)
    return x

xtrain, ltrain = MNISTtools.load(dataset="training")
xtest, ltest = MNISTtools.load(dataset="testing")
xtrain = normalize_MNIST_images(xtrain)
xtest = normalize_MNIST_images(xtest)

xtrain = xtrain.reshape(28,28,1,60000) # (28,28,60000) => (28,28,1,60000)
xtrain = np.moveaxis(xtrain,[2,3],[1,0]) # (28,28,1,60000) => (60000,1,28,28)
xtest = xtest.reshape(28,28,1,10000)
xtest = np.moveaxis(xtest,[2,3],[1,0])

xtrain = torch.from_numpy(xtrain)
ltrain = torch.from_numpy(ltrain)
xtest = torch.from_numpy(xtest)
ltest = torch.from_numpy(ltest)

* Build network by inheriting from PyTorch class nn.Module

In [22]:
class LeNet(nn.Module):
    
    # Initial the network structure
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(256, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    # Define forward
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    # Flatten for linear layers
    def num_flat_features(self, x ):
        size = x.size()[1:]
        return np.prod(size)

* Define network training process

In [23]:
def train(xtrain, ltrain, net, T, B=100, gamma=.001, rho=.9):
    N = xtrain.size()[0] # Training set size
    NB = N // B # Number of minibatches
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=gamma, momentum=rho)
    for epoch in range(T):
        running_loss = 0.0
        shuffled_indices = np.random.permutation(NB)
        for k in range(NB):
            # Extract k-th minibatch from xtrain and ltrain
            minibatch_indices = range(B*shuffled_indices[k], B*(shuffled_indices[k]+1))
            inputs = xtrain[minibatch_indices]
            labels = ltrain[minibatch_indices]
            
            # Initialize the gradients to zero
            optimizer.zero_grad()

            # Forward propagation
            outputs = net(inputs)

            # Error evaluation
            loss = criterion(outputs, labels)

            # Back propagation
            loss.backward()

            # Parameter update
            optimizer.step()

            # Print averaged loss per minibatch every 100 mini-batches
            # Compute and print statistics
            with torch.no_grad():
                running_loss += loss.item()
            if k % 100 == 99:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, k + 1, running_loss / 100))
                running_loss = 0.0
    print('Finished!')

* Start a training (CUDA)

In [27]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('The device using for the training is:',device)

net = LeNet().to(device)
xtrain = xtrain.to(device)
ltrain = ltrain.to(device, dtype=torch.int64)
xtest = xtest.to(device)
ltest = ltest.to(device, dtype=torch.int64)
train(xtrain, ltrain, net, T=10)

y = net(xtest)
_, lpred = y.max(1)
print('The final accuracy is:',100 * (ltest == lpred).float().mean())

The device using for the training is: cuda
[1,   100] loss: 2.299
[1,   200] loss: 2.284
[1,   300] loss: 2.261
[1,   400] loss: 2.211
[1,   500] loss: 2.041
[1,   600] loss: 1.407
[2,   100] loss: 0.794
[2,   200] loss: 0.613
[2,   300] loss: 0.511
[2,   400] loss: 0.473
[2,   500] loss: 0.371
[2,   600] loss: 0.377
[3,   100] loss: 0.348
[3,   200] loss: 0.321
[3,   300] loss: 0.314
[3,   400] loss: 0.276
[3,   500] loss: 0.271
[3,   600] loss: 0.258
[4,   100] loss: 0.238
[4,   200] loss: 0.225
[4,   300] loss: 0.222
[4,   400] loss: 0.208
[4,   500] loss: 0.204
[4,   600] loss: 0.182
[5,   100] loss: 0.180
[5,   200] loss: 0.187
[5,   300] loss: 0.170
[5,   400] loss: 0.159
[5,   500] loss: 0.146
[5,   600] loss: 0.157
[6,   100] loss: 0.152
[6,   200] loss: 0.156
[6,   300] loss: 0.139
[6,   400] loss: 0.134
[6,   500] loss: 0.133
[6,   600] loss: 0.132
[7,   100] loss: 0.125
[7,   200] loss: 0.124
[7,   300] loss: 0.123
[7,   400] loss: 0.117
[7,   500] loss: 0.116
[7,   600] los