In [1]:
import numpy as np

terrain_data = np.load(f'processed_terrain_data.npy')
terrain_labels = np.load(f'terrain_data_labels.npy')
terrain_columns = np.load('terrains_columns_metadata.npy')

# Decision Tree

In [2]:
import pickle
from sklearn import neighbors
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix
from sklearn import tree


foldsNum = 5
kf = StratifiedKFold(n_splits=foldsNum, shuffle = True)
clf = tree.DecisionTreeClassifier(class_weight='balanced')
acc = 0
n_classes = len(np.unique(terrain_labels))
recall = np.zeros(n_classes)
precision = np.zeros(n_classes)
for train_index, test_index in kf.split(terrain_data, terrain_labels):
    
    # Training phase
    x_train = terrain_data[train_index, :]
    y_train = terrain_labels[train_index]
    clf.fit(x_train, y_train)

    # Test phase
    x_test = terrain_data[test_index, :]
    y_test = terrain_labels[test_index]    
    y_pred = clf.predict(x_test)
    # Calculate confusion matrix and model performance
    cm = confusion_matrix(y_test, y_pred)
    print('Confusion matrix\n', cm)

    accSum = 0
    for i in range(n_classes):
        accSum += cm[i,i]
    acc += accSum / len(y_test)

    for i in range(n_classes):
        innerSum = 0
        for y in range(n_classes):
            innerSum += cm[i,y]
        recall[i] += cm[i,i]/innerSum

    for i in range(n_classes):
        innerSum = 0
        for y in range(n_classes):
            innerSum += cm[y,i]
        precision[i] += cm[i,i]/innerSum

# Print results
acc = acc/foldsNum
print('Acc: ', acc)
precision = precision/foldsNum
print('Precision: ', precision)
recall = recall/foldsNum
print('Recall: ', recall)

with open(f'terrain_decision_tree.pkl', 'wb') as f:
    pickle.dump(clf, f)

Confusion matrix
 [[191   4   9]
 [ 10  44  32]
 [  5  27  89]]
Confusion matrix
 [[184  17   2]
 [  6  52  28]
 [  4  25  92]]
Confusion matrix
 [[182  14   8]
 [ 13  52  21]
 [ 14  24  82]]
Confusion matrix
 [[193   5   6]
 [ 14  44  28]
 [  9  28  83]]
Confusion matrix
 [[185   9  10]
 [ 11  50  24]
 [  9  28  84]]
Acc:  0.7835178921132278
Precision:  [0.9084818  0.57275543 0.71974372]
Recall:  [0.9175553  0.56415869 0.71301653]
