In [3]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.metrics import confusion_matrix
from torch import nn

In [4]:
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))])
train = datasets.MNIST('.', train= True, download= True, transform= transforms)
test = datasets.MNIST('.', train= False, download= True, transform=transforms)
train_loader = DataLoader(train, batch_size=64, shuffle= True)
test_loader = DataLoader(test, batch_size= 64, shuffle= True)

In [12]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3),
                                 nn.ReLU(),
                                 nn.MaxPool2d((2, 2), stride= 2),
                                 nn.Conv2d(32, 64, kernel_size=3), 
                                 nn.ReLU(),
                                 nn.MaxPool2d((2, 2), stride=2),
                                 nn.Conv2d(64, 32, kernel_size= 3),
                                 nn.ReLU(),
                                 nn.MaxPool2d((2, 2), stride= 2))
        self.classify_head = nn.Sequential(nn.Linear(32, 20, bias= True),
                                           nn.Linear(20, 10, bias= True))
    
    def forward(self, x):
        return self.classify_head(self.net(x).reshape(-1, 32))

In [13]:
model = CNN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [14]:
for epoch in range(10):
    model.train()
    running_loss = 0.0
    for input, target in train_loader:
        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target)
        loss.backward()
        running_loss += loss.item()
        optimizer.step()
    print(f'Epoch - {epoch}, loss = {running_loss}')

Epoch - 0, loss = 2151.112156867981
Epoch - 1, loss = 2131.0279171466827
Epoch - 2, loss = 2099.186537027359
Epoch - 3, loss = 2033.7613525390625
Epoch - 4, loss = 1874.346396803856
Epoch - 5, loss = 1557.3057669401169
Epoch - 6, loss = 1191.4873133897781
Epoch - 7, loss = 865.344142138958
Epoch - 8, loss = 629.1617791950703
Epoch - 9, loss = 493.97765186429024


In [15]:
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for input, target in test_loader:
        output = model(input)
        val, index = torch.max(output, 1)
        all_preds.extend(index)
        all_labels.extend(target)
cm = confusion_matrix(all_labels, all_preds)
print(cm)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

[[ 904    0    4    0    4   16   49    1    1    1]
 [   0 1112    3    1    0    3    2    3   11    0]
 [   1    9  854   16    7    9   13   36   76   11]
 [   0    1   17  932    0   15    0   14   18   13]
 [   2    4    0    0  870    4   39    4    9   50]
 [   6    9    2   25    3  760   29    5   26   27]
 [  50    7    3    0   48   17  824    0    8    1]
 [   4    5   56   31   14    3    1  863   15   36]
 [   0   10   26   28    5   56   14   15  749   71]
 [   7    5    3   17   33   26    8   34   34  842]]
38150


In [18]:
from sklearn.metrics import accuracy_score
print(accuracy_score(all_labels, all_preds))

0.871
