In [3]:
import requests
import tarfile
import os
import pickle
import numpy as np

In [8]:
def download_data():
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    file_path = os.path.join("data", "cifar-10-python.tar.gz")

    # Download the dataset
    if not os.path.isfile(file_path):
        response = requests.get(url, stream=True)
        if not os.path.isdir("data/"):
            os.makedirs("data/")
        with open(file_path, "wb") as file:
            # receive 8kb chunks
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    file.write(chunk)
    else:
        # Extract the dataset and remove the tar file after extraction
        with tarfile.open(file_path, "r:gz") as tar:
            tar.extractall("data/CIFAR-10")
            os.remove(file_path)


def load_batch(file_path):
    with open(file_path, "rb") as file:
        batch = pickle.load(file, encoding="bytes")
        return batch


def load_CIFAR10(folder_path):
    train_batches = []
    test_batch = None

    if not os.path.isdir(folder_path):
        download_data()
        print("Downloaded CIFAR-10 dataset to: data/cifar-10-batches-py")
    
    for file_name in os.listdir(folder_path):
        # print(file_name)
        file_path = os.path.join(folder_path, file_name)
        if "data_batch" in file_name:
            train_batches.append(load_batch(file_path))
        elif "test_batch" in file_name:
            test_batch = load_batch(file_path)

    # print(train_batches)
    train_data = np.concatenate([batch[b"data"] for batch in train_batches])
    train_labels = np.concatenate([batch[b"labels"] for batch in train_batches])

    test_data = test_batch[b"data"]
    test_labels = np.array(test_batch[b"labels"])

    return train_data, train_labels, test_data, test_labels

In [9]:
dataset_folder = "data/CIFAR-10/cifar-10-batches-py"
Xtr, Ytr, Xte, Yte = load_CIFAR10(dataset_folder)

In [10]:
class KNearestNeighbor(object):
    def __init__(self):
        pass

    def train(self, X, Y):
        # the nearest neighbor classifier simply remembers all the training data
        self.Xtr = X
        self.ytr = Y

    def predict(self, X, k=1): 
        num_images = X.shape[0]
        Ypred = np.zeros(num_images, dtype=self.ytr.dtype)

        for i in range(num_images):
            # find the k-nearest training images to the i'th image
            # using the L1 distance (sum of absolute value differences)
            distances = np.sum(np.abs(self.Xtr - X[i, :]), axis=1) # broadcasting happens and X[i, :] is broadcasted to match (50000, 3072)
            # get the indices of the k smallest distances
            min_indices = np.argsort(distances)[:k]
            # get the labels of the k nearest examples
            k_nearest_labels = self.ytr[min_indices]
            # predict the label based on the majority class among k-nearest neighbors
            unique_labels, counts = np.unique(k_nearest_labels, return_counts=True)
            predicted_label = unique_labels[np.argmax(counts)]

            Ypred[i] = predicted_label

        return Ypred

In [33]:
# training
nn = KNearestNeighbor()
nn.train(Xtr, Ytr)
# takes roughly 25 minutes to make the predictions. 
# calculating L1 distance for every 10000 of the testing images is computationally expensive 
# 10000 * whatever time it takes to calculate the distance for a single image
prediction = nn.predict(Xte)
print("accuracy: %f" % (np.mean(prediction == Yte))) 
# untrained model = 10% accuracy (10 classes, 1/10). KNN = 25% accuracy, which is really bad

accuracy: 0.249200


In [12]:
# validation
Xval = Xtr[:1000, :]
Yval = Ytr[:1000]
Xtr = Xtr[1000:, :]
Ytr = Ytr[1000:]

validation_accuracies = []
for k in range(1, 10, 2):
        nn = KNearestNeighbor()
        nn.train(Xtr, Ytr)
        prediction = nn.predict(Xval, k=k)
        accuracy = np.mean(prediction == Yval)
        # print(f"accuracy: {accuracy}")
        validation_accuracies.append((k, accuracy))
for k, a in validation_accuracies:
        print(f"(K, Accuracy): {(k, a)}" )

(K, Accuracy): (1, 0.27)
(K, Accuracy): (3, 0.239)
(K, Accuracy): (5, 0.236)
(K, Accuracy): (7, 0.255)
(K, Accuracy): (9, 0.268)
