In [1]:
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import cifar10

# ------------------------------
# Clase k-NN con distancia L2
# ------------------------------
class NearestNeighborL2:
    def __init__(self):
        pass

    def train(self, X, y):
        self.Xtr = X
        self.ytr = y

    def predict(self, X, k=1):
        num_test = X.shape[0]
        Ypred = np.zeros(num_test, dtype=self.ytr.dtype)

        for i in range(num_test):
            distances = np.sqrt(np.sum(np.square(self.Xtr - X[i, :]), axis=1))
            min_indices = np.argsort(distances)[:k]
            closest_y = self.ytr[min_indices]
            counts = np.bincount(closest_y)
            Ypred[i] = np.argmax(counts)
        return Ypred

# ------------------------------
# Cargar datos desde keras
# ------------------------------
(Xtr, Ytr), (Xte, Yte) = cifar10.load_data()
Ytr = Ytr.flatten()
Yte = Yte.flatten()

# Aplanar imágenes
Xtr_rows = Xtr.reshape(Xtr.shape[0], -1)  # 50000 x 3072
Xte_rows = Xte.reshape(Xte.shape[0], -1)  # 10000 x 3072

# Separar conjunto de validación
Xval_rows = Xtr_rows[:1000, :]
Yval = Ytr[:1000]
Xtr_rows = Xtr_rows[1000:, :]
Ytr = Ytr[1000:]

# ------------------------------
# Búsqueda del mejor k
# ------------------------------
validation_accuracies = []
for k in [1, 3, 5, 10, 20, 50, 100]:
    nn = NearestNeighborL2()
    nn.train(Xtr_rows, Ytr)
    Yval_predict = nn.predict(Xval_rows, k=k)
    acc = np.mean(Yval_predict == Yval)
    print(f'k = {k}, validation accuracy: {acc:.4f}')
    validation_accuracies.append((k, acc))

# Elegir mejor k
best_k = max(validation_accuracies, key=lambda x: x[1])[0]
print(f'\nMejor k encontrado: {best_k}')

# Evaluar en conjunto de prueba
nn = NearestNeighborL2()
nn.train(Xtr_rows, Ytr)
Yte_predict = nn.predict(Xte_rows, k=best_k)
print(f'Test accuracy: {np.mean(Yte_predict == Yte):.4f}')

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 0us/step
k = 1, validation accuracy: 0.2650
k = 3, validation accuracy: 0.2290
k = 5, validation accuracy: 0.2450
k = 10, validation accuracy: 0.2360
k = 20, validation accuracy: 0.2410
k = 50, validation accuracy: 0.2340
k = 100, validation accuracy: 0.2410

Mejor k encontrado: 1
Test accuracy: 0.2521
