# Training a Digit Classifier

In [None]:
from pathlib import Path

import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
from PIL import Image

plt.rcParams["image.cmap"] = "gray"

In [None]:
mnist_path = Path("~/.fastai/data/mnist_sample/").expanduser()

In [None]:
if not mnist_path.exists():
    import fastbook

    fastbook.untar_data(fastbook.URLs.MNIST_SAMPLE)

In [None]:
threes = list(mnist_path.glob("train/3/*.png"))
sevens = list(mnist_path.glob("train/7/*.png"))

In [None]:
len(threes), len(sevens)

In [None]:
im3_path = threes[1]
im3 = Image.open(im3_path)
im3

In [None]:
jnp.array(im3)[4:9, 4:9]

Generally, when images are floats, the pixel values are expected to be between 0 and 1,
so we will also divide by 255 here

In [None]:
stacked_threes = jnp.stack([jnp.array(Image.open(p)) for p in threes]) / 255
stacked_sevens = jnp.stack([jnp.array(Image.open(p)) for p in sevens]) / 255
stacked_threes.shape, stacked_sevens.shape

In [None]:
plt.imshow(stacked_threes[0])

In [None]:
valid_3s = (
    jnp.stack([jnp.array(Image.open(p)) for p in mnist_path.glob("valid/3/*.png")])
    / 255
)
valid_7s = (
    jnp.stack([jnp.array(Image.open(p)) for p in mnist_path.glob("valid/7/*.png")])
    / 255
)

### Comparing with the perfect 3 and 7

In [None]:
mean_3 = jnp.mean(stacked_threes, axis=0)
plt.imshow(mean_3)

In [None]:
mean_3[4:9, 4:9]

In [None]:
mean_7 = jnp.mean(stacked_sevens, axis=0)
plt.imshow(mean_7)

In [None]:
def mnist_distance(a: jnp.ndarray, b: jnp.ndarray):
    return jnp.mean(jnp.abs(a - b), axis=(-2, -1))

In [None]:
example_3 = stacked_threes[0]

In [None]:
mnist_distance(example_3, mean_3), mnist_distance(example_3, mean_7)

In [None]:
valid_3_to_3 = mnist_distance(valid_3s, mean_3)
valid_3_to_7 = mnist_distance(valid_3s, mean_7)
valid_7_to_3 = mnist_distance(valid_7s, mean_3)
valid_7_to_7 = mnist_distance(valid_7s, mean_7)

How many 3s are detected as 3?

In [None]:
percision_3 = jnp.sum(valid_3_to_3 < valid_3_to_7) / valid_3s.shape[0]
percision_3

In [None]:
percision_7 = jnp.sum(valid_7_to_7 < valid_7_to_3) / valid_7s.shape[0]
percision_7

In [None]:
(percision_3 + percision_7) / 2

When a image is claimed as 3, how much chance will it be accually 3?

In [None]:
recall_3 = (
    percision_3
    * valid_3s.shape[0]
    / (percision_3 * valid_3s.shape[0] + (1 - percision_7) * valid_7s.shape[0])
)
recall_3

### SGD

In [None]:
train_x = jnp.reshape(jnp.concatenate([stacked_threes, stacked_sevens]), (-1, 28 * 28))
train_y = jnp.expand_dims(
    jnp.concatenate(
        [
            jnp.ones(stacked_threes.shape[0]),
            jnp.zeros(stacked_sevens.shape[0]),
        ]
    ),
    axis=1,
)
train_x.shape, train_y.shape

In [None]:
key = jax.random.PRNGKey(42)

In [None]:
def init_params(key, shape, std=1.0):
    return jax.random.normal(key, shape) * std

In [None]:
key, subkey = jax.random.split(key)
weights = init_params(subkey, (28 * 28, 1))

In [None]:
key, subkey = jax.random.split(key)
bias = init_params(subkey, (1,))

In [None]:
jnp.sum(weights.T * train_x[0]) + bias

In [None]:
train_x.shape, weights.shape

In [None]:
def linear1(xb):
    return xb @ weights + bias

In [None]:
pred = linear1(train_x)
pred

In [None]:
corrects = (pred > 0) == train_y
jnp.mean(corrects)

> The problem is that a small change in weights from `x_old` to `x_new` isn’t likely to cause any prediction to change, so `(y_new – y_old)` will almost always be 0. In other words, the gradient is 0 almost everywhere.

Instead of modifying the 0th weight, let's find the weight which is the most important.
That's because the 0th pixel is usually 0, and changing the 0th weight will not lead to any difference. If changing the most important weight is not making any difference, the loss function is indeed bad.

In [None]:
pixels_mean = jnp.mean(train_x, axis=0)
significant_index = jnp.argmax(jnp.mean(train_x, axis=0)).item()
pixels_mean[0].item(), significant_index

In [None]:
weights = weights.at[significant_index].multiply(1.0001)
jnp.mean((linear1(train_x) > 0) == train_y)

> A very small change in the value of a weight will often not change the accuracy at all. This means it is not useful to use accuracy as a loss function.

We choose the (vector) distance between the true value and the prediction.

```python
trgts = tensor([1,0,1])
prds = tensor([0.9, 0.4, 0.2])
```

In [None]:
def mnist_loss(predictions, targets):
    # The predictions should lie in [0, 1] range
    # normalized_pred = jax.nn.sigmoid(predictions)
    normalized_pred = predictions
    return jnp.mean((normalized_pred - targets) ** 2)

As you can see, the new loss function is better.

In [None]:
mnist_loss(pred, train_y)

In [None]:
mnist_loss(linear1(train_x), train_y)

In [None]:
?jnp.sigmoid