# K-Nearest Neighbors

This notebook demonstrates how to use the `KNN` module from the `rice2025.supervised_learning` library.  

## Setup
Import necessary modules and load data. For this example, the wine dataset from sklearn will be used. 

The Wine dataset is a small classification dataset that has:

- **Samples:** 178  
- **Features:** 13 numeric chemical properties of wines  
- **Classes:** 3 types of wine  

**Goal:** Predict the type of wine based on its chemical features.  

In [1]:
# import library
from rice2025.supervised_learning import knn
import rice2025.utilities as util

# load dataset
from sklearn.datasets import load_wine
data = load_wine()
X, y = data.data, data.target

## Data Pre-Processing
Before training, we split the dataset into **training** and **test** sets using `train_test_split`. We can verify the split by printing the lengths of each output dataset. Then, we can use the `scale` function to scale our data. 

In [2]:
# split dataset
X_train, X_test, y_train, y_test = util.train_test_split(X, y, test_size=.2)
print(f"Train size: {X_train.shape}, Test size: {X_test.shape}")

# scale dataset
X_train = util.scale(X_train)
X_test = util.scale(X_test)


Train size: (142, 13), Test size: (36, 13)


## Initializing and Training the KNN Model

The `KNN` class can be initialized by specifying `k` (number of neighbors).  
Use the `fit()` method to "train" the model on the training data. Beacuse KNN uses lazy learning, no training is actually done. 

In [3]:
model = knn.KNN(k=5)
model.fit(X_train, y_train)

## Making Predictions
Once the model is trained, the `predict()` method can be used to classify new data points.

In [4]:
y_pred = model.predict(X_test)

## Evaluating the Model

The model's performance can be measured using **accuracy** or a more detailed **classification report**.  
The `accuracy_score` and `classification_report` functions from scikit-learn can help measure performance.

In [5]:
from sklearn.metrics import accuracy_score, classification_report

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy on test set: {accuracy:.2f}")

# Detailed report
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=data.target_names))

Accuracy on test set: 0.94

Classification Report:
              precision    recall  f1-score   support

     class_0       0.85      1.00      0.92        11
     class_1       1.00      0.86      0.92        14
     class_2       1.00      1.00      1.00        11

    accuracy                           0.94        36
   macro avg       0.95      0.95      0.95        36
weighted avg       0.95      0.94      0.94        36

