In [None]:
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (1.0,))])

# Training and testing datasets.
trainset = datasets.MNIST(
    root='../data', train=True,
    download=True, transform=transform)
testset = datasets.MNIST(
    root='../data', train=False,
    download=True, transform=transform)

In [None]:
# Visualize an example.
example = trainset[0]
x, y = example
plt.imshow(torch.squeeze(x))
print(y)

In [None]:
# Example of list comprehension.
elems = [0, 1, 2, 3, 4, 5]
[x ** 2 for x in elems]

In [None]:
images_train = torch.stack([x for x, y in trainset])
labels_train = torch.tensor([y for x, y in trainset])
images_test = torch.stack([x for x, y in testset])
labels_test = torch.tensor([y for x, y in testset])
print('images_train:', images_train.shape)
print('labels_train:', labels_train.shape)

In [None]:
# Flatten last 3 dimensions to obtain a vector.
x_train = torch.flatten(images_train, start_dim=-3)
x_test = torch.flatten(images_test, start_dim=-3)
print('x_train:', x_train.shape)
print('x_test:', x_test.shape)

In [None]:
# Obtain 2-norm distance between each pair of examples.

# Use identity
#   |u - v|^2 = u'u - 2 u'v + v'v
# to avoid constructing [n_train, n_test, n_feat] array.

# Compute dot product between each pair of (train, test) examples.
dot = torch.einsum('id,jd->ij', x_train, x_test)
print(dot.shape)

# This is equivalent to the following code (which is too slow):
#
# n_train = x_train.shape[0]
# n_test = x_test.shape[0]
# dot = torch.zeros([n_train, n_test])
# for i in range(n_train):
#   for j in range(n_test):
#     dot[i, j] = torch.dot(x_train[i], x_train[j])

In [None]:
# Add an extra dimension such that norms align with dot.
norm_train = torch.sum(x_train ** 2, dim=1).unsqueeze(dim=1)
norm_test = torch.sum(x_test ** 2, dim=1).unsqueeze(dim=0)
print('norm_train:', norm_train.shape)
print('norm_test:', norm_test.shape)

In [None]:
dist_euc = norm_train + norm_test - 2 * dot
print('dist_euc:', dist_euc.shape)

In [None]:
# Find nearest neighbor for each testing example.
index_nearest = torch.argmin(dist_euc, dim=0)
# Take label of nearest training example as the prediction.
pred = labels_train[index_nearest]

# Check the accuracy of our predictions!
torch.mean((pred == labels_test).float())

In [None]:
# Try taking majority of k nearest neighbors.
k = 3
_, index_neighbors = torch.topk(dist_euc, k, largest=False, dim=0)
print('index_neighbors:', index_neighbors.shape)

# Take sum over one-hot representation to obtain per-class counts.
y_train = F.one_hot(labels_train).float()
print('y_train:', y_train.shape)
y_neighbors = y_train[index_neighbors]
print('y_neighbors:', y_neighbors.shape)

In [None]:
freq_neighbors = torch.sum(y_neighbors, dim=0)
print('freq_neighbors:', freq_neighbors.shape)

In [None]:
# Take the most frequent class amongst neighbors as the prediction.

# May be necessary to break ties. Add 0.5^i for the i-th nearest neighbor.
weight = 0.5 ** (1 + torch.arange(k))
print('weight:', weight)

tie_break = torch.tensordot(weight, y_neighbors, dims=1)
torch.max(tie_break)  # Should be < 1 to avoid overwhelming frequency.

In [None]:
score = freq_neighbors + tie_break
pred = torch.argmax(score, dim=1)

# Check the accuracy of our predictions!
torch.mean((pred == labels_test).float())

In [None]:
# Allow python to return memory to the system.
del dot, dist_euc

In [None]:
# Put it all together in a function.

def predict_knn(x_train, y_train, x_test, k, chunk_size=1000):
  # Use cat(map(f, split(x))) to avoid crashing kernel due to RAM usage.
  # This is necessary to evaluate on the training set (60k examples).
  # index_neighbors = find_neighbors(x_train, x_test, k)
  index_neighbors = torch.cat([
    find_neighbors(x_train, x, k)
    for x in torch.split(x_test, chunk_size)
  ], dim=1)
  y_neighbors = y_train[index_neighbors]
  freq_neighbors = torch.sum(y_neighbors, dim=0)
  weight = 0.5 ** (1 + torch.arange(k))
  tie_break = torch.tensordot(weight, y_neighbors, dims=1)
  score = freq_neighbors + tie_break
  return torch.argmax(score, dim=1)

def find_neighbors(x_train, x_test, k):
  dot = torch.einsum('id,jd->ij', x_train, x_test)
  norm_train = torch.sum(x_train ** 2, dim=1).unsqueeze(dim=1)
  norm_test = torch.sum(x_test ** 2, dim=1).unsqueeze(dim=0)
  dist_euc = norm_train + norm_test - 2 * dot
  _, neighbors = torch.topk(dist_euc, k, largest=False, sorted=True, dim=0)
  return neighbors

In [None]:
# Evaluate nearest neighbor for varying k.

for k in [1, 3, 10, 100]:
  pred_test = predict_knn(x_train, y_train, x_test, k)
  acc_test = torch.mean((pred_test == labels_test).float()).item()
  print('k:', k)
  print(f'test acc {acc_test:.2%}')
  print()

In [None]:
# Do the same thing for the training set (slow!!)

for k in [1, 3, 5, 100]:
  pred_train = predict_knn(x_train, y_train, x_train, k)
  acc_train = torch.mean((pred_train == labels_train).float()).item()
  print('k:', k)
  print(f'train acc {acc_train:.2%}')
  print()