In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from heapq import nlargest

In [2]:
DATAROOT = r'./data/'
data = dict(
    train=pd.read_csv(
        DATAROOT + 'train.csv'
    ).to_numpy(), 
    val=pd.read_csv(
        DATAROOT + 'val.csv'
    ).to_numpy(), 
    test=pd.read_csv(
        DATAROOT + 'test_data.csv'
    ).to_numpy()
)

In [26]:
class knn_euclidean:

    def __init__(self, train):
        self.data  = train[:, :-1]
        self.label = train[:, -1]

    def predict(self, x, k):
        distance = np.linalg.norm(
            self.data - x, 
            axis=-1
        )
        index = np.argpartition(
            distance, kth=k
        )
        counts = dict()
        neighbours_label = self.label[index[:k]]
        for label in neighbours_label:
            if label not in counts:
                counts[label] = 1
            else:
                counts[label] += 1
        most_category, _ = max(
            counts.items(), 
            key=lambda x: x[1]
        )
        return most_category

In [32]:
model = knn_euclidean(data['train'])
acc = 0
for sample in data['val']:
    pred = model.predict(sample[:-1], 5)
    true = sample[-1]
    if pred == true: acc += 1
acc /= data['val'].shape[0]
print('accuracy: {}'.format(acc))

accuracy: 1.0


In [None]:
class knn_mahalanobis:

    def __init__(self, train):
        self.data  = train[:, :-1]
        self.label = train[:, -1]
        self.A = None

    def fit(self):
        pass

    def predict(self, x, k):
        distance = np.array([
            np.linalg.norm(self.A @ (y - x))
            for y in self.data
        ])
        index = np.argpartition(
            distance, kth=k
        )
        counts = dict()
        neighbours_label = self.label[index[:k]]
        for label in neighbours_label:
            if label not in counts:
                counts[label] = 1
            else:
                counts[label] += 1
        most_category, _ = max(
            counts.items(), 
            key=lambda x: x[1]
        )
        return most_category