In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.ensemble import RandomForestClassifier
import pandas as pd
import os
import joblib # Import joblib for saving and loading models

# Create a directory to save plots and models if it doesn't exist
OUTPUT_DIR = "model_output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
PLOT_DIR = os.path.join(OUTPUT_DIR, "classification_plots")
os.makedirs(PLOT_DIR, exist_ok=True)
MODEL_DIR = os.path.join(OUTPUT_DIR, "saved_models")
os.makedirs(MODEL_DIR, exist_ok=True)

# Define model file paths
KNN_MODEL_PATH = os.path.join(MODEL_DIR, "knn_model.pkl")
RANDOM_FOREST_MODEL_PATH = os.path.join(MODEL_DIR, "random_forest_model.pkl")

# --- 1. Fetching and preparing the MNIST dataset ---
print("Fetching MNIST dataset...")
mist = fetch_openml('mnist_784', as_frame=True, parser='auto')

X, y = mist.data, mist.target

# Convert pixel data to integers (they are typically float64 by default from fetch_openml)
X = X.astype(np.uint8)

# Convert target labels to integer type for consistent classification
y = y.astype(np.uint8)

# --- 2. Visualizing a digit (and saving it) ---
def plot_digit(image_data, filename="digit.png", title_prefix="Digit"):
    """
    Plots a single MNIST digit from a flattened array and saves it to a file.

    Args:
        image_data (numpy.ndarray or pandas.Series): A 1D array/Series of 784 pixel values.
        filename (str): The name of the file to save the plot.
        title_prefix (str): Prefix for the plot title (e.g., "Digit", "Noisy Digit").
    """
    if isinstance(image_data, pd.Series):
        image_data = image_data.to_numpy()

    image = image_data.reshape(28, 28)
    plt.figure(figsize=(4, 4))
    plt.imshow(image, cmap='binary')
    plt.axis('off')
    plt.title(f"{title_prefix}") # Simplified title
    plt.savefig(os.path.join(PLOT_DIR, filename))
    plt.close()

some_digit = X.iloc[0]
print(f"Label for the first digit: {y.iloc[0]}")
plot_digit(some_digit, filename="first_digit_example.png", title_prefix=f"Example Digit: {y.iloc[0]}")
print(f"Saved example digit plot to {os.path.join(PLOT_DIR, 'first_digit_example.png')}")

# --- 3. Splitting the dataset into training and testing sets ---
x_train, x_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
print(f"Training set size: {len(x_train)} samples")
print(f"Test set size: {len(x_test)} samples")

# --- 4. Training a K-Nearest Neighbors (KNN) Classifier ---
print("\n--- Training K-Nearest Neighbors (KNN) Classifier ---")
knn_clf = KNeighborsClassifier(n_neighbors=5)
knn_clf.fit(x_train, y_train)

# Save the trained KNN model
joblib.dump(knn_clf, KNN_MODEL_PATH)
print(f"KNN model saved to {KNN_MODEL_PATH}")

knn_prediction_first_digit = knn_clf.predict([some_digit.to_numpy()])
print(f"KNN prediction for the first digit: {knn_prediction_first_digit[0]}")

# --- 5. Evaluating the KNN Classifier on Test Set ---
print("\n--- Evaluating KNN Classifier on Test Set ---")
y_test_pred_knn = knn_clf.predict(x_test)

overall_accuracy_knn = accuracy_score(y_test, y_test_pred_knn)
print(f"Overall Accuracy for KNN on Test Set: {overall_accuracy_knn:.4f}")

print("\nPer-number Accuracy for KNN on Test Set:")
report_knn = classification_report(y_test, y_test_pred_knn, target_names=[str(i) for i in range(10)], output_dict=True)
for digit in range(10):
    if str(digit) in report_knn:
        print(f"  Accuracy for digit '{digit}': {report_knn[str(digit)]['recall']:.4f}")
    else:
        print(f"  Digit '{digit}' not found in test set or predictions.")

# --- 6. Adding Random Noise and Denoising with KNN ---
print("\n--- Denoising with KNN ---")
x_train_mod = x_train.copy()
x_test_mod = x_test.copy()

noise = np.random.randint(0, 100, (len(x_train_mod), 784))
x_train_mod = (x_train_mod.to_numpy().astype(np.float32) + noise).astype(np.uint8)

noise = np.random.randint(0, 100, (len(x_test_mod), 784))
x_test_mod = (x_test_mod.to_numpy().astype(np.float32) + noise).astype(np.uint8)

y_train_mod = y_train.copy()
y_test_mod = y_test.copy()

print("Noisy digit (x_test_mod[0]):")
plot_digit(x_test_mod[0], filename="noisy_digit_example.png", title_prefix="Noisy Digit")
print(f"Saved noisy digit plot to {os.path.join(PLOT_DIR, 'noisy_digit_example.png')}")

knn_clf_denoise = KNeighborsClassifier(n_neighbors=5)
knn_clf_denoise.fit(x_train_mod, y_train_mod)

clean_digit_prediction = knn_clf_denoise.predict([x_test_mod[0]])[0]
print(f"Predicted clean digit label from noisy image: {clean_digit_prediction}")

# --- 7. Random Forest Classifier for Multiclass Classification ---
print("\n--- Training Random Forest Classifier ---")
forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)
forest_clf.fit(x_train, y_train)

# Save the trained Random Forest model
joblib.dump(forest_clf, RANDOM_FOREST_MODEL_PATH)
print(f"Random Forest model saved to {RANDOM_FOREST_MODEL_PATH}")

# --- 8. Evaluating the Random Forest Classifier on Test Set ---
print("\n--- Evaluating Random Forest Classifier on Test Set ---")
y_test_pred_forest = forest_clf.predict(x_test)

overall_accuracy_forest = accuracy_score(y_test, y_test_pred_forest)
print(f"Overall Accuracy for Random Forest on Test Set: {overall_accuracy_forest:.4f}")

print("\nPer-number Accuracy for Random Forest on Test Set:")
report_forest = classification_report(y_test, y_test_pred_forest, target_names=[str(i) for i in range(10)], output_dict=True)
for digit in range(10):
    if str(digit) in report_forest:
        print(f"  Accuracy for digit '{digit}': {report_forest[str(digit)]['recall']:.4f}")
    else:
        print(f"  Digit '{digit}' not found in test set or predictions.")

# --- 9. Prediction from saved models (User Input) ---
print("\n--- Prediction from Saved Models ---")

def get_user_input_digit():
    """
    Prompts the user to enter 784 comma-separated pixel values or a file path.
    Validates and converts the input into a NumPy array suitable for prediction.
    """
    while True:
        user_input = input(
            "\nEnter 784 comma-separated pixel values (0-255) for a digit, "
            "or type 'file <path_to_text_file>' to load from a file: "
        ).strip()

        if user_input.lower().startswith("file "):
            file_path = user_input[5:].strip()
            try:
                with open(file_path, 'r') as f:
                    pixel_str = f.read().strip()
                print(f"Loading pixels from file: {file_path}")
            except FileNotFoundError:
                print(f"Error: File not found at {file_path}. Please try again.")
                continue
            except Exception as e:
                print(f"Error reading file: {e}. Please try again.")
                continue
        else:
            pixel_str = user_input

        try:
            pixels = [int(p.strip()) for p in pixel_str.split(',') if p.strip()]
            if len(pixels) != 784:
                print(f"Error: Expected 784 pixel values, but got {len(pixels)}. Please try again.")
                continue
            if not all(0 <= p <= 255 for p in pixels):
                print("Error: Pixel values must be between 0 and 255. Please try again.")
                continue
            return np.array(pixels, dtype=np.uint8).reshape(1, -1) # Reshape to (1, 784) for single prediction
        except ValueError:
            print("Error: Invalid input. Please enter 784 comma-separated integers. Try again.")
        except Exception as e:
            print(f"An unexpected error occurred: {e}. Please try again.")

# Load the trained models
try:
    loaded_knn_model = joblib.load(KNN_MODEL_PATH)
    print(f"Successfully loaded KNN model from {KNN_MODEL_PATH}")
except FileNotFoundError:
    print(f"Error: KNN model file not found at {KNN_MODEL_PATH}. Please run the training section first.")
    loaded_knn_model = None
except Exception as e:
    print(f"Error loading KNN model: {e}")
    loaded_knn_model = None

try:
    loaded_forest_model = joblib.load(RANDOM_FOREST_MODEL_PATH)
    print(f"Successfully loaded Random Forest model from {RANDOM_FOREST_MODEL_PATH}")
except FileNotFoundError:
    print(f"Error: Random Forest model file not found at {RANDOM_FOREST_MODEL_PATH}. Please run the training section first.")
    loaded_forest_model = None
except Exception as e:
    print(f"Error loading Random Forest model: {e}")
    loaded_forest_model = None

if loaded_knn_model and loaded_forest_model:
    while True:
        input_digit_data = get_user_input_digit()
        if input_digit_data is None: # User chose to exit or input was invalid
            break

        # Make predictions
        knn_prediction = loaded_knn_model.predict(input_digit_data)[0]
        forest_prediction = loaded_forest_model.predict(input_digit_data)[0]

        print(f"\nPrediction using Loaded KNN Model: {knn_prediction}")
        print(f"Prediction using Loaded Random Forest Model: {forest_prediction}")

        # Optionally, plot the input digit
        plot_digit(input_digit_data[0], filename="user_input_digit.png", title_prefix=f"User Input Digit (Predicted KNN: {knn_prediction}, RF: {forest_prediction})")
        print(f"Saved user input digit plot to {os.path.join(PLOT_DIR, 'user_input_digit.png')}")

        another_prediction = input("Do you want to predict another digit? (yes/no): ").strip().lower()
        if another_prediction != 'yes':
            break
else:
    print("\nModels were not loaded successfully. Cannot proceed with predictions.")

print("\nScript finished.")


Fetching MNIST dataset...
Label for the first digit: 5
Saved example digit plot to model_output\classification_plots\first_digit_example.png
Training set size: 60000 samples
Test set size: 10000 samples

--- Training K-Nearest Neighbors (KNN) Classifier ---
KNN model saved to model_output\saved_models\knn_model.pkl




KNN prediction for the first digit: 5

--- Evaluating KNN Classifier on Test Set ---
Overall Accuracy for KNN on Test Set: 0.9688

Per-number Accuracy for KNN on Test Set:
  Accuracy for digit '0': 0.9939
  Accuracy for digit '1': 0.9982
  Accuracy for digit '2': 0.9603
  Accuracy for digit '3': 0.9663
  Accuracy for digit '4': 0.9613
  Accuracy for digit '5': 0.9664
  Accuracy for digit '6': 0.9864
  Accuracy for digit '7': 0.9611
  Accuracy for digit '8': 0.9374
  Accuracy for digit '9': 0.9534

--- Denoising with KNN ---
Noisy digit (x_test_mod[0]):
Saved noisy digit plot to model_output\classification_plots\noisy_digit_example.png
Predicted clean digit label from noisy image: 7

--- Training Random Forest Classifier ---
Random Forest model saved to model_output\saved_models\random_forest_model.pkl

--- Evaluating Random Forest Classifier on Test Set ---
Overall Accuracy for Random Forest on Test Set: 0.9705

Per-number Accuracy for Random Forest on Test Set:
  Accuracy for digit '0