# KNN From Scratch

In [1]:
import random
from sklearn import datasets  # for sample dataset
from math import sqrt  # for euclidean distance func
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix



In [2]:
class KNN_Classifier:
    
    def fit(self, X_train, y_train, k=3):
        self.X_train = X_train
        self.y_train = y_train
        self.k = k
    
    def predict(self, X_test):
        predictions = []
        for row in X_test:
            # Get the k nearest neighbours
            labels = self.__get_nearest_neighbours(row)
            # Select appropriate label
            prediction = self.__get_label(labels)
            predictions.append(prediction)
        return predictions
    
    def __get_nearest_neighbours(self, row):
        distances = list()
        # enurmerate and get all distances
        for i in range(len(self.X_train)):
            distances.append([self.__euc(X_train[i], row), self.y_train[i]])
        distances.sort()
        return [distances[0][1], distances[1][1], distances[2][1]]
    
    def __get_label(self, labels):
        count = dict()
        for item in labels:
            try:
                count[item] = 1
            except:
                count[item] += 1
                
        v=list(count.values())
        k=list(count.keys())
        return k[v.index(max(v))]
    
    def __euc(self, x, y):
        if len(x) != len(y):
            return False

        total_dist = 0
        for i in range(len(x)):
            total_dist += (x[i] - y[i])**2
            
        return sqrt(total_dist)

In [3]:
iris = datasets.load_iris()

In [4]:
X = iris.data
y = iris.target

In [5]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = .5)

In [6]:
# Make classifier with Iris data and KNN = 1
classifer = KNN_Classifier()

In [7]:
classifer.fit(X_train, y_train, 1)

In [8]:
predictions = classifer.predict(X_test)

In [9]:
# Labels are 0, 1, 2
confusion_matrix(y_test, predictions)

array([[25,  0,  0],
       [ 0, 28,  1],
       [ 0,  6, 15]])