In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as functional 

In [29]:
#load data
#this might take a while as it will download the dataset from internet
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
data_train = torchvision.datasets.MNIST('./', download=True, train=True, transform = transform)
data_test = torchvision.datasets.MNIST('./', download=True, train=False, transform = transform)

In [30]:
classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
batch_size = 4
trainloader = torch.utils.data.DataLoader(data_train, batch_size = batch_size, shuffle = True, num_workers = 2)
testloader = torch.utils.data.DataLoader(data_test, batch_size = batch_size, shuffle = True, num_workers = 2)

In [31]:
import matplotlib.pyplot as plt
import numpy as np

def imageShow(image):
    image = image / 2 + 0.5
    npImage = img.numpy()
    plt.imshow(np.transpose(npImage, (1, 2, 0)))
    plt.show()

In [32]:
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(400, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        
    def forward(self, x):
        x = functional.max_pool2d(functional.relu(self.conv1(x)), (2, 2))
        x = functional.max_pool2d(functional.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = functional.relu(self.fc1(x))
        x = functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:] 
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

net = Net()
print(net)

Net(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)


In [33]:
crit = nn.CrossEntropyLoss()
opt = optim.SGD(net.parameters(), lr = 0.001, momentum = 0.9)

In [34]:
for epoch in range(2):  
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        opt.zero_grad()
        outputs = net(inputs)
        loss = crit(outputs, labels)
        loss.backward()
        opt.step()
        running_loss += loss.item()
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
print('Finished Training')

[1,  2000] loss: 1.331
[1,  4000] loss: 0.257
[1,  6000] loss: 0.179
[1,  8000] loss: 0.131
[1, 10000] loss: 0.120
[1, 12000] loss: 0.102
[1, 14000] loss: 0.104
[2,  2000] loss: 0.078
[2,  4000] loss: 0.081
[2,  6000] loss: 0.071
[2,  8000] loss: 0.064
[2, 10000] loss: 0.061
[2, 12000] loss: 0.066
[2, 14000] loss: 0.064
Finished Training


In [37]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network %d %%' % (100 * correct / total))

Accuracy of the network 98 %


In [None]:
class_cor = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_cor[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_cor[i] / class_total[i]))