In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

In [2]:
def euclidean_distance(point1, point2):
    return np.sqrt(np.sum((point1 - point2) ** 2))

In [3]:
def get_neighbors(train, test_row, k):
    distances = []
    for train_row in train:
        dist = euclidean_distance(test_row[:-1], train_row[:-1])
        distances.append((train_row, dist))
    distances.sort(key=lambda tup: tup[1])
    neighbors = []
    for i in range(k):
        neighbors.append(distances[i][0])
    return neighbors

In [4]:
def vote(neighbors):
    class_votes = {}
    for neighbor in neighbors:
        if neighbor[2] not in class_votes:
            class_votes[neighbor[2]] = 1
        else:
            class_votes[neighbor[2]] += 1
    sorted_votes = sorted(class_votes.items(), key=lambda x: x[1], reverse=True)
    return sorted_votes[0][0]

In [5]:
def predict_classification(train, test_row, k_num):
    neighbors = get_neighbors(train, test_row, k_num)
    output_values = [row[-1] for row in neighbors]
    prediction = max(set(output_values), key=output_values.count)
    return prediction


In [6]:
def knn(training_set, labels, test_instance, k):
    neighbors = get_neighbors(training_set, labels, test_instance, k)
    return vote(neighbors)

In [7]:
df = pd.read_csv('citrus.csv')
features = df[['diameter', 'weight', 'red', 'green', 'blue']].values
labels = df['name'].values
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.3)

train = np.column_stack((X_train, y_train))
test = np.column_stack((X_test, y_test))

k = 3

In [8]:
from sklearn.metrics import confusion_matrix, classification_report

predictions = []
for row in test:
    output = predict_classification(train, row, k)
    predictions.append(output)
    
actual = y_test
predicted = predictions
print(confusion_matrix(actual, predicted))
print(classification_report(actual, predicted))

[[1351  131]
 [ 125 1393]]
              precision    recall  f1-score   support

  grapefruit       0.92      0.91      0.91      1482
      orange       0.91      0.92      0.92      1518

    accuracy                           0.91      3000
   macro avg       0.91      0.91      0.91      3000
weighted avg       0.91      0.91      0.91      3000
