In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix
from matplotlib import pyplot as plt
from torchvision import datasets, transforms
from torch import nn

In [3]:
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))])
train = datasets.MNIST('.', train= True, download= True, transform= transforms)
test = datasets.MNIST('.', train= False, download= True, transform= transforms)

In [4]:
train_loader = DataLoader(train, batch_size= 64, shuffle= True)
test_loader = DataLoader(test, batch_size= 64, shuffle= True)

In [5]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3),
                                 nn.ReLU(),
                                 nn.MaxPool2d((2, 2), stride=2),
                                 nn.Conv2d(64, 128, kernel_size=3),
                                 nn.ReLU(),
                                 nn.MaxPool2d((2, 2), stride=2),
                                 nn.Conv2d(128, 64, kernel_size=3),
                                 nn.ReLU(),
                                 nn.MaxPool2d((2, 2), stride=2))
        self.classify_head = nn.Sequential(nn.Linear(64, 20, bias= True),
                                           nn.ReLU(),
                                           nn.Linear(20, 10, bias=True))
    def forward(self, x):
        features = self.net(x)
        # the -1 in reshape is for infering the batch size
        return self.classify_head(features.reshape(-1, 64))

In [6]:
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr= 0.001)

In [7]:
for epoch in range(10):
    model.train()
    running_loss = 0.0
    for input, target in train_loader:
        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch - {epoch}, loss = {running_loss}')

Epoch - 0, loss = 2154.909630060196
Epoch - 1, loss = 2140.0478971004486
Epoch - 2, loss = 2119.7958755493164
Epoch - 3, loss = 2083.61563038826
Epoch - 4, loss = 2007.311164855957
Epoch - 5, loss = 1826.4274609088898
Epoch - 6, loss = 1403.7756037712097
Epoch - 7, loss = 926.8690323829651
Epoch - 8, loss = 652.020500510931
Epoch - 9, loss = 497.1380241513252


In [11]:
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for input, target in test_loader:
        output = model(input)
        val, index = torch.max(output, 1)
        all_preds.extend(index)
        all_labels.extend(target)
cm = confusion_matrix(all_labels, all_preds)
print(cm)
# remember that model.parameters needs () after it
print(sum(p.numel() for p in model.parameters() if p.requires_grad))


[[ 930    0    8    0    1    9   19    5    5    3]
 [   0 1097    2    3    0   10    2    1   20    0]
 [  10    1  884   33    6   12    8   28   33   17]
 [   0    2   36  883    0   50    0   20   13    6]
 [   2    4    0    0  902    1   29    1    3   40]
 [   7   13    6   27   33  708   15   20   51   12]
 [  16    5    2    0   33    6  890    0    6    0]
 [   4    4   41   56    0    6    0  876    3   38]
 [   3   14   10   46   15   48   14    6  776   42]
 [  11    7    4    9   54   49    3   21   12  839]]
149798


In [12]:
from sklearn.metrics import accuracy_score
print(accuracy_score(all_labels, all_preds))

0.8785
