# MNIST Image Classifier Project

This notebook implements a Convolutional Neural Network (CNN) from scratch using PyTorch to classify handwritten digits.

### 1. Imports
We import the necessary libraries: PyTorch for the AI, PIL for image processing, and NumPy for data handling.

In [None]:
import os, io, torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torchvision import transforms
import numpy as np

def print_named_params(model):
  for name, param in model.named_parameters():
    print(f"{name}: {param.numel()}")

### 2. Data Preparation Functions
These functions handle loading the files from folders and converting images into PyTorch tensors.

In [None]:
def load_filepaths(target_dir): 
  paths = []
  files = os.listdir(target_dir)
  for file in files:
    paths.append(os.path.join(target_dir, file))
  return paths

def prepare_data(target_dir):
  filepaths = []
  labels = []

  # We assume folders are named '0', '1', ... '9'
  for i in range(10):
    path = os.path.join(target_dir, str(i))
    if os.path.exists(path):
        fpaths = load_filepaths(path)
        labels += [i] * len(fpaths)
        filepaths += fpaths

  return np.array(filepaths), torch.tensor(labels)

def load_images(filepaths):
  # Instantiate class to transform image to tensor
  to_tensor = transforms.ToTensor()
  tensor = None

  # List all files in the directory
  for item in filepaths:
    image = Image.open(item)

    # transforms.ToTensor() performs transformations on images
    # values of img_tensor are in the range of [0.0, 1.0]
    img_tensor = to_tensor(image)

    if tensor is None:
      # size: [1,1,28,28] (Add batch dimension)
      tensor = img_tensor.unsqueeze(0)
    else:
      # concatenate becomes [2,1,28,28], [3,1,28,28] ...
      tensor = torch.cat((tensor, img_tensor.unsqueeze(0)), dim=0)
    
  return tensor

### 3. CNN Architecture
Here we define the neural network structure: 2 Convolutional Layers, Max Pooling, and 2 Fully Connected Layers.

In [None]:
class SimpleCNN(nn.Module):
  def __init__(self):
    super(SimpleCNN, self).__init__()
    # in_channels=1 (grayscale), out_channels=16 filters
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
    
    # in_channels=16, out_channels=32 filters
    self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
    
    # Max Pooling: downsample by factor of 2
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    
    # Fully Connected 1: input 7*7*32 -> output 128
    self.fc1 = nn.Linear(in_features= 7 * 7 * 32, out_features=128)
    
    # Fully Connected 2: input 128 -> output 10 classes
    self.fc2 = nn.Linear(in_features=128, out_features=10)

    # Activation
    self.relu = nn.ReLU()

  def forward(self, x):
    # Apply convolution + ReLU + pooling
    x = self.conv1(x)
    x = self.relu(x)
    x = self.pool(x)

    x = self.conv2(x)
    x = self.relu(x)
    x = self.pool(x)

    # Flatten the feature maps 
    x = x.view(-1, 7 * 7 * 32)

    # Fully connected layers
    x = self.fc1(x)
    x = self.relu(x)
    
    # Output layer (no activation, CrossEntropyLoss handles softmax)
    x = self.fc2(x)
    
    return x

### 4. Train and Test Functions
The training loop updates the model's weights, while the testing loop evaluates accuracy without learning.

In [None]:
def train(model, criterion, optimizer, filepaths, labels):
  # hyper-parameters
  n_epochs = 2
  batch_size = 64 

  for epoch in range(n_epochs):
    samples_trained = 0
    run_loss = 0
    correct_preds = 0
    total_samples = len(filepaths) 

    permutation = torch.randperm(total_samples)
    for i in range(0, total_samples, batch_size):
      indices = permutation[i : i+batch_size]
      batch_inputs = load_images(filepaths[indices])
      batch_labels = labels[indices]

      # Forward pass
      outputs = model(batch_inputs)

      # Compute loss
      loss = criterion(outputs, batch_labels)
      run_loss += loss.item()

      # Backward pass and optimize
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      
      # Stats
      probs = torch.softmax(outputs, dim=1)
      _, preds = torch.max(probs, dim=1)

      samples_trained += len(batch_labels)
      avg_loss = run_loss / samples_trained
      correct_preds += torch.sum(preds == batch_labels)
      accuracy = correct_preds / float(samples_trained)

      print(f"Epoch {epoch+1} ({samples_trained}/{total_samples}): Loss={avg_loss:.5f}, Accuracy={accuracy:.5f}")


def test(model, filepaths, labels):
  batch_size = 64
  samples_tested = 0
  correct_preds = 0
  total_samples = len(filepaths)

  print("\nStarting Testing...")
  for i in range(0, total_samples, batch_size):
    batch_inputs = load_images(filepaths[i : i + batch_size])
    batch_labels = labels[i : i + batch_size]

    outputs = model(batch_inputs)
    probs = torch.softmax(outputs, dim=1)
    _, preds = torch.max(probs, dim=1)

    samples_tested += len(batch_labels)
    correct_preds += torch.sum(preds == batch_labels)
    accuracy = correct_preds / float(samples_tested)

    print(f"({samples_tested}/{total_samples}): Accuracy={accuracy:.5f}")

### 5. Main Execution Block
Run this cell to start training and testing.

In [None]:
# Instantiate the model, define the loss function and optimizer
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model 
dir_train = "mnist_train" # Make sure this folder is in the same directory
if os.path.exists(dir_train):
    filepaths, labels = prepare_data(dir_train)
    if len(filepaths) > 0:
        train(model, criterion, optimizer, filepaths, labels)
    else:
        print("No training images found in 'mnist_train'.")
else:
    print("Folder 'mnist_train' not found.")

# Test the model
dir_test = "mnist_test" 
if os.path.exists(dir_test):
    filepaths, labels = prepare_data(dir_test)
    if len(filepaths) > 0:
        test(model, filepaths, labels)
    else:
        print("No testing images found in 'mnist_test'.")
else:
    print("Folder 'mnist_test' not found.")