In [1]:
import numpy as np

terrain_data = np.load(f'processed_terrain_data.npy')
terrain_labels = np.load(f'terrain_data_labels.npy')
terrain_columns = np.load('terrains_columns_metadata.npy')


# KNN Classification

In [2]:
import pickle
from sklearn import neighbors
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix


foldsNum = 5
kf = StratifiedKFold(n_splits=foldsNum, shuffle = True)
clf = neighbors.KNeighborsClassifier(n_neighbors=18)
acc = 0
n_classes = len(np.unique(terrain_labels))
recall = np.zeros(n_classes)
precision = np.zeros(n_classes)
for train_index, test_index in kf.split(terrain_data, terrain_labels):
    
    # Training phase
    x_train = terrain_data[train_index, :]
    y_train = terrain_labels[train_index]
    clf.fit(x_train, y_train)

    # Test phase
    x_test = terrain_data[test_index, :]
    y_test = terrain_labels[test_index]    
    y_pred = clf.predict(x_test)
    # Calculate confusion matrix and model performance
    cm = confusion_matrix(y_test, y_pred)
    print('Confusion matrix\n', cm)

    accSum = 0
    for i in range(n_classes):
        accSum += cm[i,i]
    acc += accSum / len(y_test)

    for i in range(n_classes):
        innerSum = 0
        for y in range(n_classes):
            innerSum += cm[i,y]
        recall[i] += cm[i,i]/innerSum

    for i in range(n_classes):
        innerSum = 0
        for y in range(n_classes):
            innerSum += cm[y,i]
        precision[i] += cm[i,i]/innerSum

# Print results
acc = acc/foldsNum
print('Acc: ', acc)
precision = precision/foldsNum
print('Precision: ', precision)
recall = recall/foldsNum
print('Recall: ', recall)

with open(f'terrain_knn.pkl', 'wb') as f:
    pickle.dump(clf, f)


Confusion matrix
 [[198   2   4]
 [ 19  49  18]
 [  9  21  91]]
Confusion matrix
 [[197   3   3]
 [ 13  52  21]
 [ 15  19  87]]
Confusion matrix
 [[199   3   2]
 [ 20  41  25]
 [  8  18  94]]
Confusion matrix
 [[198   5   1]
 [ 12  49  25]
 [ 14  17  89]]
Confusion matrix
 [[195   6   3]
 [ 12  53  20]
 [ 19  23  79]]
Acc:  0.8147207880837932
Precision:  [0.87501483 0.67620618 0.78287517]
Recall:  [0.96859847 0.56889193 0.72979339]
