In [1]:
import time

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.metrics import accuracy_score

In [2]:
# download training data from open datasets
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# download test data from open datasets
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [3]:
# convert images to numpy arrays and flatten them from 28Ã—28 to vector of size 784
X_train = training_data.data.numpy().reshape(len(training_data), -1) / 255.0
y_train = training_data.targets.numpy()

X_test = test_data.data.numpy().reshape(len(test_data), -1) / 255.0
y_test = test_data.targets.numpy()

print("Train shape:", X_train.shape)
print("Test shape:", X_test.shape)

Train shape: (60000, 784)
Test shape: (10000, 784)


In [5]:
param_grid = {
    "n_neighbors": [1, 3, 5, 7, 9, 11]  # different values of k
}

knn = KNeighborsClassifier()

# 10 folds cross validation
grid = GridSearchCV(
    estimator=knn,
    param_grid=param_grid,
    cv=KFold(n_splits=10, shuffle=True, random_state=42),  # shuffle folds, so that each fold contains a proper mix of classes
    scoring="accuracy",
    n_jobs=-1,
    verbose=1,
)

In [6]:
print("Running cross-validation...")

t_cv = time.time()
grid.fit(X_train, y_train)
cv_time = time.time() - t_cv

print("\n===== Cross validation results =====")
print(f"\nBest params: {grid.best_params_}")
print(f"Best CV accuracy: {grid.best_score_:.4f}")
print(f"CV time: {cv_time:.2f} sec")

best_knn = grid.best_estimator_

# training on full dataset
t_train = time.time()
best_knn.fit(X_train, y_train)
train_time = time.time() - t_train

# evaluate on the test set
t_pred = time.time()
test_preds = best_knn.predict(X_test)
pred_time = time.time() - t_pred

test_accuracy = accuracy_score(y_test, test_preds)

print("\n===== Final results =====")
print(f"Test accuracy:  {test_accuracy:.4f}")
print(f"Train time:     {train_time:.4f} sec")
print(f"Predict time:   {pred_time:.4f} sec")

Running cross-validation...
Fitting 10 folds for each of 6 candidates, totalling 60 fits

===== Cross validation results =====

Best params: {'n_neighbors': 7}
Best CV accuracy: 0.8569
CV time: 112.87 sec

===== Final results =====
Test accuracy:  0.8540
Train time:     0.0114 sec
Predict time:   3.2942 sec
