In [3]:
# -*- coding: utf-8 -*-
"""
Training an image classifier
----------------------------

We will do the following steps in order:

1. Load and normalizing the CIFAR10 training and test datasets using
   ``torchvision``
2. Define a Convolution Neural Network
3. Define a loss function
4. Train the network on the training data
5. Test the network on the test data

1. Loading and normalizing CIFAR10
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Using ``torchvision``, it’s extremely easy to load CIFAR10.
"""
import torch
import torchvision
import torchvision.transforms as transforms

########################################################################
# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1]

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
    transform=transform, download=True)
validset = torchvision.datasets.CIFAR10(root='./data', train=True, 
    transform=transform, download=False)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, 
    transform=transform, download=False)
                          
# Split training into train and validation
indices = torch.randperm(len(trainset))
trainIndices = indices[0:400]

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                           sampler=torch.utils.data.sampler.SubsetRandomSampler(trainIndices),
                                           shuffle=False, num_workers=2)

testloader = torch.utils.data.DataLoader(testset, batch_size=4, 
                                            shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

########################################################################
# Let us show some of the training images, for fun.

import matplotlib.pyplot as plt
import numpy as np

# functions to show an image

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))

# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))


########################################################################
# 2. Define a Convolution Neural Network
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Copy the neural network from the Neural Networks section before and modify it to
# take 3-channel images (instead of 1-channel images as it was defined).

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

########################################################################
# 3. Define a Loss function and optimizer
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Let's use a Classification Cross-Entropy loss and SGD with momentum

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

Files already downloaded and verified
 bird  deer  bird  deer


In [4]:
########################################################################
# 4. Train the network
# ^^^^^^^^^^^^^^^^^^^^
#
# This is when things start to get interesting.
# We simply have to loop over our data iterator, and feed the inputs to the
# network and optimize

for epoch in range(4):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):

        # get the inputs
        inputs, labels = data

        # wrap them in Variable
        inputs, labels = Variable(inputs), Variable(labels)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.data[0]
        if i % 10 == 9:    # print every 10 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 10))
            running_loss = 0.0

print('Finished Initial Training')

[1,    10] loss: 2.305
[1,    20] loss: 2.308
[1,    30] loss: 2.301
[1,    40] loss: 2.300
[1,    50] loss: 2.314
[1,    60] loss: 2.305
[1,    70] loss: 2.304
[1,    80] loss: 2.290
[1,    90] loss: 2.293
[1,   100] loss: 2.309
[2,    10] loss: 2.298
[2,    20] loss: 2.308
[2,    30] loss: 2.296
[2,    40] loss: 2.303
[2,    50] loss: 2.302
[2,    60] loss: 2.294
[2,    70] loss: 2.297
[2,    80] loss: 2.296
[2,    90] loss: 2.304
[2,   100] loss: 2.313
[3,    10] loss: 2.298
[3,    20] loss: 2.298
[3,    30] loss: 2.305
[3,    40] loss: 2.301
[3,    50] loss: 2.321
[3,    60] loss: 2.292
[3,    70] loss: 2.297
[3,    80] loss: 2.302
[3,    90] loss: 2.283
[3,   100] loss: 2.296
[4,    10] loss: 2.303
[4,    20] loss: 2.306
[4,    30] loss: 2.298
[4,    40] loss: 2.295
[4,    50] loss: 2.298
[4,    60] loss: 2.294
[4,    70] loss: 2.298
[4,    80] loss: 2.292
[4,    90] loss: 2.302
[4,   100] loss: 2.297
Finished Initial Training


In [16]:
# Repeat hard data mining 
for k in range(100):
    indices = torch.randperm(len(trainset))
    validIndices = indices[0:400]

    validloader = torch.utils.data.DataLoader(validset, batch_size=4,
                                           sampler=torch.utils.data.sampler.SubsetRandomSampler(validIndices),
                                           shuffle=False, num_workers=2)

    # Hard data mine
    print('[%d ITER] BEGIN HARD NEGATIVE MINING' % k)
    for epoch in range(4):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(validloader, 0):
            images, labels = data
            outputs = net(Variable(images))
            _, predicted = torch.max(outputs.data, 1)
            if (predicted == labels).all(): continue
            else: 
                # training model with corrected label

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                outputs = net(Variable(images))
                loss = criterion(outputs, Variable(labels))
                loss.backward()
                optimizer.step()
                running_loss += loss.data[0]

                if i % 10 == 9:    # print every 10 mini-batches
                    print('[%d, %5d] loss: %.3f' %
                          (epoch + 1, i + 1, running_loss / 10))    # print statistics
                    running_loss = 0.0

    print('[%d ITER] Finished Retraining' %k)

[0 ITER] BEGIN HARD NEGATIVE MINING
[1,    20] loss: 2.766
[1,    30] loss: 1.221
[1,    40] loss: 1.287
[1,    50] loss: 1.452
[1,    60] loss: 1.363
[1,    70] loss: 2.089
[1,    80] loss: 1.410
[1,    90] loss: 1.417
[1,   100] loss: 1.404
[2,    10] loss: 0.755
[2,    20] loss: 1.054
[2,    30] loss: 0.847
[2,    40] loss: 0.750
[2,    50] loss: 0.490
[2,    60] loss: 0.664
[2,    80] loss: 1.585
[2,    90] loss: 0.812
[2,   100] loss: 0.778
[3,    50] loss: 2.477
[3,    70] loss: 0.952
[4,    50] loss: 1.014
[4,    90] loss: 1.309
[4,   100] loss: 0.374
[0 ITER] Finished Retraining
[1 ITER] BEGIN HARD NEGATIVE MINING
[1,    10] loss: 1.845
[1,    20] loss: 1.861
[1,    30] loss: 1.450
[1,    40] loss: 1.679
[1,    50] loss: 1.052
[1,    60] loss: 1.524
[1,    70] loss: 1.281
[1,    80] loss: 1.361
[1,    90] loss: 1.545
[1,   100] loss: 1.132
[2,    10] loss: 0.913
[2,    40] loss: 2.496
[2,    60] loss: 1.360
[2,    70] loss: 0.380
[2,    80] loss: 0.734
[2,    90] loss: 0.534
[2

[2,    10] loss: 0.876
[2,    20] loss: 0.846
[2,    50] loss: 1.912
[2,    80] loss: 2.585
[2,    90] loss: 0.645
[2,   100] loss: 0.816
[3,    10] loss: 0.297
[3,    30] loss: 1.054
[3,    60] loss: 1.487
[3,    70] loss: 0.703
[3,    80] loss: 0.525
[3,   100] loss: 0.846
[4,    10] loss: 0.173
[4,    30] loss: 0.333
[4,    40] loss: 0.289
[4,   100] loss: 1.243
[11 ITER] Finished Retraining
[12 ITER] BEGIN HARD NEGATIVE MINING
[1,    10] loss: 1.386
[1,    20] loss: 1.668
[1,    30] loss: 1.673
[1,    40] loss: 1.997
[1,    50] loss: 1.479
[1,    60] loss: 1.617
[1,    70] loss: 1.398
[1,    90] loss: 2.797
[1,   100] loss: 1.155
[2,    10] loss: 0.743
[2,    30] loss: 1.366
[2,    50] loss: 1.660
[2,    60] loss: 0.691
[2,    80] loss: 1.595
[2,    90] loss: 1.064
[2,   100] loss: 0.873
[3,    20] loss: 0.519
[3,    30] loss: 0.483
[3,    50] loss: 0.733
[3,    80] loss: 1.249
[4,    20] loss: 0.356
[4,    40] loss: 0.504
[4,    50] loss: 0.440
[4,    60] loss: 0.259
[4,    70] lo

[4,    10] loss: 0.159
[4,    20] loss: 0.258
[4,    30] loss: 0.245
[4,    40] loss: 0.173
[4,    70] loss: 0.936
[4,    80] loss: 0.294
[23 ITER] Finished Retraining
[24 ITER] BEGIN HARD NEGATIVE MINING
[1,    10] loss: 1.436
[1,    20] loss: 1.135
[1,    30] loss: 1.751
[1,    50] loss: 2.533
[1,    60] loss: 1.506
[1,    70] loss: 1.061
[1,    80] loss: 1.170
[1,    90] loss: 1.246
[1,   100] loss: 1.259
[2,    10] loss: 0.524
[2,    20] loss: 0.591
[2,    70] loss: 3.215
[2,    80] loss: 0.696
[2,    90] loss: 0.703
[2,   100] loss: 0.660
[3,    10] loss: 0.224
[3,    20] loss: 0.571
[3,    30] loss: 0.308
[3,    50] loss: 0.629
[3,    70] loss: 0.357
[3,    90] loss: 0.898
[4,    10] loss: 0.064
[4,    60] loss: 1.087
[4,    80] loss: 0.497
[24 ITER] Finished Retraining
[25 ITER] BEGIN HARD NEGATIVE MINING
[1,    10] loss: 1.264
[1,    20] loss: 1.566
[1,    30] loss: 1.506
[1,    40] loss: 1.618
[1,    50] loss: 1.510
[1,    60] loss: 1.368
[1,    70] loss: 1.313
[1,    80] loss

[1,    10] loss: 1.625
[1,    20] loss: 1.220
[1,    40] loss: 3.162
[1,    50] loss: 1.834
[1,    60] loss: 1.058
[1,    70] loss: 1.103
[1,    80] loss: 1.367
[1,   100] loss: 2.356
[2,    10] loss: 0.751
[2,    20] loss: 0.799
[2,    30] loss: 0.771
[2,    40] loss: 0.652
[2,    70] loss: 1.776
[2,    90] loss: 1.439
[3,    40] loss: 1.460
[3,    50] loss: 0.375
[3,    60] loss: 0.428
[3,    70] loss: 0.148
[3,    80] loss: 0.574
[3,   100] loss: 0.776
[4,    20] loss: 0.603
[4,    90] loss: 1.464
[36 ITER] Finished Retraining
[37 ITER] BEGIN HARD NEGATIVE MINING
[1,    10] loss: 1.327
[1,    20] loss: 1.620
[1,    30] loss: 2.116
[1,    40] loss: 1.220
[1,    50] loss: 1.138
[1,    70] loss: 2.396
[1,    80] loss: 1.680
[1,    90] loss: 1.418
[2,    20] loss: 1.574
[2,    40] loss: 1.178
[2,    50] loss: 0.668
[2,    60] loss: 0.752
[2,    70] loss: 0.566
[2,    80] loss: 1.169
[2,   100] loss: 1.414
[3,    20] loss: 0.904
[3,    40] loss: 0.825
[3,    60] loss: 0.827
[3,    80] lo

[48 ITER] Finished Retraining
[49 ITER] BEGIN HARD NEGATIVE MINING
[1,    10] loss: 1.820
[1,    20] loss: 1.395
[1,    30] loss: 1.408
[1,    40] loss: 1.363
[1,    50] loss: 1.450
[1,    60] loss: 1.198
[1,    80] loss: 2.586
[1,    90] loss: 1.481
[2,    10] loss: 0.813
[2,    30] loss: 1.724
[2,    60] loss: 2.023
[2,    70] loss: 0.625
[2,    80] loss: 0.682
[2,    90] loss: 0.972
[2,   100] loss: 0.656
[3,    10] loss: 0.347
[3,    20] loss: 0.425
[3,    30] loss: 0.475
[3,    60] loss: 1.216
[3,    70] loss: 0.644
[3,    90] loss: 0.654
[4,    30] loss: 0.698
[4,    40] loss: 0.225
[4,    50] loss: 0.283
[4,    60] loss: 0.150
[4,    90] loss: 0.916
[49 ITER] Finished Retraining
[50 ITER] BEGIN HARD NEGATIVE MINING
[1,    10] loss: 1.347
[1,    20] loss: 1.101
[1,    30] loss: 0.998
[1,    40] loss: 1.570
[1,    50] loss: 1.342
[1,    60] loss: 1.525
[1,    70] loss: 1.426
[1,    80] loss: 1.048
[1,    90] loss: 1.389
[1,   100] loss: 1.368
[2,    10] loss: 0.567
[2,    20] loss

[2,   100] loss: 3.152
[3,    30] loss: 1.122
[3,    40] loss: 0.512
[3,    50] loss: 0.410
[3,    70] loss: 0.594
[3,   100] loss: 1.212
[4,    10] loss: 0.191
[4,    40] loss: 0.625
[4,    60] loss: 0.546
[61 ITER] Finished Retraining
[62 ITER] BEGIN HARD NEGATIVE MINING
[1,    10] loss: 1.387
[1,    40] loss: 4.062
[1,    50] loss: 1.804
[1,    60] loss: 1.063
[1,    70] loss: 1.258
[1,    90] loss: 2.042
[1,   100] loss: 1.377
[2,    20] loss: 1.605
[2,    30] loss: 0.927
[2,    40] loss: 0.720
[2,    80] loss: 2.701
[2,    90] loss: 0.987
[2,   100] loss: 0.454
[3,    20] loss: 0.860
[3,    30] loss: 0.394
[3,    40] loss: 0.497
[3,    50] loss: 0.658
[3,    60] loss: 0.574
[3,    90] loss: 0.878
[4,    30] loss: 0.348
[4,    70] loss: 1.299
[62 ITER] Finished Retraining
[63 ITER] BEGIN HARD NEGATIVE MINING
[1,    10] loss: 1.068
[1,    20] loss: 1.290
[1,    30] loss: 1.018
[1,    40] loss: 1.402
[1,    50] loss: 1.304
[1,    70] loss: 2.811
[1,    90] loss: 2.173
[1,   100] loss

[2,    10] loss: 0.561
[2,    20] loss: 0.801
[2,    30] loss: 0.715
[2,    40] loss: 0.744
[2,    50] loss: 0.827
[2,    60] loss: 0.613
[2,    70] loss: 0.914
[2,    80] loss: 0.812
[2,   100] loss: 1.082
[3,    40] loss: 1.334
[3,    50] loss: 0.294
[3,    80] loss: 1.135
[3,   100] loss: 0.966
[4,    30] loss: 0.535
[4,    40] loss: 0.222
[4,    80] loss: 0.985
[4,    90] loss: 0.235
[75 ITER] Finished Retraining
[76 ITER] BEGIN HARD NEGATIVE MINING
[1,    10] loss: 1.745
[1,    20] loss: 1.013
[1,    30] loss: 1.320
[1,    40] loss: 1.159
[1,    60] loss: 2.756
[1,    70] loss: 0.696
[1,    80] loss: 1.586
[1,    90] loss: 1.246
[1,   100] loss: 1.448
[2,    10] loss: 0.874
[2,    20] loss: 0.780
[2,    30] loss: 0.563
[2,    40] loss: 0.454
[2,    60] loss: 1.263
[2,    90] loss: 1.438
[3,    10] loss: 0.301
[3,    40] loss: 0.778
[3,    70] loss: 0.672
[3,    90] loss: 0.850
[3,   100] loss: 0.371
[4,    10] loss: 0.271
[4,    60] loss: 0.656
[4,    70] loss: 0.096
[76 ITER] Fin

[1,    10] loss: 1.062
[1,    20] loss: 1.015
[1,    30] loss: 1.244
[1,    40] loss: 1.223
[1,    50] loss: 1.103
[1,    60] loss: 1.146
[1,    70] loss: 0.859
[1,    80] loss: 1.463
[1,   100] loss: 1.709
[2,    20] loss: 1.027
[2,    30] loss: 0.278
[2,    80] loss: 2.169
[2,    90] loss: 0.365
[2,   100] loss: 0.466
[3,    30] loss: 0.774
[3,    70] loss: 1.256
[4,    30] loss: 0.505
[4,    60] loss: 0.580
[89 ITER] Finished Retraining
[90 ITER] BEGIN HARD NEGATIVE MINING
[1,    10] loss: 1.522
[1,    20] loss: 1.213
[1,    30] loss: 1.332
[1,    40] loss: 1.439
[1,    60] loss: 2.450
[1,    70] loss: 0.799
[1,    80] loss: 1.323
[1,   100] loss: 2.926
[2,    10] loss: 0.879
[2,    40] loss: 1.524
[2,    50] loss: 0.882
[2,    60] loss: 0.591
[2,    70] loss: 0.677
[2,    80] loss: 0.553
[2,    90] loss: 0.547
[3,    20] loss: 0.756
[3,    40] loss: 1.023
[3,   100] loss: 1.759
[4,    40] loss: 0.711
[4,    70] loss: 0.633
[90 ITER] Finished Retraining
[91 ITER] BEGIN HARD NEGATIVE

In [17]:
########################################################################
# TEST
#
# Let us look at how the network performs on the whole dataset.

correct = 0
total = 0
for data in testloader:
    images, labels = data
    outputs = net(Variable(images))
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 53 %


In [13]:
########################################################################
# 5. Test the network on the test data
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# We have trained the network for 2 passes over the training dataset.
# But we need to check if the network has learnt anything at all.
#
# We will check this by predicting the class label that the neural network
# outputs, and checking it against the ground-truth. If the prediction is
# correct, we add the sample to the list of correct predictions.
#
# Okay, first step. Let us display an image from the test set to get familiar.


dataiter = iter(testloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

# ########################################################################
# # Okay, now let us see what the neural network thinks these examples above are:

# outputs = net(Variable(images))

# ########################################################################
# # The outputs are energies for the 10 classes.
# # Higher the energy for a class, the more the network
# # thinks that the image is of the particular class.
# # So, let's get the index of the highest energy:
# _, predicted = torch.max(outputs.data, 1)
# # print(predicted)
# # print(predicted[1][0])
# print('Predicted: ', ' '.join('%5s' % classes[predicted[j][0]] for j in range(4)))


GroundTruth:    cat  ship  ship plane


In [18]:
########################################################################
# That looks waaay better than chance, which is 10% accuracy (randomly picking
# a class out of 10 classes).
# Seems like the network learnt something.
#
# Hmmm, what are the classes that performed well, and the classes that did
# not perform well:

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
for data in testloader:
    images, labels = data
    outputs = net(Variable(images))
    _, predicted = torch.max(outputs.data, 1)
    c = (predicted == labels).squeeze()
    for i in range(4):
        label = labels[i]
        class_correct[label] += c[i]
        class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

Accuracy of plane : 58 %
Accuracy of   car : 62 %
Accuracy of  bird : 39 %
Accuracy of   cat : 38 %
Accuracy of  deer : 45 %
Accuracy of   dog : 33 %
Accuracy of  frog : 63 %
Accuracy of horse : 63 %
Accuracy of  ship : 69 %
Accuracy of truck : 58 %
