# Handwritten Digit Recognition

## 1 - Packages

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from utils import *
%matplotlib inline

## 2 - MNIST Dataset

### 2.1 Load the MNIST Dataset

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='./', train=True, download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = torchvision.datasets.MNIST(root='./', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

print('The shape of train dataset:', train_dataset.data.shape)
print('The shape of test dataset:', test_dataset.data.shape)

plt.imshow(train_dataset.data[0], cmap='gray')
plt.title('The first image in train dataset')
plt.show()

plt.imshow(test_dataset.data[0], cmap='gray')
plt.title('The first image in test dataset')
plt.show()

### 2.2 Visualize the MNIST Dataset

In [None]:
visualize_MNIST_dataset(train_loader)

## 3 - Neural Networks Model

### 3.1 Define Neural Networks Model

In [None]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


model = CNNModel()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

print('The parameters of the CNN model:')
for name, param in model.named_parameters():
    print(name, param.shape)
total_params = sum(p.numel() for p in model.parameters())
print('total_params', total_params)

### 3.2 Train Neural Networks Model

In [None]:
num_epochs = 20
epoch_losses = []

for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            average_loss = running_loss / 100
            print('[Epoch: %d, Mini-batch: %d] Average Loss: %.5f' % (epoch + 1, i + 1, average_loss))
            running_loss = 0.0

    epoch_loss = running_loss / len(train_loader)
    epoch_losses.append(epoch_loss)
    print('Epoch %d completed. Average Loss: %.5f' % (epoch + 1, epoch_loss))

plt.plot(range(1, num_epochs + 1), epoch_losses, marker='o')
plt.title('Loss vs. Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

### 3.3 Test Neural Networks Model

In [None]:
correct = 0
total = 0

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test dataset: %.2f%%' % (100 * correct / total))

visualize_prediction_results(model, test_loader)

### 3.4 Compute Average Loss

In [None]:
test_loss = 0.0
total = 0

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = model(images)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        total += labels.size(0)

print('Average loss on the test dataset: %.5f' % (test_loss / total))

### 3.5 Display Some Examples of Mispredictions

In [None]:
visualize_incorrect_predictions(model, test_loader)

## 4 - References

[MNIST Dataset](https://www.kaggle.com/datasets/hojjatk/mnist-dataset): The MNIST database of handwritten digits.