In [1]:
import torch
from scipy.stats import mode
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

In [2]:
class KNN:
  def __init__(self, k, X):
    """
      k: number of neighbors
    """
    self.k = k
  
  def distance(self, point1, point2, default='euclidean', p=2):
    if default == 'euclidean':
      return torch.norm(point1-point2, 2, 0)
    elif default == 'manhattan':
      return torch.norm(torch.abs(point1-point2))
    elif default == 'minkowski':
      return torch.norm(torch.sum(torch.abs(point1-point2)**p), 1/p)
    else:
      raise ValueError('Unkown similarity distance type')
  
  def fit_predict(self, X, y, item):
    """
      - Iterate through each datapoints (item/y_test) that needs to be classified
      - Find distance between all train data points and each datapoint (item/y_test)
        using D distance with D in [euclidean, manhattan, minkowski]
      - Sort the distance using argsort, it gives indices of the y_test
      - Find the majority label whose distance closest to each datapoint of y_test.

      X: input tensor
      y: ground truth label
      item: tensors to be classified

      return: predicted labels
    """
    y_predict = []
    for i in item:
      point_distances = []
      for ipt in range(X.shape[0]):
        distances = self.distance(X[ipt,:], i)
        point_distances.append(distances)

      point_distances = torch.tensor(point_distances)
      k_neighbors = torch.argsort(point_distances)[:self.k]
      y_label = y[k_neighbors]
      major_class = mode(y_label)
      major_class = major_class.mode[0]
      y_predict.append(major_class)
    
    return torch.tensor(y_predict)

In [None]:
iris = load_iris()
X = torch.tensor(iris.data)
y = torch.tensor(iris.target)

torch.manual_seed(0)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
model = KNN(5, X_train)
y_pred = model.fit_predict(X_train, y_train, X_test)
print(f'Accuracy: {accuracy_score(y_test, y_pred)}')