In [435]:
DEBUG=False
if DEBUG:
  from PIL import Image
  import numpy as np

  def read_image(path):
    return np.asarray(Image.open(path).convert('L'))

  def write_image(image,path):
    img=Image.fromarray(np.array(image),'L')
    img.save(path)

In [436]:
TEST_DIR='test/'
TEST_DATA_FILENAME='t10k-images.idx3-ubyte'
TEST_LABELS_FILENAME='t10k-labels.idx1-ubyte'
TRAIN_DATA_FILENAME='train-images.idx3-ubyte'
TRAIN_LABELS_FILENAME='train-labels.idx1-ubyte'

In [437]:
def bytes_to_int(byte_data):
  return int.from_bytes(byte_data,'big')

In [438]:
def read_images(filename,n_max_images=None):
  images=[]
  with open(filename,'rb') as f:
    garbage=f.read(4)
    n_images=bytes_to_int(f.read(4))
    if n_max_images:
      n_images=n_max_images
    n_rows=bytes_to_int(f.read(4))
    n_columns=bytes_to_int(f.read(4))
    for image_idx in range(n_images):
      image=[]
      for row_idx in range(n_rows):
        row=[]
        for col_idx in range(n_columns):
          pixel=f.read(1)
          row.append(pixel)
        image.append(row)
      images.append(image)
  return images

In [439]:
def read_labels(filename,n_max_labels=None):
  labels=[]
  with open(filename,'rb') as f:
    garbage=f.read(4)
    n_labels=bytes_to_int(f.read(4))
    if n_max_labels:
      n_labels=n_max_labels
    for label_idx in range(n_labels):
      label=bytes_to_int(f.read(1))
      labels.append(label)
  return labels

In [440]:
def flatten_list(l):
  return [pixel for sublist in l for pixel in sublist]

In [441]:
def extract_features(X):
  return [flatten_list(sample) for sample in X]

In [442]:
def dist(x,y):
  return sum([(bytes_to_int(x_i)-bytes_to_int(y_i))**2 for x_i,y_i in zip(x,y)])**(0.5)

In [443]:
def get_training_distance_for_test_sample(X_train,test_sample):
  return [dist(train_sample,test_sample) for train_sample in X_train]

In [444]:
def get_most_frequent_element(l):
  return max(l,key=l.count)

In [445]:
def knn(X_train,X_test,y_train,k):
    y_pred=[]
    for test_sample_idx,test_sample in enumerate(X_test):
      training_distances=get_training_distance_for_test_sample(X_train,test_sample)
      sorted_distances=[
          pair[0]
          for pair in sorted(enumerate(training_distances),key=lambda x: x[1])
          ]
      candidates=[y_train[idx] for idx in sorted_distances[:k]]
      topCandidate=get_most_frequent_element(candidates)
      y_pred.append(topCandidate)
    return y_pred

In [446]:
X_train=read_images(TRAIN_DATA_FILENAME,10000)
y_train=read_labels(TRAIN_LABELS_FILENAME,10000)
X_test=read_images(TEST_DATA_FILENAME,5)
y_test=read_labels(TEST_LABELS_FILENAME,5)
if DEBUG:
  for idx,test_sample in enumerate(X_test):
    write_image(test_sample,f'{TEST_DIR}{idx}.png')
  X_test=[read_image(f'{TEST_DIR}our_test.png')]
  y_test=[8]

In [447]:
X_train=extract_features(X_train)
X_test=extract_features(X_test)

In [448]:
y_pred=knn(X_train,X_test,y_train,3)

In [449]:
accuracy=sum([
    int(y_pred_i==y_test_i)
    for y_pred_i,y_test_i
    in zip(y_pred,y_test)])/len(y_pred)

In [450]:
print(f'Predicted Label: {y_pred}')
if not(DEBUG):
    print(accuracy)

Predicted Label: [7, 2, 1, 0, 4]
1.0
