In [1]:
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split, GridSearchCV

# Load the MNIST dataset
mnist = fetch_openml('mnist_784')
X, y = mnist["data"], mnist["target"]

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Define the parameter grid for hyperparameter tuning
param_grid = {
    'n_neighbors': [3, 5, 7],
    'weights': ['uniform', 'distance']
}

# Create the KNN classifier
knn = KNeighborsClassifier()

# Create a GridSearchCV object to find the best hyperparameters
grid_search = GridSearchCV(knn, param_grid, cv=5, verbose=2)

# Fit the GridSearchCV object on the training data
grid_search.fit(X_train, y_train)

# Print the best hyperparameters and the corresponding accuracy score
print(f"Best parameters: {grid_search.best_params_}")
print(f"Best accuracy score: {grid_search.best_score_}")

# Evaluate the best estimator on the testing data
best_estimator = grid_search.best_estimator_
accuracy = best_estimator.score(X_test, y_test)
print(f"Accuracy: {accuracy}")

# Predict a sample image
sample = X_test[0].reshape(1, -1)
prediction = best_estimator.predict(sample)
print(f"Prediction: {prediction}")


Fitting 5 folds for each of 6 candidates, totalling 30 fits
[CV] END .....................n_neighbors=3, weights=uniform; total time=  24.3s
[CV] END .....................n_neighbors=3, weights=uniform; total time=  20.1s
[CV] END .....................n_neighbors=3, weights=uniform; total time=  20.9s
[CV] END .....................n_neighbors=3, weights=uniform; total time=  19.5s
[CV] END .....................n_neighbors=3, weights=uniform; total time=  20.5s
[CV] END ....................n_neighbors=3, weights=distance; total time=  21.7s
[CV] END ....................n_neighbors=3, weights=distance; total time=  19.7s
[CV] END ....................n_neighbors=3, weights=distance; total time=  19.4s
[CV] END ....................n_neighbors=3, weights=distance; total time=  18.9s
[CV] END ....................n_neighbors=3, weights=distance; total time=  19.4s
[CV] END .....................n_neighbors=5, weights=uniform; total time=  23.5s
[CV] END .....................n_neighbors=5, weig

KeyError: 0