In [1]:
import kagglehub

# Download CIFAR-10
# path = kagglehub.dataset_download("pankrzysiu/cifar10-python") + '/cifar-10-batches-py'
path = '/home/tibless/.cache/kagglehub/datasets/pankrzysiu/cifar10-python/versions/1' + '/cifar-10-batches-py'

print("Path to dataset files:", path)

Path to dataset files: /home/tibless/.cache/kagglehub/datasets/pankrzysiu/cifar10-python/versions/1/cifar-10-batches-py


In [2]:
import os
import time
import numpy as np

class2name = [
    'airplane',  # 0
    'automobile',  # 1
    'bird',  # 2
    'cat',  # 3
    'deer',  # 4
    'dog',  # 5
    'frog',  # 6
    'horse',  # 7
    'ship',  # 8
    'truck'  # 9
]

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def load_cifar10(path):
    x_train = []
    y_train = []
    
    for i in range(1, 6):
        file_path = os.path.join(path, f'data_batch_{i}')
        data_dict = unpickle(file_path)
        
        x_train.append(data_dict[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1))
        y_train += data_dict[b'labels']
    
    x_train = np.vstack(x_train)
    y_train = np.array(y_train)

    test_file_path = os.path.join(path, 'test_batch')
    test_dict = unpickle(test_file_path)
    
    x_test = test_dict[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
    y_test = np.array(test_dict[b'labels'])

    return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_cifar10(path)

print(f"Training data shape: {x_train.shape}, Training labels shape: {y_train.shape}")
print(f"Testing data shape: {x_test.shape}, Testing labels shape: {y_test.shape}")

Training data shape: (50000, 32, 32, 3), Training labels shape: (50000,)
Testing data shape: (10000, 32, 32, 3), Testing labels shape: (10000,)


In [3]:
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(42)

TRAIN = 50000
TEST = 1000
x_train = jnp.array(x_train[:TRAIN]).reshape(-1, 32 * 32 * 3) / 255.
y_train = jnp.array(y_train[:TRAIN])
x_test = jnp.array(x_test[:TEST]).reshape(-1, 32 * 32 * 3) / 255.
y_test = jnp.array(y_test[:TEST])

In [4]:
import jax.numpy as jnp
from jax import jit, vmap, lax

distance = jit(vmap(
                    vmap(
                        lambda x, y: jnp.sum(jnp.abs(x - y)), in_axes=(None, 0)
                    ), in_axes=(0, None)
                ))


class KNN:
    def __init__(self, k, num_class):
        self.k = k
        self.n = num_class

    def train(self, x_train, y_train):
        self.x = jnp.array(x_train)
        self.y = jnp.array(y_train)

    def predict(self, x_test, batch_size=1000):
        
        dismat = jnp.zeros((x_test.shape[0], self.x.shape[0]))

        # # Loop version
        # for i in range(0, self.x.shape[0], batch_size):
        #     s = time.time()
        #     dismat = dismat.at[:, i:i+batch_size].set(
        #         distance(x_test, self.x[i:i+batch_size])
        #     )
        #     print(f'{time.time() - s}')

        def step(carry, x_t):
            mat, ix = carry
            
            dis_batch = distance(x_test, x_t) 
            new_mat = lax.dynamic_update_slice(mat, dis_batch, (0, ix))
            return (new_mat, ix + x_t.shape[0]), ()

        batches = jnp.array([self.x[ix: ix+batch_size] for ix in range(0, self.x.shape[0], batch_size)])
        
        (dismat, _), _ = lax.scan(step, (dismat, 0), batches)
        
        def proba(d):
            kns = self.y[jnp.argsort(d)[:self.k]]
            cnt = jnp.bincount(kns, length=self.n)
            return jnp.argmax(cnt / self.k)

        get_proba = vmap(proba, in_axes=(0,))

        res = get_proba(dismat)
        return res

In [5]:
knn = KNN(k=10, num_class=10)
knn.train(x_train, y_train)

acc = jit(lambda x, y: jnp.mean(x == y))

In [7]:
import time

s = time.time()
y_pred = knn.predict(x_test)
print()
print(f'time: {time.time() - s} s')

print(f'acc: {acc(y_test, y_pred)} s')


time: 0.43906235694885254 s
acc: 0.38200002908706665 s


![knn](./assets/knn.png)