In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [32]:
train_dataset = datasets.MNIST(root='/data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='/data', train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=2, drop_last=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=True, num_workers=2, drop_last=True)

In [59]:
for batch_idx, sample in enumerate(train_loader):
    image, label = sample
    print(image.size())
    print(label)

    total = [0 for i in range(10)]
    for i in range(10):
        total[i] = (label == i).sum()
    print(total)
    break

torch.Size([64, 1, 28, 28])
tensor([8, 4, 5, 3, 0, 6, 2, 2, 8, 5, 3, 5, 6, 6, 5, 9, 4, 7, 2, 0, 6, 9, 3, 5,
        3, 5, 7, 1, 7, 7, 7, 6, 2, 5, 3, 8, 8, 2, 9, 2, 7, 1, 1, 0, 1, 9, 4, 0,
        0, 3, 1, 3, 9, 3, 6, 6, 9, 7, 6, 3, 9, 6, 4, 0])
[tensor(6), tensor(5), tensor(6), tensor(9), tensor(4), tensor(7), tensor(9), tensor(7), tensor(4), tensor(7)]


In [52]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [53]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.maxpool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(4*4*16, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))
        x = self.maxpool(F.relu(self.conv2(x)))
        x = x.reshape(64, 16*4*4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet()

In [54]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [55]:
epochs = 20

for epoch in range(epochs):
    for batch_idx, samples in enumerate(train_loader):
        image, label = samples[0], samples[1]
        prediction = model(image)

        loss = criterion(prediction, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("epoch: {:2d}/{}, loss: {}".format(epoch, epochs, loss.item()))

epoch:  0/20, loss: 2.255955696105957
epoch:  1/20, loss: 0.2959239184856415
epoch:  2/20, loss: 0.27757230401039124
epoch:  3/20, loss: 0.3543558120727539
epoch:  4/20, loss: 0.08398246020078659
epoch:  5/20, loss: 0.07292014360427856
epoch:  6/20, loss: 0.06101829931139946
epoch:  7/20, loss: 0.04728495329618454
epoch:  8/20, loss: 0.22426337003707886
epoch:  9/20, loss: 0.06843440234661102
epoch: 10/20, loss: 0.06562858074903488
epoch: 11/20, loss: 0.01922835037112236
epoch: 12/20, loss: 0.0981874167919159
epoch: 13/20, loss: 0.03330638259649277
epoch: 14/20, loss: 0.0488969087600708
epoch: 15/20, loss: 0.002911282004788518
epoch: 16/20, loss: 0.08511878550052643
epoch: 17/20, loss: 0.002738929819315672
epoch: 18/20, loss: 0.04523089528083801
epoch: 19/20, loss: 0.04316696897149086


In [57]:
correct = 0
total = 0

with torch.no_grad():
    for batch_idx, samples in enumerate(test_loader):
        image, label = samples
        prediction = model(image)
        _, label_pred = prediction.max(dim=1)
        
        correct += (label_pred == label).sum()
        total += label_pred.shape[0]

    print("accuracy = {:.1f}%".format(100*correct/total))


accuracy = 98.7%


In [63]:
correct = [0 for i in range(10)]
total = [0 for i in range(10)]

with torch.no_grad():
    for batch_idx, samples in enumerate(test_loader):
        image, label = samples
        prediction = model(image)
        _, label_pred = prediction.max(dim=1)
    
        for i in range(10):
            total[i] += (label == i).sum()
            correct[i] += ((label == i) & (label_pred == label)).sum()

for i in range(10):
    print("accuracy for {} = {:.1f}%".format(i, 100*correct[i]/total[i]))


accuracy for 0 = 99.2%
accuracy for 1 = 99.3%
accuracy for 2 = 98.9%
accuracy for 3 = 99.0%
accuracy for 4 = 98.8%
accuracy for 5 = 98.1%
accuracy for 6 = 98.6%
accuracy for 7 = 98.5%
accuracy for 8 = 98.3%
accuracy for 9 = 98.1%


In [48]:
for i in range(10):
    print(i)

0
1
2
3
4
5
6
7
8
9
