In [None]:
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt

# -----------------------------
# 1) Load MNIST
# -----------------------------
mnist = fetch_openml('mnist_784', version=1)
X, y = mnist["data"], mnist["target"].astype(np.uint8)

X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]

# -----------------------------
# 2) Stratified Cross-Validation + Grid Search
# -----------------------------
param_grid = {
    "n_neighbors": [3, 7, 15],
    "weights": ["uniform", "distance"],
    "metric": ["cosine", "euclidean", "manhattan"]
}

knn = KNeighborsClassifier()
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

grid = GridSearchCV(knn, param_grid, cv=cv, n_jobs=-1, verbose=1)
grid.fit(X_train, y_train)

best_knn = grid.best_estimator_
print("Best parameters:", grid.best_params_)
print("Best CV accuracy:", grid.best_score_)

# -----------------------------
# 3) Train final model on all training data
# -----------------------------
best_knn.fit(X_train, y_train)

# -----------------------------
# 4) Evaluate on test set
# -----------------------------
y_pred = best_knn.predict(X_test)

print("Test accuracy:", accuracy_score(y_test, y_pred))

cm = confusion_matrix(y_test, y_pred)
print("Confusion matrix:\n", cm)

# -----------------------------
# 5) Show classification report
# -----------------------------
print(classification_report(y_test, y_pred))

# -----------------------------
# 6) Visualize misclassified samples
# -----------------------------
mis_idx = np.where(y_pred != y_test)[0]
print("Number of misclassified images:", len(mis_idx))

def show_misclassified(n=6):
    samples = mis_idx[:n]
    plt.figure(figsize=(10, 4))
    for i, idx in enumerate(samples):
        img = X_test[idx].values.reshape(28, 28)
        plt.subplot(1, n, i+1)
        plt.imshow(img, cmap="gray")
        plt.title(f"True: {y_test[idx]}, Pred: {y_pred[idx]}")
        plt.axis("off")
    plt.show()

show_misclassified(6)
