# Classification on MNIST Handwritten Digits
## Introduction
In this notebook, we'll run a simple classification neural network on the MNIST dataset. The MNIST dataset is a collection of 28x28 pixel images of handwritten digits from 0 to 9. The dataset is split into a training set of 60,000 images and a test set of 10,000 images. The goal is to classify the images into the correct digit class.

## 1. Set Up
Install and import the necessary libraries.

In [None]:
# Install required libraries

#!pip install torch 
#!pip install torchvision

In [5]:
import torch 
import torchvision 

from torch import nn, optim
from torchvision import transforms 
from torch.utils.data import DataLoader

In [6]:
# Set the device 
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


## 2. Import Data
We'll import the data from the torchvision library and create a DataLoader for the training and test sets.<br><br>
**Note: MNIST dataset input size is 28x28x1.** 

In [16]:
# Set image path 
image_path = './data'

# Transform images to tensors
transform = transforms.Compose([transforms.ToTensor()])

# Download training dataset
mnist_train = torchvision.datasets.MNIST(
    root=image_path, 
    train=True, 
    transform=transform,
    download=True)

# Download test dataset
mnist_test = torchvision.datasets.MNIST(
    root=image_path, 
    train=False, 
    transform=transform,
    download=True)

# Set batch size
batch_size = 64

# Load training dataset
train_loader = DataLoader(
    dataset=mnist_train, 
    batch_size=batch_size, 
    shuffle=True)

# Load test dataset
test_loader = DataLoader(
    dataset=mnist_test, 
    batch_size=batch_size, 
    shuffle=False)

## 3. Linear Neural Network 
We'll start by building a simple linear neural network with one hidden layer. 

### 3.1 Create Two Layer Neural Network

In [9]:
class TwoLayerNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Initial the neural network 
        self.Sequential = nn.Sequential(
            nn.Flatten(), # Input: 28x28, Output: 28*28
            nn.Linear(28*28, 128), # Input: 28*28, Output: 128
            nn.ReLU(),
            nn.Linear(128, 10) # Input: 128, Output: 10
        )
    
    def forward(self, x):
        output = self.Sequential(x)

        return output 

### 3.2 Train the Model

In [15]:
# Initialize the model
model = TwoLayerNet() 

# Define the optimizer
optimizer = optim.SGD(
    params=model.parameters(),
    lr=0.01,
    momentum=0.9,
)

# Train the model
for epoch in range(10):
    accuracy_history = 0  # Track loss

    for image, label in train_loader:  # Iterate over the training dataset by batch
        pred = model(image)  # Forward pass
        loss = nn.functional.cross_entropy(pred, label) # Cross-entropy loss
        loss.backward() # Backward pass
        optimizer.step() # Update the weights
        optimizer.zero_grad() # Reset the gradients
        is_correct = (
            torch.argmax(pred, dim=1) == label
        ).float() # Check if the prediction is correct
        accuracy_history += is_correct.sum() # Sum the correct predictions
    accuracy_history /= len(mnist_train) # Calculate the accuracy
    print(f"Epoch: {epoch}, Accuracy: {accuracy_history:.4f}")

Epoch: 0, Accuracy: 0.8697
Epoch: 1, Accuracy: 0.9339
Epoch: 2, Accuracy: 0.9501
Epoch: 3, Accuracy: 0.9597
Epoch: 4, Accuracy: 0.9672
Epoch: 5, Accuracy: 0.9715
Epoch: 6, Accuracy: 0.9749
Epoch: 7, Accuracy: 0.9781
Epoch: 8, Accuracy: 0.9806
Epoch: 9, Accuracy: 0.9822


### 3.3 Evaluate the Model

In [19]:
# Evaluate the model on the test dataset
model.eval()  # Set the model to evaluation mode
accuracy_history = 0  # Track accuracy

with torch.no_grad():  # Disable gradient calculation
    for image, label in test_loader:  # Iterate over the test dataset by batch
        pred = model(image)  # Forward pass
        is_correct = (torch.argmax(pred, dim=1) == label).float()  # Check if the prediction is correct
        accuracy_history += is_correct.sum()  # Sum the correct predictions

accuracy_history /= len(mnist_test)  # Calculate the accuracy
print(f"Test accuracy: {accuracy_history:.4f}")

Test accuracy: 0.9752
