In [1]:
import kagglehub

# Download CIFAR-10
path = kagglehub.dataset_download("pankrzysiu/cifar10-python") + '/cifar-10-batches-py'

print("Path to dataset files:", path)

Path to dataset files: /home/tibless/.cache/kagglehub/datasets/pankrzysiu/cifar10-python/versions/1/cifar-10-batches-py


In [2]:
import os
import time
import numpy as np

class2name = [
    'airplane',  # 0
    'automobile',  # 1
    'bird',  # 2
    'cat',  # 3
    'deer',  # 4
    'dog',  # 5
    'frog',  # 6
    'horse',  # 7
    'ship',  # 8
    'truck'  # 9
]

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def load_cifar10(path):
    x_train = []
    y_train = []
    
    for i in range(1, 6):
        file_path = os.path.join(path, f'data_batch_{i}')
        data_dict = unpickle(file_path)
        
        x_train.append(data_dict[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1))
        y_train += data_dict[b'labels']
    
    x_train = np.vstack(x_train)
    y_train = np.array(y_train)

    test_file_path = os.path.join(path, 'test_batch')
    test_dict = unpickle(test_file_path)
    
    x_test = test_dict[b'data'].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
    y_test = np.array(test_dict[b'labels'])

    return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_cifar10(path)

print(f"Training data shape: {x_train.shape}, Training labels shape: {y_train.shape}")
print(f"Testing data shape: {x_test.shape}, Testing labels shape: {y_test.shape}")

Training data shape: (50000, 32, 32, 3), Training labels shape: (50000,)
Testing data shape: (10000, 32, 32, 3), Testing labels shape: (10000,)


In [3]:
TRAIN = 50000
TEST = 10000
x_train = x_train[:TRAIN].reshape(-1, 32 * 32 * 3)
y_train = y_train[:TRAIN]
x_test = x_test[:TEST].reshape(-1, 32 * 32 * 3)
y_test = y_test[:TEST]

In [4]:
import jax.numpy as jnp
from jax import random, jit, vmap

class KNNClf:
    def __init__(self, k=1, d='euclid', num_class=10, batch_size=(128, 2048)):
        self.k = k
        if k < 1:
            raise ValueError(f'[x] k should be a number >=1, but get {self.k}')
        self.d = d
        self.batch_size = batch_size
        
        # 根据距离度量选择相应的函数
        if d == 'euclid':
            self.distance = self.__euclid_distance
        elif d == 'manhattan':
            self.distance = self.__manhattan_distance
        elif d == 'cosine':
            self.distance = self.__cosine_distances
        elif d == 'chebyshev':
            self.distance = self.__chebyshev_distances
        else:
            print('[!] d should be euclid, manhattan, cosine or chebyshev !')
            print('[!] use default p: euclid')
            self.distance = self.__euclid_distance

        self.num_class = 10
        self.training_time = None
        self.testing_time = None
        self.n_k_neighbors = None

    @staticmethod
    def __to_jnp(x):
        return jnp.array(x, dtype=jnp.float32)

    @staticmethod
    def __euclid_distance(x, y):
        dists = jnp.sqrt(
            jnp.sum(y**2, axis=1, keepdims=True) + jnp.sum(x**2, axis=1) - 2 * jnp.dot(y, x.T)
        )
        return dists

    @staticmethod
    def __manhattan_distance(x, y):
        dists = jnp.sum(jnp.abs(y[:, None] - x), axis=2)
        return dists

    @staticmethod
    def __chebyshev_distances(x, y):
        dists = jnp.max(jnp.abs(y[:, None] - x), axis=2)
        return dists

    @staticmethod
    def __cosine_distances(x, y):
        x_normalized = x / jnp.linalg.norm(x, ord=2, axis=1, keepdims=True)
        y_normalized = y / jnp.linalg.norm(y, ord=2, axis=1, keepdims=True)
        similarity = jnp.dot(y_normalized, x_normalized.T)
        dists = 1 - similarity
        return dists

    def fit(self, X_train, y_train):
        start = time.time()
        self.x = self.__to_jnp(X_train)
        self.y = jnp.array(y_train.reshape(-1), dtype=jnp.int32)
        classes, self.static = jnp.unique(self.y, return_counts=True)
        
        if classes[0] != 0:
            raise ValueError('[x] Make sure y is start form 0 !')

        self.static = self.static / self.static.sum()
        self.training_time = time.time() - start

   
    def predict_proba(self, x_test):
        x_test = self.__to_jnp(x_test)

        @jit
        def calculate_proba(n_k_neighbors, y):
            batch_size, k = n_k_neighbors.shape
            neighbor_labels = y[n_k_neighbors.flatten()].reshape(batch_size, k)
            
            def count_labels(labels):
                one_hot = jnp.eye(self.num_class)[labels]
                return jnp.sum(one_hot, axis=0)
        
            counts = vmap(count_labels)(neighbor_labels)
            proba_batch = counts / k
            return proba_batch            

        start = time.time()
        proba = jnp.zeros((x_test.shape[0], self.static.size))

        self.n_k_neighbors = []
        for i in range(0, x_test.shape[0], self.batch_size[0]):
            batch_x_test = x_test[i:i + self.batch_size[0]]
            distance = jnp.zeros((batch_x_test.shape[0], self.x.shape[0]))

            for j in range(0, self.x.shape[0], self.batch_size[1]):
                batch_x_train = self.x[j:j + self.batch_size[1]]
                dist_batch = self.distance(batch_x_train, batch_x_test)
                distance = distance.at[:, j:j + dist_batch.shape[1]].set(dist_batch)

            n_k_neighbors = jnp.argsort(distance, axis=1)[:, :self.k]
            self.n_k_neighbors.append(n_k_neighbors)
            proba = proba.at[i:i + batch_x_test.shape[0], :].set(calculate_proba(n_k_neighbors, self.y))
 
        self.testing_time = time.time() - start
        
        print('calcu end.')
        n_k_neighbors = jnp.concatenate(self.n_k_neighbors, axis=0)
        self.n_k_neighbors = np.asarray(n_k_neighbors)
        print('predict end.')
        return np.asarray(proba)

    def predict(self, x_test):
        proba = self.predict_proba(x_test)
        diff = proba - np.asarray(self.static)
        return np.argmax(diff, axis=1)

    def get_testing_time(self):
        return self.testing_time

    def get_training_time(self):
        return self.training_time

knn = KNNClf(k=10, num_class=10, d='euclid', batch_size=(x_test.shape[0], x_train.shape[0])) 
knn.fit(x_train, y_train)
y_pred = knn.predict(x_test)
print(f'acc: {(y_pred == y_test).mean()}')
print(f'time: {knn.get_testing_time()}')

calcu end.
predict end.
acc: 0.3386
time: 0.37079310417175293
