In [45]:
import pandas as pd
from sklearn import metrics
from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split

from Convert import convert_to_dataframe, single_y_test_pred

In [46]:
iris = convert_to_dataframe(load_iris())

In [47]:
iris.sample(10)

Unnamed: 0,sepallength,sepalwidth,petallength,petalwidth,target
103,6.3,2.9,5.6,1.8,virginica
4,5.0,3.6,1.4,0.2,setosa
68,6.2,2.2,4.5,1.5,versicolor
39,5.1,3.4,1.5,0.2,setosa
75,6.6,3.0,4.4,1.4,versicolor
95,5.7,3.0,4.2,1.2,versicolor
11,4.8,3.4,1.6,0.2,setosa
104,6.5,3.0,5.8,2.2,virginica
42,4.4,3.2,1.3,0.2,setosa
106,4.9,2.5,4.5,1.7,virginica


In [48]:
X = iris.drop(['target'], axis=1)
Y = iris['target']
feature_names = iris.columns.values.tolist()[:-1]
class_names = Y.unique().tolist()
print(feature_names)
print(class_names)

['sepallength', 'sepalwidth', 'petallength', 'petalwidth']
['setosa', 'versicolor', 'virginica']


In [49]:
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, shuffle=True)

In [50]:
x_train.head()

Unnamed: 0,sepallength,sepalwidth,petallength,petalwidth
145,6.7,3.0,5.2,2.3
121,5.6,2.8,4.9,2.0
51,6.4,3.2,4.5,1.5
109,7.2,3.6,6.1,2.5
47,4.6,3.2,1.4,0.2


In [51]:
clf = KNeighborsClassifier()

In [52]:
clf = clf.fit(x_train, y_train)

In [53]:
y_pred = clf.predict(x_test)

In [54]:
print(single_y_test_pred(y_test, y_pred))

    index      target      y_pred
0      49      setosa      setosa
1      30      setosa      setosa
2      40      setosa      setosa
3      70  versicolor  versicolor
4     138   virginica   virginica
5     114   virginica   virginica
6      75  versicolor  versicolor
7       4      setosa      setosa
8      88  versicolor  versicolor
9      97  versicolor  versicolor
10      3      setosa      setosa
11     16      setosa      setosa
12     80  versicolor  versicolor
13      0      setosa      setosa
14     41      setosa      setosa
15    127   virginica   virginica
16    129   virginica   virginica
17     28      setosa      setosa
18     14      setosa      setosa
19     53  versicolor  versicolor
20     61  versicolor  versicolor
21     54  versicolor  versicolor
22    125   virginica   virginica
23    136   virginica   virginica
24     17      setosa      setosa
25      9      setosa      setosa
26     62  versicolor  versicolor
27    147   virginica   virginica
28    142   vi

In [55]:
print(metrics.classification_report(y_test, y_pred))

              precision    recall  f1-score   support

      setosa       1.00      1.00      1.00        12
  versicolor       0.90      1.00      0.95         9
   virginica       1.00      0.89      0.94         9

    accuracy                           0.97        30
   macro avg       0.97      0.96      0.96        30
weighted avg       0.97      0.97      0.97        30



In [56]:
print("Confusion matrix:")
print(metrics.confusion_matrix(y_test, y_pred, labels=class_names))

Confusion matrix:
[[12  0  0]
 [ 0  9  0]
 [ 0  1  8]]


In [57]:
accuracy_test = metrics.accuracy_score(y_test, y_pred) * 100
accuracy_train = metrics.accuracy_score(y_train, clf.predict(x_train)) * 100

print(f"Accuracy: {round(accuracy_test, 2)}% on Test Data")
print(f"Accuracy: {round(accuracy_train, 2)}% on Training Data")

Accuracy: 96.67% on Test Data
Accuracy: 96.67% on Training Data


In [58]:
clf.score(x_test, y_test)

0.9666666666666667

In [59]:
clf.kneighbors()

(array([[0.24494897, 0.36055513, 0.37416574, 0.37416574, 0.42426407],
        [0.31622777, 0.33166248, 0.45825757, 0.60827625, 0.64031242],
        [0.26457513, 0.34641016, 0.37416574, 0.38729833, 0.41231056],
        [0.63245553, 0.67082039, 0.70710678, 0.75498344, 0.80622577],
        [0.14142136, 0.2236068 , 0.2236068 , 0.2236068 , 0.3       ],
        [0.14142136, 0.26457513, 0.28284271, 0.3       , 0.34641016],
        [0.31622777, 0.37416574, 0.42426407, 0.42426407, 0.51961524],
        [0.14142136, 0.26457513, 0.26457513, 0.26457513, 0.3       ],
        [0.1       , 0.31622777, 0.33166248, 0.37416574, 0.38729833],
        [0.1       , 0.28284271, 0.3       , 0.33166248, 0.33166248],
        [0.2       , 0.2236068 , 0.2236068 , 0.24494897, 0.28284271],
        [0.37416574, 0.42426407, 0.42426407, 0.46904158, 0.48989795],
        [0.26457513, 0.31622777, 0.42426407, 0.42426407, 0.42426407],
        [0.26457513, 0.28284271, 0.31622777, 0.34641016, 0.50990195],
        [0.53851648,