Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature importances #5

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
404 changes: 404 additions & 0 deletions docs/tutorials/runtime_comparison.ipynb

Large diffs are not rendered by default.

529 changes: 529 additions & 0 deletions docs/tutorials/test_feature_importance.ipynb

Large diffs are not rendered by default.

64 changes: 56 additions & 8 deletions oblique_forests/morf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import numpy as np
from joblib import Parallel, delayed
from sklearn.ensemble._forest import ForestClassifier
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.fixes import _joblib_parallel_args

from .tree.morf_tree import Conv2DObliqueTreeClassifier

Expand Down Expand Up @@ -175,11 +179,55 @@ def __init__(
# self.min_impurity_split = min_impurity_split

# s-rerf params
# self.discontiguous_height = discontiguous_height
# self.discontiguous_width = discontiguous_width
# self.image_height = image_height
# self.image_width = image_width
# self.patch_height_max = patch_height_max
# self.patch_height_min = patch_height_min
# self.patch_width_max = patch_width_max
# self.patch_width_min = patch_width_min
self.discontiguous_height = discontiguous_height
self.discontiguous_width = discontiguous_width
self.image_height = image_height
self.image_width = image_width
self.patch_height_max = patch_height_max
self.patch_height_min = patch_height_min
self.patch_width_max = patch_width_max
self.patch_width_min = patch_width_min

@property
def feature_importances_(self):
"""
Computes the importance of every unique feature used to make a split
in each tree of the forest.

Parameters
----------
normalize : bool, default=True
A boolean to indicate whether to normalize feature importances.

Returns
-------
importances : array of shape [n_features]
Array of count-based feature importances.
"""
# TODO: Parallelize this and see if there is an equivalent way to express this better
# 1. Find all unique atoms in the forest
# 2. Compute number of times each atom appears across all trees
forest_projections = [
node.proj_vec
for tree in self.estimators_
if tree.tree_.node_count > 0
for node in tree.tree_.nodes
if node.proj_vec is not None
]
unique_projections, counts = np.unique(
forest_projections, axis=0, return_counts=True
)

if counts.sum() == 0:
return np.zeros(self.n_features_, dtype=np.float64)

# 3. Count how many times each feature gets nonzero weight in unique projections
importances = np.zeros(self.n_features_)
for proj_vec, count in zip(unique_projections, counts):
importances[np.nonzero(proj_vec)] += count

# 4. Normalize by number of unique projections
if len(unique_projections) > 0:
importances /= len(unique_projections)

return importances
50 changes: 49 additions & 1 deletion oblique_forests/sporf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import numpy as np
from joblib import Parallel, delayed
from sklearn.ensemble._forest import ForestClassifier
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.fixes import _joblib_parallel_args

from .tree.oblique_tree import ObliqueTreeClassifier
from oblique_forests.tree.oblique_tree import ObliqueTreeClassifier


class ObliqueForestClassifier(ForestClassifier):
Expand Down Expand Up @@ -95,3 +99,47 @@ def __init__(
# self.max_leaf_nodes = max_leaf_nodes
# self.min_impurity_decrease = min_impurity_decrease
# self.min_impurity_split = min_impurity_split

@property
def feature_importances_(self):
"""
Computes the importance of every unique feature used to make a split
in each tree of the forest.

Parameters
----------
normalize : bool, default=True
A boolean to indicate whether to normalize feature importances.

Returns
-------
importances : array of shape [n_features]
Array of count-based feature importances.
"""
# TODO: Parallelize this and see if there is an equivalent way to express this better
# 1. Find all unique atoms in the forest
# 2. Compute number of times each atom appears across all trees
forest_projections = [
node.proj_vec
for tree in self.estimators_
if tree.tree_.node_count > 0
for node in tree.tree_.nodes
if node.proj_vec is not None
]
unique_projections, counts = np.unique(
forest_projections, axis=0, return_counts=True
)

if counts.sum() == 0:
return np.zeros((self.n_features_), dtype=np.float64)

# 3. Count how many times each feature gets nonzero weight in unique projections
importances = np.zeros((self.n_features_), dtype=np.float64)
for proj_vec, count in zip(unique_projections, counts):
importances[np.nonzero(proj_vec)] += count

# 4. Normalize by number of unique projections
if len(unique_projections) > 0:
importances /= len(unique_projections)

return importances
4 changes: 2 additions & 2 deletions oblique_forests/tree/morf_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,15 +265,15 @@ def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None):
splitter = self._set_splitter(X, y)

# create the Oblique tree
self.tree = ObliqueTree(
self.tree_ = ObliqueTree(
splitter,
self.min_samples_split,
self.min_samples_leaf,
self.max_depth,
self.min_impurity_split,
self.min_impurity_decrease,
)
self.tree.build()
self.tree_.build()
return self


Expand Down
93 changes: 88 additions & 5 deletions oblique_forests/tree/oblique_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from joblib import Parallel, delayed
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y
from sklearn.utils.fixes import _joblib_parallel_args

from ._split import BaseObliqueSplitter
from .oblique_base import BaseManifoldSplitter, Node, SplitInfo, StackRecord
Expand Down Expand Up @@ -528,6 +529,37 @@ def predict(self, X, check_input=True):

return predictions

def compute_feature_importances(self):
"""
Computes the importance of each feature (aka variable).

Parameters
----------
unique_projections : ndarray of shape (n_proj, n_features)
Array of unique sampling projection vectors.

Returns
-------
importances : ndarray of shape (n_features,)
Normalized importance of each feature of the data matrix.
"""
projections = [
node.proj_vec for node in self.nodes if node.proj_vec is not None
]
unique_projections, counts = np.unique(projections, axis=0, return_counts=True)

if counts.sum() == 0:
return np.zeros((self.splitter.n_features,))

importances = np.zeros((self.splitter.n_features,))
for proj_vec, count in zip(unique_projections, counts):
importances[np.nonzero(proj_vec)] += count

if len(unique_projections) > 0:
importances /= len(unique_projections)

return importances


class ObliqueTreeClassifier(BaseEstimator):
"""
Expand Down Expand Up @@ -600,6 +632,7 @@ def __init__(

# Max features
self.max_features = max_features
self.n_jobs = n_jobs

self.n_classes = None
self.n_jobs = n_jobs
Expand Down Expand Up @@ -640,15 +673,15 @@ def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None):
tree_func = self._tree_class()

# instantiate the tree and build it
self.tree = tree_func(
self.tree_ = tree_func(
splitter,
self.min_samples_split,
self.min_samples_leaf,
self.max_depth,
self.min_impurity_split,
self.min_impurity_decrease,
)
self.tree.build()
self.tree_.build()

return self

Expand All @@ -666,7 +699,7 @@ def apply(self, X):
pred_nodes : array of shape[n_samples]
The indices for each test sample's final node in the oblique tree.
"""
pred_nodes = self.tree.predict(X).astype(int)
pred_nodes = self.tree_.predict(X).astype(int)
return pred_nodes

def predict(self, X, check_input=True):
Expand All @@ -689,7 +722,7 @@ def predict(self, X, check_input=True):
pred_nodes = self.apply(X)
for k in range(len(pred_nodes)):
id = pred_nodes[k]
preds[k] = self.tree.nodes[id].label
preds[k] = self.tree_.nodes[id].label

return preds

Expand All @@ -713,7 +746,7 @@ def predict_proba(self, X, check_input=True):
pred_nodes = self.apply(X)
for k in range(len(preds)):
id = pred_nodes[k]
preds[k] = self.tree.nodes[id].proba
preds[k] = self.tree_.nodes[id].proba

return preds

Expand All @@ -737,3 +770,53 @@ def predict_log_proba(self, X, check_input=True):
# TODO: Actually do this function
def _validate_X_predict(self, X, check_input=True):
return X

@property
def feature_importances_(self):
"""
Return the feature importances.
The importance of a feature is computed as the number of times it
is used in a projection across all split nodes

Returns
-------
feature_importances_ : ndarray of shape (n_features,)
Array of count-based feature importances.
"""
check_is_fitted(self)

return self.tree_.compute_feature_importances()

def compute_projection_counts(self, unique_projections=None):
"""
Counts the number of times each unique projection in the tree appears.

Parameters
----------
unique_projections : ndarray of shape (n_proj,), optional
Array of unique projections to count, by default None

Returns
-------
projection_counts : ndarray of shape (n_proj,)
Counts of each unique projection used in this tree.
"""
check_is_fitted(self)

if unique_projections is None:
projections = [
node.proj_vec
for node in self.tree_.nodes
if node.proj_vec is not None
]
unique_projections, counts = np.unique(projections, axis=0, return_counts=True)
return counts, unique_projections

# TODO: see if joblib will speed up at all for this for loop
n_proj = len(unique_projections)
counts = np.zeros(n_proj)
for node in self.tree_.nodes:
projection_idx = np.where((unique_projections == node.proj_vec).all(axis=1))
counts[projection_idx] += 1

return counts, unique_projections
62 changes: 61 additions & 1 deletion oblique_forests/tree/tests/test_morf_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from sklearn.utils.validation import check_random_state

from oblique_forests.tree.morf_split import Conv2DSplitter
from oblique_forests.sporf import ObliqueForestClassifier as SPORF
from oblique_forests.tree.oblique_tree import ObliqueTreeClassifier as OTC
from oblique_forests.morf import Conv2DObliqueForestClassifier as MORF
from oblique_forests.tree.morf_tree import Conv2DObliqueTreeClassifier

# toy sample
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
Expand Down Expand Up @@ -43,7 +47,7 @@ def test_convolutional_splitter():
y[:25] = 0

splitter = Conv2DSplitter(
X,
X.reshape(n, -1),
y,
max_features=1,
feature_combinations=1.5,
Expand All @@ -52,4 +56,60 @@ def test_convolutional_splitter():
image_width=d,
patch_height_max=2,
patch_height_min=2,
patch_width_max=3,
patch_width_min=3,
)

splitter.sample_proj_mat(splitter.indices)


if __name__ == "__main__":

test_convolutional_splitter()

# from sklearn.datasets import fetch_openml
from keras.datasets import mnist
import time

(X_train, y_train), (X_test, y_test) = mnist.load_data()

# Get 100 samples of 3s and 5s
num = 100
threes = np.where(y_train == 3)[0][:num]
fives = np.where(y_train == 5)[0][:num]
train_idx = np.concatenate((threes, fives))

# Subset train data
Xtrain = X_train[train_idx]
ytrain = y_train[train_idx]

# Apply random shuffling
permuted_idx = np.random.permutation(len(train_idx))
Xtrain = Xtrain[permuted_idx]
ytrain = ytrain[permuted_idx]

# Subset test data
test_idx = np.where(y_test == 3)[0]
Xtest = X_test[test_idx]
ytest = y_test[test_idx]

print(f"-----{2 * num} samples")

clf = OTC(random_state=0)
start = time.time()
clf.fit(Xtrain.reshape(Xtrain.shape[0], -1), ytrain)
elapsed = time.time() - start
print(elapsed)
print(f"SPORF Tree: {elapsed} sec")

clf = Conv2DObliqueTreeClassifier(image_height=28, image_width=28, random_state=0)
start = time.time()
clf.fit(Xtrain.reshape(Xtrain.shape[0], -1), ytrain)
elapsed = time.time() - start
print(f"MORF Tree: {elapsed} sec")

clf = SPORF(n_estimators=100, random_state=0)
start = time.time()
clf.fit(Xtrain.reshape(Xtrain.shape[0], -1), ytrain)
elapsed = time.time() - start
print(f"SPORF: {elapsed} sec")
Loading