# Digit classifier

A CV model that takes in an image of a digit as an input and classifies it from 0 to 9.

## Setup

### Dependencies

* torch:
* PIL.Image: open image files.
* pathlib.Path: convert paths to directories/files to POSIXPath object that is convenient to use.
* URLs: urls to different datasets.
* untar_data: function to unarchive the dataset from the URLs.
* tensor: function to convert image to its tensor representation.

In [None]:
import torch

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plot

from fastai.vision.all import (
    URLs,
    untar_data,
    tensor
)

### Variables

1. Define dataset path.

In [None]:
dataset_path_string = untar_data(URLs.MNIST)
dataset_path = Path(dataset_path_string) / 'training'

2. Define paths to each of the digits.

In [None]:
zeroes = (dataset_path / '0').ls().sorted()
ones = (dataset_path / '1').ls().sorted()
twos = (dataset_path / '2').ls().sorted()
threes = (dataset_path / '3').ls().sorted()
fours = (dataset_path / '4').ls().sorted()
fives = (dataset_path / '5').ls().sorted()
sixs = (dataset_path / '6').ls().sorted()
sevens = (dataset_path / '7').ls().sorted()
eights = (dataset_path / '8').ls().sorted()
nines = (dataset_path / '9').ls().sorted()

### Execution

#### Pixel Similarity

Start with finding the average pixel value for each of the digits.

1. Define rank 2 tensors (because the images are non colored, they have width and height dimensions, so the tensors are of rank 2)

In [None]:
zeroes_tensors_list = [tensor(Image.open(image)) for image in zeroes]
ones_tensors_list = [tensor(Image.open(image)) for image in ones]
twos_tensors_list = [tensor(Image.open(image)) for image in twos]
threes_tensors_list = [tensor(Image.open(image)) for image in threes]
fours_tensors_list = [tensor(Image.open(image)) for image in fours]
fives_tensors_list = [tensor(Image.open(image)) for image in fives]
sixs_tensors_list = [tensor(Image.open(image)) for image in sixs]
sevens_tensors_list = [tensor(Image.open(image)) for image in sevens]
eights_tensors_list = [tensor(Image.open(image)) for image in eights]
nines_tensors_list = [tensor(Image.open(image)) for image in nines]

2. Combine all the images in the tensors lists in a single rank 3 tensor (for each of the digits its own).
Some operations in PyTorch require to cast variables to float, so it's done here.
Since images when images are floats the pixels are expected to be in between 0 and 1, so also divide by 255.

In [None]:
stacked_zeroes = torch.stack(zeroes_tensors_list).float()/255
stacked_ones = torch.stack(ones_tensors_list).float()/255
stacked_twos = torch.stack(twos_tensors_list).float()/255
stacked_threes = torch.stack(threes_tensors_list).float()/255
stacked_fours = torch.stack(fours_tensors_list).float()/255
stacked_fives = torch.stack(fives_tensors_list).float()/255
stacked_sixs = torch.stack(sixs_tensors_list).float()/255
stacked_sevens = torch.stack(sevens_tensors_list).float()/255
stacked_eights = torch.stack(eights_tensors_list).float()/255
stacked_nines = torch.stack(nines_tensors_list).float()/255

3. Find the 'ideal' digit representation for each of the digits.

In [None]:
zeroes_mean = stacked_zeroes.mean(0)
ones_mean = stacked_ones.mean(0)
twos_mean = stacked_twos.mean(0)
threes_mean = stacked_threes.mean(0)
fours_mean = stacked_fours.mean(0)
fives_mean = stacked_fives.mean(0)
sixs_mean = stacked_sixs.mean(0)
sevens_mean = stacked_sevens.mean(0)
eights_mean = stacked_eights.mean(0)
nines_mean = stacked_nines.mean(0)

4. Create stacked tensors for the validation dataset to be able to calculate the accuracy metric over it.

In [None]:
validation_dataset_path = Path(dataset_path_string) / 'testing'

validation_zeroes = (validation_dataset_path / '0').ls().sorted()
validation_ones = (validation_dataset_path / '1').ls().sorted()
validation_twos = (validation_dataset_path / '2').ls().sorted()
validation_threes = (validation_dataset_path / '3').ls().sorted()
validation_fours = (validation_dataset_path / '4').ls().sorted()
validation_fives = (validation_dataset_path / '5').ls().sorted()
validation_sixs = (validation_dataset_path / '6').ls().sorted()
validation_sevens = (validation_dataset_path / '7').ls().sorted()
validation_eights = (validation_dataset_path / '8').ls().sorted()
validation_nines = (validation_dataset_path / '9').ls().sorted()

validation_zeroes_tensors_list = [tensor(Image.open(image)) for image in validation_zeroes]
validation_ones_tensors_list = [tensor(Image.open(image)) for image in validation_ones]
validation_twos_tensors_list = [tensor(Image.open(image)) for image in validation_twos]
validation_threes_tensors_list = [tensor(Image.open(image)) for image in validation_threes]
validation_fours_tensors_list = [tensor(Image.open(image)) for image in validation_fours]
validation_fives_tensors_list = [tensor(Image.open(image)) for image in validation_fives]
validation_sixs_tensors_list = [tensor(Image.open(image)) for image in validation_sixs]
validation_sevens_tensors_list = [tensor(Image.open(image)) for image in validation_sevens]
validation_eights_tensors_list = [tensor(Image.open(image)) for image in validation_eights]
validation_nines_tensors_list = [tensor(Image.open(image)) for image in validation_nines]

validation_stacked_zeroes = torch.stack(validation_zeroes_tensors_list).float() / 255
validation_stacked_ones = torch.stack(validation_ones_tensors_list).float() / 255
validation_stacked_twos = torch.stack(validation_twos_tensors_list).float() / 255
validation_stacked_threes = torch.stack(validation_threes_tensors_list).float() / 255
validation_stacked_fours = torch.stack(validation_fours_tensors_list).float() / 255
validation_stacked_fives = torch.stack(validation_fives_tensors_list).float() / 255
validation_stacked_sixs = torch.stack(validation_sixs_tensors_list).float() / 255
validation_stacked_sevens = torch.stack(validation_sevens_tensors_list).float() / 255
validation_stacked_eights = torch.stack(validation_eights_tensors_list).float() / 255
validation_stacked_nines = torch.stack(validation_nines_tensors_list).float() / 255

5. Define the function that will calculate the MAD of the independent digit to the mean of the same digit.

In [None]:
def mnist_distance(digit_tensor, ideal_mean_for_that_digit):
    return (ideal_mean_for_that_digit - digit_tensor).abs().mean((-1, -2))

6. Define functions that will calculate the MAD of the independent digit to the mean of all the digits in the training dataset and return whether the digit is it or not.

In [None]:
def is_zero(digit_tensor):
    distances = torch.stack([
        mnist_distance(digit_tensor, zeroes_mean),
        mnist_distance(digit_tensor, ones_mean),
        mnist_distance(digit_tensor, twos_mean),
        mnist_distance(digit_tensor, threes_mean),
        mnist_distance(digit_tensor, fours_mean),
        mnist_distance(digit_tensor, fives_mean),
        mnist_distance(digit_tensor, sixs_mean),
        mnist_distance(digit_tensor, sevens_mean),
        mnist_distance(digit_tensor, eights_mean),
        mnist_distance(digit_tensor, nines_mean),
    ])
    closest_digit = distances.argmin(dim=0)
    return closest_digit == 0

def is_one(digit_tensor):
    distances = torch.stack([
        mnist_distance(digit_tensor, zeroes_mean),
        mnist_distance(digit_tensor, ones_mean),
        mnist_distance(digit_tensor, twos_mean),
        mnist_distance(digit_tensor, threes_mean),
        mnist_distance(digit_tensor, fours_mean),
        mnist_distance(digit_tensor, fives_mean),
        mnist_distance(digit_tensor, sixs_mean),
        mnist_distance(digit_tensor, sevens_mean),
        mnist_distance(digit_tensor, eights_mean),
        mnist_distance(digit_tensor, nines_mean),
    ])
    closest_digit = distances.argmin(dim=0)
    return closest_digit == 1

def is_two(digit_tensor):
    distances = torch.stack([
        mnist_distance(digit_tensor, zeroes_mean),
        mnist_distance(digit_tensor, ones_mean),
        mnist_distance(digit_tensor, twos_mean),
        mnist_distance(digit_tensor, threes_mean),
        mnist_distance(digit_tensor, fours_mean),
        mnist_distance(digit_tensor, fives_mean),
        mnist_distance(digit_tensor, sixs_mean),
        mnist_distance(digit_tensor, sevens_mean),
        mnist_distance(digit_tensor, eights_mean),
        mnist_distance(digit_tensor, nines_mean),
    ])
    closest_digit = distances.argmin(dim=0)
    return closest_digit == 2

def is_three(digit_tensor):
    distances = torch.stack([
        mnist_distance(digit_tensor, zeroes_mean),
        mnist_distance(digit_tensor, ones_mean),
        mnist_distance(digit_tensor, twos_mean),
        mnist_distance(digit_tensor, threes_mean),
        mnist_distance(digit_tensor, fours_mean),
        mnist_distance(digit_tensor, fives_mean),
        mnist_distance(digit_tensor, sixs_mean),
        mnist_distance(digit_tensor, sevens_mean),
        mnist_distance(digit_tensor, eights_mean),
        mnist_distance(digit_tensor, nines_mean),
    ])
    closest_digit = distances.argmin(dim=0)
    return closest_digit == 3

def is_four(digit_tensor):
    distances = torch.stack([
        mnist_distance(digit_tensor, zeroes_mean),
        mnist_distance(digit_tensor, ones_mean),
        mnist_distance(digit_tensor, twos_mean),
        mnist_distance(digit_tensor, threes_mean),
        mnist_distance(digit_tensor, fours_mean),
        mnist_distance(digit_tensor, fives_mean),
        mnist_distance(digit_tensor, sixs_mean),
        mnist_distance(digit_tensor, sevens_mean),
        mnist_distance(digit_tensor, eights_mean),
        mnist_distance(digit_tensor, nines_mean),
    ])
    closest_digit = distances.argmin(dim=0)
    return closest_digit == 4

def is_five(digit_tensor):
    distances = torch.stack([
        mnist_distance(digit_tensor, zeroes_mean),
        mnist_distance(digit_tensor, ones_mean),
        mnist_distance(digit_tensor, twos_mean),
        mnist_distance(digit_tensor, threes_mean),
        mnist_distance(digit_tensor, fours_mean),
        mnist_distance(digit_tensor, fives_mean),
        mnist_distance(digit_tensor, sixs_mean),
        mnist_distance(digit_tensor, sevens_mean),
        mnist_distance(digit_tensor, eights_mean),
        mnist_distance(digit_tensor, nines_mean),
    ])
    closest_digit = distances.argmin(dim=0)
    return closest_digit == 5

def is_six(digit_tensor):
    distances = torch.stack([
        mnist_distance(digit_tensor, zeroes_mean),
        mnist_distance(digit_tensor, ones_mean),
        mnist_distance(digit_tensor, twos_mean),
        mnist_distance(digit_tensor, threes_mean),
        mnist_distance(digit_tensor, fours_mean),
        mnist_distance(digit_tensor, fives_mean),
        mnist_distance(digit_tensor, sixs_mean),
        mnist_distance(digit_tensor, sevens_mean),
        mnist_distance(digit_tensor, eights_mean),
        mnist_distance(digit_tensor, nines_mean),
    ])
    closest_digit = distances.argmin(dim=0)
    return closest_digit == 6

def is_seven(digit_tensor):
    distances = torch.stack([
        mnist_distance(digit_tensor, zeroes_mean),
        mnist_distance(digit_tensor, ones_mean),
        mnist_distance(digit_tensor, twos_mean),
        mnist_distance(digit_tensor, threes_mean),
        mnist_distance(digit_tensor, fours_mean),
        mnist_distance(digit_tensor, fives_mean),
        mnist_distance(digit_tensor, sixs_mean),
        mnist_distance(digit_tensor, sevens_mean),
        mnist_distance(digit_tensor, eights_mean),
        mnist_distance(digit_tensor, nines_mean),
    ])
    closest_digit = distances.argmin(dim=0)
    return closest_digit == 7

def is_eight(digit_tensor):
    distances = torch.stack([
        mnist_distance(digit_tensor, zeroes_mean),
        mnist_distance(digit_tensor, ones_mean),
        mnist_distance(digit_tensor, twos_mean),
        mnist_distance(digit_tensor, threes_mean),
        mnist_distance(digit_tensor, fours_mean),
        mnist_distance(digit_tensor, fives_mean),
        mnist_distance(digit_tensor, sixs_mean),
        mnist_distance(digit_tensor, sevens_mean),
        mnist_distance(digit_tensor, eights_mean),
        mnist_distance(digit_tensor, nines_mean),
    ])
    closest_digit = distances.argmin(dim=0)
    return closest_digit == 8

def is_nine(digit_tensor):
    distances = torch.stack([
        mnist_distance(digit_tensor, zeroes_mean),
        mnist_distance(digit_tensor, ones_mean),
        mnist_distance(digit_tensor, twos_mean),
        mnist_distance(digit_tensor, threes_mean),
        mnist_distance(digit_tensor, fours_mean),
        mnist_distance(digit_tensor, fives_mean),
        mnist_distance(digit_tensor, sixs_mean),
        mnist_distance(digit_tensor, sevens_mean),
        mnist_distance(digit_tensor, eights_mean),
        mnist_distance(digit_tensor, nines_mean),
    ])
    closest_digit = distances.argmin(dim=0)
    return closest_digit == 9

7. Calculate accuracy of the model for the validation dataset.

In [None]:
accuracy_zeroes = is_zero(validation_stacked_zeroes).float().mean()
accuracy_ones = is_one(validation_stacked_ones).float().mean()
accuracy_twos = is_two(validation_stacked_twos).float().mean()
accuracy_threes = is_three(validation_stacked_threes).float().mean()
accuracy_fours = is_four(validation_stacked_fours).float().mean()
accuracy_fives = is_five(validation_stacked_fives).float().mean()
accuracy_sixs = is_six(validation_stacked_sixs).float().mean()
accuracy_sevens = is_seven(validation_stacked_sevens).float().mean()
accuracy_eights = is_eight(validation_stacked_eights).float().mean()
accuracy_nines = is_nine(validation_stacked_nines).float().mean()

In [None]:
def f(time, parameters):
    a, b, c = parameters
    return a*(time**2) + (b*time) + c

def mse(prediction, target):
    return ((prediction - target)**2).mean()

In [None]:
time = torch.arange(0, 20).float()
speed = torch.randn(20)*3 + 0.75*(time-9.5)**2 + 1

In [None]:
parameters = torch.randn(3).requires_grad_()

In [None]:
original_parameters = parameters.clone()

In [None]:
predictions = f(time, parameters)

In [None]:
def show_predictions(predictions, ax=None):
    if ax is None:
        ax=plot.subplots()[1]
    ax.scatter(time, speed)
    ax.scatter(time, predictions.detach().cpu().numpy(), color='red')
    ax.set_ylim(-300, 100)

In [None]:
show_predictions(predictions)

In [None]:
loss = mse(predictions, speed)
loss

In [None]:
loss.backward()

In [None]:
parameters.grad

In [None]:
parameters.grad * 1e-5

In [None]:
parameters

In [None]:
lr = 1e-5
parameters.data -= lr * parameters.grad.data

In [None]:
parameters.grad = None

In [None]:
predictions = f(time, parameters)
mse(predictions, speed)

In [None]:
show_predictions(predictions)

In [None]:
def apply_step(parameters, prn=True):
    predictions = f(time, parameters)
    loss = mse(predictions, speed)
    loss.backward()
    parameters.data -= lr * parameters.grad.data
    parameters.grad = None
    if prn:
        print(loss.item())
    return predictions

In [None]:
for i in range(10):
    apply_step(parameters)

In [None]:
parameters = original_parameters.detach().requires_grad_()

In [None]:
_, axs = plot.subplots(1, 4, figsize=(12, 3))
for ax in axs:
    show_predictions(apply_step(parameters, False), ax)
plot.tight_layout()