In [None]:
import math

import torch
import torch.nn.functional as F
from torchvision import datasets, transforms

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (1.0,))])

# Training and testing datasets.
trainset = datasets.MNIST(
    root='../data', train=True,
    download=True, transform=transform)
testset = datasets.MNIST(
    root='../data', train=False,
    download=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



In [None]:
images_train = torch.stack([x for x, y in trainset])
labels_train = torch.tensor([y for x, y in trainset])
print('images_train:', images_train.shape)
print('labels_train:', labels_train.shape)

images_train: torch.Size([60000, 1, 28, 28])
labels_train: torch.Size([60000])


In [None]:
# Flatten last 3 dimensions to obtain a vector.
x_train = torch.flatten(images_train, start_dim=-3)
print('x_train:', x_train.shape)

x_train: torch.Size([60000, 784])


In [None]:
# Obtain one-hot representation of targets.
y_train = F.one_hot(labels_train).float()
print('y_train:', y_train.shape)

y_train: torch.Size([60000, 10])


In [None]:
# Find W that minimizes |X W - Y|^2.
# (Adopt Frobenius norm for residual matrix, i.e. 2-norm of vectorized matrix.)
# This is equivalent to solving for each column of W independently:
# |X W - Y|^2 = sum_i |X w_i - y_i|^2

# Use QR-decomposition X = Q R where
#   R is square and upper-triangular
#   Q is rectangular (tall) with orthogonal columns Q' Q = I
#
# min_w  |X w - y|^2
# min_w  w' X' X w - 2 y' X w + y' y
# X' X w = X' y
# R' Q' Q R w = R' Q' y
# R' R w = R' Q' y
# R w = Q' y  (assuming X' X is non-singular and hence R too)
#
# We can solve for all columns of W at once using:
# R W = Q' Y

q, r = torch.linalg.qr(x_train)
result = torch.triangular_solve(q.T @ y_train, r)
weights = result.solution
print('weights:', weights.shape)

weights: torch.Size([784, 10])


In [None]:
# Evaluate model on training set.
output = x_train @ weights
pred = torch.argmax(output, dim=-1)

print('output:', output.shape)
print('pred:', pred.shape)

output: torch.Size([60000, 10])
pred: torch.Size([60000])


In [None]:
# Check the accuracy of our predictions.
# Remember that chance is 0.1.

torch.mean((pred == labels_train).float())

tensor(0.0987)

In [None]:
# Check the condition number of the problem.
# A high condition number means singular values close to zero.
# This means there is a subspace of solutions with similar loss,
# and small changes in Y will result in large changes in solution.
# Acceptable values might be around 10^3 or 10^4.

torch.linalg.cond(x_train)

tensor(49350108.)

In [None]:
# Convert to double precision (64 bit) to see true scale of problem.
torch.linalg.cond(x_train.double())

tensor(2.3693e+16, dtype=torch.float64)

In [None]:
# Add a regularizer.
# Find W that minimizes (1/n) |X W - Y|^2 + alpha |W|^2.
#
# Re-write as |A W - B|^2:
#
# (1/n) |X W - Y|^2 + alpha |W|^2
#
# = |1/sqrt(n) (X W - Y)|^2
#   |   sqrt(alpha) W   |
#
# = |[ 1/sqrt(n) X ] W - [1/sqrt(n) Y]|^2
#   |[sqrt(alpha) I]     [     0     ]|

alpha = 1.0

n, d = x_train.shape
_, c = y_train.shape

a = torch.cat([(1 / math.sqrt(n)) * x_train,
               math.sqrt(alpha) * torch.eye(d)], dim=0)
b = torch.cat([(1 / math.sqrt(n)) * y_train,
               torch.zeros(d, c)], dim=0)

print('a:', a.shape)
print('b:', b.shape)

q, r = torch.linalg.qr(a)
result = torch.triangular_solve(q.T @ b, r)

weights = result.solution
print('weights:', weights.shape)

a: torch.Size([60784, 784])
b: torch.Size([60784, 10])
weights: torch.Size([784, 10])


In [None]:
# Check the condition number of A.
torch.linalg.cond(a)

tensor(11.4129)

In [None]:
# Evaluate model on training set.
pred_train = torch.argmax(x_train @ weights, axis=-1)
# Get the accuracy.
torch.mean((pred_train == labels_train).float())

tensor(0.8165)

In [None]:
# Ready to evaluate model on testing set.
# First construct tensors of vectorized images and labels.
images_test = torch.stack([x for x, y in testset])
labels_test = torch.tensor([y for x, y in testset])
x_test = torch.flatten(images_test, start_dim=-3)


In [None]:
# Then make predictions and check accuracy.
pred_test = torch.argmax(x_test @ weights, axis=-1)
torch.mean((pred_test == labels_test).float())

tensor(0.8242)

In [None]:
# Define functions to train and evaluate models.

def train_model(x, y, alpha):
  n, d = x.shape
  _, c = y.shape
  a = torch.cat([(1 / math.sqrt(n)) * x,
                math.sqrt(alpha) * torch.eye(d)], dim=0)
  b = torch.cat([(1 / math.sqrt(n)) * y,
                torch.zeros(d, c)], dim=0)

  q, r = torch.linalg.qr(a)
  result = torch.triangular_solve(q.T @ b, r)
  return result.solution

def evaluate_model(images, labels, weights):
  x = torch.flatten(images, start_dim=-3)
  scores = x @ weights
  pred = torch.argmax(scores, dim=-1)
  is_correct = (pred == labels)
  return torch.mean(is_correct.float()).item()

In [None]:
# Run training and evaluation for different alpha.

for alpha in [0] + [10 ** i for i in range(-3, 4)]:
  weights = train_model(x_train, y_train, alpha)
  acc_train = evaluate_model(images_train, labels_train, weights)
  acc_test = evaluate_model(images_test, labels_test, weights)
  print('alpha:', alpha)
  print(f'train acc {acc_train:.2%}, test acc {acc_test:.2%}')
  print()

alpha: 0
train acc 9.87%, test acc 9.80%

alpha: 0.001
train acc 85.70%, test acc 86.22%

alpha: 0.01
train acc 85.72%, test acc 86.45%

alpha: 0.1
train acc 85.26%, test acc 86.16%

alpha: 1
train acc 81.65%, test acc 82.42%

alpha: 10
train acc 72.47%, test acc 74.30%

alpha: 100
train acc 47.03%, test acc 49.39%

alpha: 1000
train acc 22.27%, test acc 23.14%

