In [27]:
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report, accuracy_score
import joblib

## Load Features


In [28]:
X_train = np.load("../features/X_train_scaled.npy")
X_test = np.load("../features/X_test_scaled.npy")
y_train = np.load("../features/y_train.npy")
y_test = np.load("../features/y_test.npy")

## Global Parameters



In [29]:
param_grid = {
    "n_neighbors": [5, 7, 11, 15],
    "weights": ["distance"],
    "metric": [ "euclidean"]
}

In [30]:
CLASS_MAPPING = {
    "cardboard": 0,
    "glass": 1,
    "metal": 2,
    "paper": 3,
    "plastic": 4,
    "trash": 5,
    "unknown": 6
}


## Model Training


In [31]:

grid = GridSearchCV(
    KNeighborsClassifier(),
    param_grid,
    cv=5,
    scoring="accuracy",
    n_jobs=-1
)

grid.fit(X_train, y_train)

knn = grid.best_estimator_
print("Best KNN params:", grid.best_params_)


Best KNN params: {'metric': 'euclidean', 'n_neighbors': 5, 'weights': 'distance'}


## Model Evaluation

In [32]:
y_pred = knn.predict(X_test)

print("\nAccuracy:", accuracy_score(y_test, y_pred))
print("\nClassification Report:\n", classification_report(y_test, y_pred))



Accuracy: 0.8870967741935484

Classification Report:
               precision    recall  f1-score   support

   cardboard       0.93      0.88      0.91        49
       glass       0.86      0.88      0.87        78
       metal       0.85      0.89      0.87        62
       paper       0.90      0.89      0.89        90
     plastic       0.93      0.89      0.91        71
       trash       0.83      0.91      0.87        22

    accuracy                           0.89       372
   macro avg       0.88      0.89      0.89       372
weighted avg       0.89      0.89      0.89       372



## Save Model


In [33]:
joblib.dump(knn, "../models/knn_model.pkl")



['../models/knn_model.pkl']

## Model Prediction with Rejection


In [34]:
def knn_predict_with_rejection(model, X, threshold=0.6):
    """
    Predict classes using KNN with rejection.
    Confidence = fraction of neighbors voting for predicted class.
    """
    distances, indices = model.kneighbors(X)
    neighbor_labels = model._y[indices]

    preds = model.predict(X)
    final_preds = []

    for i, pred in enumerate(preds):
        votes = np.sum(neighbor_labels[i] == pred)
        confidence = votes / model.n_neighbors

        if confidence < threshold:
            final_preds.append(6)  # Unknown
        else:
            final_preds.append(CLASS_MAPPING[pred])

    return np.array(final_preds)



Accuracy (accepted samples): 0.8870967741935484

Classification Report (accepted samples):
               precision    recall  f1-score   support

   cardboard       0.93      0.88      0.91        49
       glass       0.86      0.88      0.87        78
       metal       0.85      0.89      0.87        62
       paper       0.90      0.89      0.89        90
     plastic       0.93      0.89      0.91        71
       trash       0.83      0.91      0.87        22

    accuracy                           0.89       372
   macro avg       0.88      0.89      0.89       372
weighted avg       0.89      0.89      0.89       372



In [35]:
knn_preds = knn_predict_with_rejection(knn, X_test, threshold=0.6)
print("\nAccuracy (accepted samples):", accuracy_score(y_test, y_pred))
print("\nClassification Report (accepted samples):\n", classification_report(y_test, y_pred))


Accuracy (accepted samples): 0.8870967741935484

Classification Report (accepted samples):
               precision    recall  f1-score   support

   cardboard       0.93      0.88      0.91        49
       glass       0.86      0.88      0.87        78
       metal       0.85      0.89      0.87        62
       paper       0.90      0.89      0.89        90
     plastic       0.93      0.89      0.91        71
       trash       0.83      0.91      0.87        22

    accuracy                           0.89       372
   macro avg       0.88      0.89      0.89       372
weighted avg       0.89      0.89      0.89       372

