In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os
import math

In [3]:
import torch 
from torchvision import transforms
from torchvision.datasets import MNIST

In [4]:
dataset = MNIST('/workspace/data/', download=True, transform=transforms.ToTensor())
dataset

Dataset MNIST
    Number of datapoints: 60000
    Root location: /workspace/data/
    Split: Train
    StandardTransform
Transform: ToTensor()

In [5]:
dataset.data.shape

torch.Size([60000, 28, 28])

In [6]:
n_train = 50_000
n_valid = dataset.data.shape[0] - n_train
x_train, y_train = dataset.data[:n_train, :, :].view(n_train, -1) / 255, dataset.targets[:n_train]
x_valid, y_valid = dataset.data[n_train:, :, :].view(n_valid, -1) / 255, dataset.targets[n_train:]

In [7]:
train_mean, train_std = x_train.mean(), x_train.std()
train_mean, train_std

(tensor(0.1310), tensor(0.3085))

In [8]:
def normalize(x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor):
    return (x - mean) / std

In [9]:
x_train = normalize(x_train, train_mean, train_std)
# NOTE: use training, not validation mean and std for validation set
x_valid = normalize(x_valid, train_mean, train_std)

In [10]:
x_train.mean(), x_train.std()

(tensor(2.1126e-08), tensor(1.))

In [11]:
def test_near_zero(x, tol=1e-3):
    return x.abs() < tol

In [12]:
test_near_zero(x_train.mean()), test_near_zero(x_train.std() - 1)

(tensor(True), tensor(True))

In [13]:
n, m = x_train.shape
c = y_train.max() + 1
n, m, c

(50000, 784, tensor(10))

## Basic architecture

In [14]:
# num hidden
nh = 50

In [15]:
w1 = torch.randn(m, nh)
b1 = torch.zeros(nh)  # it gets broadcasted to (n, nh)
w2 = torch.randn(nh, 1)
b2 = torch.zeros(1)

In [16]:
# this should roughly be (0, 1)
x_valid.mean(), x_valid.std()

(tensor(-0.0059), tensor(0.9924))

In [17]:
def linear(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor):
    return (x @ w) + b

In [18]:
t = linear(x_valid, w1, b1)
t.mean(), t.std()

(tensor(1.6759), tensor(26.1915))

This is a pretty terrible result. Let's use a simplified Kaiming He init to make mean and std of the output activation closer to (0, 1).

In [19]:
# simplified kaiming he init
w1 = torch.randn(m, nh) / math.sqrt(m)
b1 = torch.zeros(nh)  # it gets broadcasted to (n, nh)
w2 = torch.randn(nh, 1) / math.sqrt(nh)
b2 = torch.zeros(1)

In [20]:
test_near_zero(w1.mean()), test_near_zero(w1.std() - 1 / math.sqrt(m))

(tensor(True), tensor(True))

In [21]:
def linear(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor):
    return (x @ w) + b

In [22]:
# ... and so should this, since we used kaiming he init, which is designed to do this
t = linear(x_valid, w1, b1)
t.mean(), t.std()

(tensor(0.0342), tensor(0.9978))

This is promising, however, it doesn't take into account the non-linearity activation. Modern networks use ReLu, Swish, Mish, etc. The Kaiming He init doesn't work well with such non-linearities.

In [23]:
def relu(x: torch.Tensor):
    return x.clamp_min(0.)

In [24]:
t = relu(linear(x_valid, w1, b1))

In [25]:
t.mean(), t.std()

(tensor(0.4086), tensor(0.5913))

As you can notice, the output is not centered at 0 and with unit standard deviation.