In [20]:
import torch
from torch import nn, optim
from matplotlib import pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix

In [21]:
# Compose - This function is used to chain together multiple transformations that will be applied to the data.
# ToTensor - Converts the input image (which is in PIL format) to a PyTorch tensor.
# This also automatically scales the pixel values to the range [0, 1] by dividing by 255.
# Normalize - This normalizes the tensor by subtracting the mean and dividing by the standard deviation.
# Mean: (0.5,) - each channel (MNIST images are grayscale, so only one channel) will be centered around 0.5.
# Std: (0.5,) - each channel will be scaled by a standard deviation of 0.5.
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))])

# '.' means that the dataset will be stored in the current directory
# train = True meaning it will load 60000 images
# download = tells pytorch to download the dataset if it isn't already present in the specified directory
train = datasets.MNIST('.', train= True, download= True, transform= transforms)
test = datasets.MNIST('.', train= False, download= True, transform= transforms)

In [22]:
train_loader = DataLoader(train, batch_size= 64, shuffle= True)
test_loader = DataLoader(test, batch_size= 64, shuffle= True)

In [23]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.ac1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.ac2 = nn.ReLU()
        # not adding softmax layer as the cost function used will first apply softmax and then evaluate
        self.fc3 = nn.Linear(64, 10)
    
    def forward(self, x):
        x = x.reshape(-1, 28 * 28)
        ans = self.fc1(x)
        ans = self.ac1(ans)
        ans = self.fc2(ans)
        ans = self.ac2(ans)
        ans = self.fc3(ans)
        return ans

In [24]:
model = Network()
# CrossEntropy loss automatically applies softmax to the input
# so no need to apply softmax at the output to get the predicted class
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr= 0.001)

In [25]:
loss_list = []
for epoch in range(10):
    # call train inside the epoch loop; this will help in case you want to evaluate something
    # and change the model a bit mid training.
    model.train()
    running_loss = 0.0
    for input, target in train_loader:
        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'epoch [{epoch + 1} / 10], loss: {running_loss / len(train_loader)}')


epoch [1 / 10], loss: 2.2217008798107156
epoch [2 / 10], loss: 1.9124657115194081
epoch [3 / 10], loss: 1.3894374986955607
epoch [4 / 10], loss: 0.963588167863614
epoch [5 / 10], loss: 0.7438113994753437
epoch [6 / 10], loss: 0.6274233063591569
epoch [7 / 10], loss: 0.5563453254160851
epoch [8 / 10], loss: 0.5073951357272642
epoch [9 / 10], loss: 0.4709611423869631
epoch [10 / 10], loss: 0.4428183328844845


In [None]:
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for input, target in test_loader:
        output = model(input)
        # note the max function
        # max return max value and the index of the value in the specified dimension
        val, index = torch.max(output, 1)
        # the index and target contain a list of values, so we use extend instead of append. It will unpack and add all the elements to the lists.
        all_preds.extend(index.numpy())
        all_labels.extend(target.numpy())
cm = confusion_matrix(all_labels, all_preds)
print(cm)
# model.parameters gives the parameters iterable for each layer
# since we only need to find the parameters that are used in the computational graph, we use the if condition
# p.numel() gives number of parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of learnable parameters: {num_params}')

[[ 946    0    5    1    0   16    8    1    3    0]
 [   0 1094    2    6    1    2    4    1   25    0]
 [  14   14  882   23   27    4   22   10   33    3]
 [   4    3   26  883    1   44    2   19   24    4]
 [   2    5    4    0  888    2   16    2    9   54]
 [  16    3   11   63   15  709   22   10   37    6]
 [  18    3   14    1   16   22  881    0    3    0]
 [   6   24   26    2    8    1    0  917    7   37]
 [  12    8   14   35   11   31   17    5  819   22]
 [  13    8    5   11   71   13    1   22   13  852]]
Number of learnable parameters: 109386
