In [16]:
import math
import numpy as np
from collections import Counter

class KNN():
    def __init__(self):
        self.feature_vectors = []
        self.classes = []

    def add_data(self, batch):
        for X, y in batch:
            self.feature_vectors.append(X)
            self.classes.append(y)
    
    def l2_distance(self, x: list[float], y: list[float]):
        return math.sqrt(sum((x[i] - y[i]) ** 2) for i in range(len(x)))

    def predict(self, test: list[float], k: int = 5):
        if len(test) != len(self.feature_vectors[0]):
            raise ValueError("Wrong length")
        distances = [self.l2_distance(test, feat) for feat in self.feature_vectors]
        top_k_indices = [idx for idx, _ in sorted(enumerate(distances), key=lambda x: x[1])[:k]]
        top_k_classes = [self.classes[idx] for idx in top_k_indices]
        return Counter(top_k_classes).most_common(1)[0][0]
    

## fake data

X = np.random.randn(100, 10)
y = np.random.randint(low=0, high=2, size=(100,))
batch = list(zip(X, y))
toy_knn = KNN()
toy_knn.add_data(batch)

In [26]:
np.hstack((X, y.reshape(-1, 1))).shape

(100, 11)

In [29]:
y.reshape(-1,1).shape

(100, 1)

In [37]:
a = [
    (1,2),
    (3,3),
    (2,4),
]
sorted_ = sorted(a)[:2]
[x[1] for x in sorted_]

[2, 4]

In [8]:
np.random.randn(100,10)

array([[ 1.17330534e+00,  8.89019450e-01,  1.55344689e+00,
         3.43608559e-01,  1.94687347e+00, -3.70791243e-01,
         6.97109277e-01,  1.10060803e+00, -5.06173855e-01,
        -5.75430542e-01],
       [-5.27617446e-01,  8.09334946e-01,  8.93425553e-01,
         1.48140706e+00, -2.92083774e-01,  9.32627577e-01,
         8.33514826e-01,  9.65782516e-01,  1.20980355e-01,
        -6.50657272e-01],
       [-2.47781323e-01,  2.01903042e+00,  1.69943325e+00,
        -1.43397005e-01,  1.58611279e-02, -1.48727106e+00,
        -5.06552323e-01, -1.42135027e+00, -4.92823920e-02,
         2.82834156e-01],
       [ 1.95811824e-01, -1.05921020e+00, -6.97069825e-01,
         1.44649403e+00,  2.49057581e+00, -1.62830432e-01,
         7.96043153e-01, -5.33714955e-01, -2.12355746e-01,
        -1.42478549e+00],
       [-1.00970988e+00, -8.70078049e-01, -1.20484053e+00,
         1.00762698e-01,  1.64171529e-01, -1.52114403e+00,
         6.11458156e-01,  1.12176892e+00, -1.14976596e+00,
         1.

In [15]:
(np.random.randint(low=0, high=2, size=(100,1)) + np.random.randint(low=0, high=2, size=(100,1))).shape

(100, 1)