In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [2]:
train_dataset = datasets.MNIST(root='/data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='/data', train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=2, drop_last=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=True, num_workers=2, drop_last=True)

In [3]:
for batch_idx, sample in enumerate(train_loader):
    image, label = sample
    print(image.size())
    print(label)

    total = [0 for i in range(10)]
    for i in range(10):
        total[i] = (label == i).sum()
    print(total)
    break

torch.Size([64, 1, 28, 28])
tensor([8, 9, 7, 4, 3, 2, 7, 9, 0, 1, 2, 8, 4, 4, 4, 2, 3, 0, 0, 6, 8, 2, 0, 2,
        4, 3, 9, 1, 0, 9, 5, 6, 6, 7, 1, 9, 5, 7, 2, 1, 6, 2, 0, 9, 2, 2, 9, 0,
        1, 7, 8, 1, 7, 1, 3, 1, 8, 8, 8, 6, 5, 1, 2, 9])
[tensor(7), tensor(9), tensor(10), tensor(4), tensor(5), tensor(3), tensor(5), tensor(6), tensor(7), tensor(8)]


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.maxpool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(4*4*16, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))
        x = self.maxpool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*4*4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet()

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.99, 0.999))

In [11]:
epochs = 20

for epoch in range(epochs):
    for batch_idx, samples in enumerate(train_loader):
        image, label = samples[0], samples[1]
        prediction = model(image)

        loss = criterion(prediction, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("epoch: {:2d}/{}, loss: {}".format(epoch, epochs, loss.item()))

epoch:  0/20, loss: 0.08449505269527435
epoch:  1/20, loss: 0.029525453224778175
epoch:  2/20, loss: 0.06896868348121643
epoch:  3/20, loss: 0.026662152260541916
epoch:  4/20, loss: 0.057392366230487823
epoch:  5/20, loss: 0.06648974865674973
epoch:  6/20, loss: 0.003932134248316288
epoch:  7/20, loss: 0.025181367993354797
epoch:  8/20, loss: 0.004681329242885113
epoch:  9/20, loss: 0.0008657730650156736
epoch: 10/20, loss: 0.0004540483350865543
epoch: 11/20, loss: 0.010420321486890316
epoch: 12/20, loss: 0.00011453322076704353
epoch: 13/20, loss: 0.00022322646691463888
epoch: 14/20, loss: 0.000803330447524786
epoch: 15/20, loss: 0.015048198401927948
epoch: 16/20, loss: 7.533235475420952e-05
epoch: 17/20, loss: 0.08263624459505081
epoch: 18/20, loss: 0.008780090138316154
epoch: 19/20, loss: 0.00031637129723094404


In [12]:
correct = 0
total = 0

with torch.no_grad():
    for batch_idx, samples in enumerate(test_loader):
        image, label = samples
        prediction = model(image)
        _, label_pred = prediction.max(dim=1)
        
        correct += (label_pred == label).sum()
        total += label_pred.shape[0]

    print("accuracy = {:.1f}%".format(100*correct/total))


accuracy = 98.9%


In [13]:
correct = [0 for i in range(10)]
total = [0 for i in range(10)]

with torch.no_grad():
    for batch_idx, samples in enumerate(test_loader):
        image, label = samples
        prediction = model(image)
        label_pred = prediction.argmax(dim=1)
    
        for i in range(10):
            total[i] += (label == i).sum()
            correct[i] += ((label == i) & (label_pred == label)).sum()

for i in range(10):
    print("accuracy for {} = {:.1f}%".format(i, 100*correct[i]/total[i]))


accuracy for 0 = 99.9%
accuracy for 1 = 99.6%
accuracy for 2 = 99.8%
accuracy for 3 = 98.3%
accuracy for 4 = 98.5%
accuracy for 5 = 98.1%
accuracy for 6 = 98.7%
accuracy for 7 = 99.1%
accuracy for 8 = 98.7%
accuracy for 9 = 98.3%


In [48]:
for i in range(10):
    print(i)

0
1
2
3
4
5
6
7
8
9
