In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix
from matplotlib import pyplot as plt
from torchvision import datasets, transforms
from torch import nn
from sklearn.metrics import accuracy_score as acc

In [2]:
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))])
train = datasets.MNIST(download=True, train=True, transform=transforms, root='./Data')
test = datasets.MNIST(download=True, train=False, transform=transforms, root='./Data')

In [3]:
tr_load = DataLoader(dataset=train, batch_size=64, shuffle=True)
tes_load = DataLoader(dataset=test, batch_size=64, shuffle=True)

In [4]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2), stride=2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2), stride=2),
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2), stride=2)
        )
        self.classify_head = nn.Sequential(
            nn.Linear(in_features=64, out_features=20, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=20, out_features=10, bias=True)
        )
    def forward(self, x):
        identifier = self.net(x)
        return self.classify_head(identifier.reshape(-1,64))

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [8]:
model = CNN().to(device=device)
cr = nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(), lr=0.001)

In [7]:
epochs = int(input("Enter Number of Epochs : "))

In [10]:
for i in range(epochs):
    model.train()
    l1 = 0.0
    for ip, op in tr_load:
        ip, op = ip.to(device), op.to(device)
        opt.zero_grad()
        out = model(ip)
        l = cr(out, op)
        l.backward()
        opt.step()
        l1 += l.item()
    if (i % (epochs / 5)) == 0:
        print(f"Epoch : {i} Loss : {l1}")
        print('-X' * 50 + '-')

Epoch : 0 Loss : 2130.7792110443115
-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-
Epoch : 3 Loss : 1922.7530776262283
-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-
Epoch : 6 Loss : 770.2952352762222
-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-
Epoch : 9 Loss : 404.22583578526974
-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-
Epoch : 12 Loss : 283.6595372930169
-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-


In [11]:
model.eval()

CNN(
  (net): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classify_head): Sequential(
    (0): Linear(in_features=64, out_features=20, bias=True)
    (1): ReLU()
    (2): Linear(in_features=20, out_features=10, bias=True)
  )
)

In [12]:
all_pred = []
all_label = []

In [None]:
with torch.no_grad():
    for ip, op in tes_load:
        ip, op = ip.to(device), op.to(device)
        out = model(ip)
        v, idx = torch.max(out, 1)
        all_pred.extend(idx.cpu())
        all_label.extend(op.cpu())

In [15]:
cm = confusion_matrix(all_label, all_pred)

In [16]:
learnable_para = sum(p.numel() for p in model.parameters() if p.requires_grad)
learnable_para

149798

In [17]:
print(cm)

[[ 939    1    1    0    4   13   18    2    2    0]
 [   0 1110    5    2    0    3    3    0   12    0]
 [   2    2  958    7    7    6    5   25   16    4]
 [   0    0   20  948    0   11    0   16   12    3]
 [   2    5    2    0  933    0   12    2    2   24]
 [   9    4    2   20    2  829    3    4   17    2]
 [  33    3    3    0   11    6  895    0    7    0]
 [   2    4   43   14    2    1    0  919    7   36]
 [   5    1    9   15    9   16   10   11  868   30]
 [  10    3    4    2   18   17    1   24   11  919]]


In [18]:
ac = acc(all_label, all_pred)
print(f"Accuracy : {ac}")

Accuracy : 0.9318
