In [119]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix

In [120]:
def chi_squared(x, y):
    return 0 if x + y == 0 else ((x - y) ** 2) / (x + y)

In [121]:
class CustomKNN:
    def __init__(self, k):
        self.k = k

    def fit(self, X_train, y_train):
        self.X_train = X_train
        self.y_train = y_train

    def predict(self, X_test):
        pred = []
        for test_sample in X_test:
            distances = []
            for i in range(len(self.X_train)):
                dis = 0
                for col in range(len(test_sample)):
                    dis += chi_squared(test_sample[col], self.X_train[i][col])
                distances.append((dis, self.y_train[i]))
            
            distances.sort()

            freq = {}
            for i in range(self.k):
                freq[distances[i][1]] = freq.get(distances[i][1], 0) + 1

            pred.append(max(freq, key=freq.get))

        return pred

In [122]:
iris = load_iris()
dir(iris)

['DESCR',
 'data',
 'data_module',
 'feature_names',
 'filename',
 'frame',
 'target',
 'target_names']

In [123]:
iris['feature_names']

['sepal length (cm)',
 'sepal width (cm)',
 'petal length (cm)',
 'petal width (cm)']

In [124]:
X = iris['data']
y = iris['target']

In [125]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [126]:
bestAccuracy = 0
bestK = 0

for i in range(1, 10):
    knn = CustomKNN(i)
    knn.fit(X_train, y_train)
    y_pred = knn.predict(X_test)

    acc = accuracy_score(y_test, y_pred)

    if acc > bestAccuracy:
        bestAccuracy = acc
        bestK = i

    print(acc)

1.0
1.0


1.0
1.0
1.0
1.0
1.0
1.0
1.0


In [127]:
knn = CustomKNN(5)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)

In [128]:
print(f'Accuracy : {accuracy_score(y_test, y_pred)}')
print(f'Confusion Matrix:\n{confusion_matrix(y_test, y_pred)}')

Accuracy : 1.0
Confusion Matrix:
[[10  0  0]
 [ 0  9  0]
 [ 0  0 11]]
