In [21]:
import torch
import os
from sklearn.metrics import recall_score, precision_score, f1_score
from sklearn.tree import DecisionTreeClassifier
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

In [12]:
FEATURE_SPACE = 'layer3'
FEATURE_FILE = '../data/extracted_features.pt'

In [14]:
# Load features
features = torch.load(FEATURE_FILE, weights_only=False)

selected_features = [elem[FEATURE_SPACE] for elem in features]
labels = [elem['class'] for elem in features]

In [22]:
X = np.array(selected_features)
y = np.array(labels)

In [23]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state = 99)

In [26]:
clf = DecisionTreeClassifier(random_state=1)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
recall = recall_score(y_test, y_pred, average='macro')
print(f'Recall: {recall:.4f}')
precision = precision_score(y_test, y_pred, average='macro')
print(f'Precision: {precision:.4f}')
f1 = f1_score(y_test, y_pred, average='macro')
print(f'F1 Score: {f1:.4f}')

Recall: 0.8321
Precision: 0.8329
F1 Score: 0.8323


In [30]:
from sklearn.model_selection import GridSearchCV

# Hyperparameter to fine tune
param_grid = {
    'max_depth': range(1, 10, 2),
    'min_samples_leaf': range(1, 20, 5),
    'min_samples_split': range(2, 20, 5),
    'criterion': ["entropy", "gini"]
}

tree = DecisionTreeClassifier(random_state=1)
# GridSearchCV
grid_search = GridSearchCV(estimator=tree, param_grid=param_grid, 
                           cv=3, verbose=True)
grid_search.fit(X_train, y_train)

print("best accuracy", grid_search.best_score_)
print(grid_search.best_estimator_)

Fitting 3 folds for each of 160 candidates, totalling 480 fits


KeyboardInterrupt: 

In [None]:
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
 
tree_clf = grid_search.best_estimator_

plt.figure(figsize=(18, 15))
plot_tree(tree_clf, filled=True, feature_names=iris.feature_names,
          class_names=iris.target_names)
plt.show()