In [21]:
import torch
import os
from sklearn.metrics import recall_score, precision_score, f1_score
from sklearn.tree import DecisionTreeClassifier
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

In [12]:
FEATURE_SPACE = 'layer3'
FEATURE_FILE = '../data/extracted_features.pt'

In [14]:
# Load features
features = torch.load(FEATURE_FILE, weights_only=False)

selected_features = [elem[FEATURE_SPACE] for elem in features]
labels = [elem['class'] for elem in features]

In [22]:
X = np.array(selected_features)
y = np.array(labels)

In [23]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state = 99)

In [38]:
clf = DecisionTreeClassifier(criterion='entropy', max_depth=8, min_samples_leaf=5,
                       min_samples_split=16, random_state=1)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
recall = recall_score(y_test, y_pred, average='macro')
print(f'Recall: {recall:.4f}')
precision = precision_score(y_test, y_pred, average='macro')
print(f'Precision: {precision:.4f}')
f1 = f1_score(y_test, y_pred, average='macro')
print(f'F1 Score: {f1:.4f}')

Recall: 0.8566
Precision: 0.8584
F1 Score: 0.8569


In [None]:
### IMPORTANT: Hyperparameter tuning with GridSearchCV  NO NEED TO RUN THIS PART EVERY TIME ###
from sklearn.model_selection import GridSearchCV

# Hyperparameter to fine tune
param_grid = {
    'max_depth': range(1, 10, 1),
    'min_samples_leaf': range(1, 20, 2),
    'min_samples_split': range(2, 20, 2),
    'criterion': ["entropy", "gini"]
}

tree = DecisionTreeClassifier(random_state=1)
# GridSearchCV
grid_search = GridSearchCV(estimator=tree, param_grid=param_grid, 
                           cv=5, verbose=2, n_jobs=-1)
grid_search.fit(X_train, y_train)

print("best accuracy", grid_search.best_score_)
print(grid_search.best_estimator_)

Fitting 5 folds for each of 1620 candidates, totalling 8100 fits
[CV] END criterion=entropy, max_depth=1, min_samples_leaf=1, min_samples_split=2; total time=   0.3s
[CV] END criterion=entropy, max_depth=1, min_samples_leaf=1, min_samples_split=4; total time=   0.3s
[CV] END criterion=entropy, max_depth=1, min_samples_leaf=1, min_samples_split=2; total time=   0.3s
[CV] END criterion=entropy, max_depth=1, min_samples_leaf=1, min_samples_split=6; total time=   0.3s
[CV] END criterion=entropy, max_depth=1, min_samples_leaf=1, min_samples_split=8; total time=   0.3s
[CV] END criterion=entropy, max_depth=1, min_samples_leaf=1, min_samples_split=2; total time=   0.3s
[CV] END criterion=entropy, max_depth=1, min_samples_leaf=1, min_samples_split=4; total time=   0.3s
[CV] END criterion=entropy, max_depth=1, min_samples_leaf=1, min_samples_split=2; total time=   0.3s
[CV] END criterion=entropy, max_depth=1, min_samples_leaf=1, min_samples_split=4; total time=   0.3s
[CV] END criterion=entropy

In [None]:
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

# 1) Make a list of feature‐names
#    – if X is a pandas DataFrame, you can just use X.columns.
#    – otherwise, generate something like "f0", "f1", … "f1023".
try:
    feature_names = list(X.columns)
except AttributeError:
    feature_names = [f"f{i}" for i in range(X.shape[1])]

# 2) Make a list of class‐names
#    – tree_clf.classes_ holds the unique class labels seen during fitting.
#    – convert them to strings if they aren’t already.
class_names = [str(c) for c in clf.classes_]

# 3) (Optional) limit the depth of the plotted tree for readability
#    – if you try to plot all 1 024 features you’ll get a massive, unreadable blob.
#    – you can pass e.g. max_depth=3 to plot_tree to only show the top 3 levels.
plt.figure(figsize=(180, 150))
plot_tree(
    clf,
    filled=True,
    feature_names=feature_names,
    class_names=class_names,
    max_depth=9  # <-- adjust or remove as needed
)
plt.show()
