In [1]:
import numpy as np
from collections import Counter

In [5]:
class KNN:
    def __init__(self,k=3):
        self.k = k
        
    def fit(self,X,y):          
        """
        Used to store the training data for further predictions.
        It has no return value unless no training data is provided or len(data_points) and len(data_labels) are not same.
        
        Parameters:
        X : list of floats
            The training data points.
        Y : list of ints
            The training data labels.
        """
        if len(X) == 0:
            raise ValueError("No training data provided")
        
        elif len(X) != len(y):
            raise ValueError("The length of trainind data points and correspoding labels should be same")
        
        else:
            self.X_train = X        #Associate a attribute to the data points
            self.y_train = y        #Associate a attribute to the data labels
        
        
    def predict(self, X):              
        """
        Calculates the distance of the point with k-nearest points and decide on the label by taking the mean.
        
        Parameters:
        X : list of flaots
            The data points to which the prediction is required.
        """                              
        predicted_labels = [self._predict(x) for x in X]         #predicting for each data point in X
        return predicted_labels
    
    
    @staticmethod
    def euclidean_distance(x1,x2):
        """ 
        Calculates the euclidean distance between the two points.
        
        Parameters:
        x1,x2 : Floats
                2 data points
        """
        distance = np.sqrt(np.sum((x1-x2)**2))
        return distance
        
        
        
        
    def _predict(self,x):
        """ 
        Helper function to predict the data label for each data point.
        
        Parameters:
        x : float
            A single data point
        """
        
        distances = [self.euclidean_distance(x,x_train) for x_train in self.X_train]
        
        #Get the closest k points
        k_indices = np.argsort(distances)[:self.k]
        k_nearest_labels = [self.y_train[i] for i in k_indices]
        
        #Majority repeated label
        most_common = Counter(k_nearest_labels).most_common()
        return most_common[0][0]
        
        