In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as dset
from torchsummary import summary
from torchvision import datasets, transforms

In [3]:
# GPU
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('GPU State:', device)

GPU State: cpu


In [4]:
def training_loop(model, loss, optimizer, train_Loader, n_epochs, verbose=True, device=device):
    """
    Run training of a model given a loss function, optimizer and a set of training and validation data.
    """

    # Train
    for epoch in range(epochs):
        running_loss = 0.0

        for times, data in enumerate(train_Loader):
            inputs, labels = data[0].to(device), data[1].to(device)
            inputs = inputs.view(inputs.shape[0], -1)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Foward + backward + optimize
            outputs = net(inputs)
            loss_tensor = loss(outputs, labels)
            loss_tensor.backward()
            optimizer.step()

            # Print statistics
            running_loss += loss_tensor.item()
            if verbose:
                if times % 100 == 99 or times+1 == len(trainLoader):
                    print('[%d/%d, %d/%d] loss: %.3f' % (epoch+1, epochs, times+1, len(trainLoader), running_loss/2000))      

In [5]:
def evaluate_model(model, data_loader, device=device):
    """
    Evaluate a model 'model' on all batches of a torch DataLoader 'data_loader'.
    
    Returns: the total number of correct classifications,
             the total number of images
             the list of the per class correct classification,
             the list of the per class total number of images.
    """
    
    # Test
    correct = 0
    total = 0

    with torch.no_grad():
        for data in testLoader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            inputs = inputs.view(inputs.shape[0], -1)

            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    class_correct = [0 for i in range(10)]
    class_total = [0 for i in range(10)]

    with torch.no_grad():
        for data in testLoader:
            inputs, labels = data[0].to(device), data[1].to(device)
            inputs = inputs.view(inputs.shape[0], -1)

            outputs = net(inputs)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(10):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
    
    return (correct, total, class_correct, class_total)

In [6]:
# Transform
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,)),]
)

In [7]:
# Data
trainSet = datasets.MNIST(root='MNIST', download=True, train=True, transform=transform)
testSet = datasets.MNIST(root='MNIST', download=True, train=False, transform=transform)
trainLoader = dset.DataLoader(trainSet, batch_size=64, shuffle=True)
testLoader = dset.DataLoader(testSet, batch_size=64, shuffle=False)

In [14]:
# Model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=(1,1)),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2, padding=0, dilation=1, ceil_mode=False),
            nn.Conv2d(16, 32, 3, stride=(1,1)),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2, padding=0, dilation=1, ceil_mode=False),
            nn.Flatten(start_dim=1, end_dim=-1),
            nn.Linear(800, 10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, input):
        input = input.view(input.shape[0], 1, 28, 28)
        return self.main(input)


net = Net().to(device)
print(net)
print(summary(net, (1, 28, 28)))

Net(
  (main): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=800, out_features=10, bias=True)
    (8): LogSoftmax(dim=1)
  )
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 26, 26]             160
              ReLU-2           [-1, 16, 26, 26]               0
         MaxPool2d-3           [-1, 16, 13, 13]               0
            Conv2d-4           [-1, 32, 11, 11]           4,640
              ReLU-5           [-1, 32, 11, 11]               0
         MaxPool2d-6             [-1, 32, 5, 5]               0
       

In [17]:
16 * 32 * 3*3 + 32

4640

In [15]:
# Parameters
epochs = 4
lr = 0.002
loss = nn.NLLLoss()
optimizer = optim.SGD(net.parameters(), lr=0.002, momentum=0.9)

# Train
print('Training on %d images' % trainSet.data.shape[0])
training_loop(net, loss, optimizer, trainLoader, epochs)
print('Training Finished.\n')

# Test
correct, total, class_correct, class_total = evaluate_model(net, testLoader)
print('Accuracy of the network on the %d test images: %d %%' % (testSet.data.shape[0], (100*correct / total)))
for i in range(10):
    print('Accuracy of %d: %3f' % (i, (class_correct[i]/class_total[i])))

Training on 60000 images
[1/4, 100/938] loss: 0.091
[1/4, 200/938] loss: 0.121
[1/4, 300/938] loss: 0.139
[1/4, 400/938] loss: 0.153
[1/4, 500/938] loss: 0.165
[1/4, 600/938] loss: 0.175
[1/4, 700/938] loss: 0.184
[1/4, 800/938] loss: 0.193
[1/4, 900/938] loss: 0.201
[1/4, 938/938] loss: 0.204
[2/4, 100/938] loss: 0.008
[2/4, 200/938] loss: 0.015
[2/4, 300/938] loss: 0.021
[2/4, 400/938] loss: 0.028
[2/4, 500/938] loss: 0.034
[2/4, 600/938] loss: 0.040
[2/4, 700/938] loss: 0.045
[2/4, 800/938] loss: 0.051
[2/4, 900/938] loss: 0.056
[2/4, 938/938] loss: 0.059
[3/4, 100/938] loss: 0.005
[3/4, 200/938] loss: 0.010
[3/4, 300/938] loss: 0.015
[3/4, 400/938] loss: 0.020
[3/4, 500/938] loss: 0.025
[3/4, 600/938] loss: 0.030
[3/4, 700/938] loss: 0.034
[3/4, 800/938] loss: 0.039
[3/4, 900/938] loss: 0.043
[3/4, 938/938] loss: 0.044
[4/4, 100/938] loss: 0.004
[4/4, 200/938] loss: 0.008
[4/4, 300/938] loss: 0.012
[4/4, 400/938] loss: 0.016
[4/4, 500/938] loss: 0.020
[4/4, 600/938] loss: 0.024
[4/