Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-output regression support for CascadeForestRegressor #40

Merged
merged 5 commits into from
Feb 22, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Version 0.1.*
.. |Fix| replace:: :raw-html:`<span class="badge badge-danger">Fix</span>` :raw-latex:`{\small\sc [Fix]}`
.. |API| replace:: :raw-html:`<span class="badge badge-warning">API Change</span>` :raw-latex:`{\small\sc [API Change]}`

- |Feature| add multi-output support for :obj:`CascadeForestRegressor` (`#40 <https://github.com/LAMDA-NJU/Deep-Forest/pull/40>`__) @Alex-Medium
- |Feature| add layer-wise feature importances (`#39 <https://github.com/LAMDA-NJU/Deep-Forest/pull/39>`__) @xuyxu
- |Feature| add scikit-learn backend (`#36 <https://github.com/LAMDA-NJU/Deep-Forest/pull/36>`__) @xuyxu
- |Feature| add official support for Mac-OS (`#34 <https://github.com/LAMDA-NJU/Deep-Forest/pull/34>`__) @T-Allen-sudo
Expand Down
2 changes: 1 addition & 1 deletion deepforest/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def model_loadobj(dirname, obj_type, d=None):
# Build a temporary layer
layer_ = Layer(
layer_idx=layer_idx,
n_classes=d["n_outputs"],
n_outputs=d["n_outputs"],
criterion=d["criterion"],
n_estimators=d["n_estimators"],
partial_mode=d["partial_mode"],
Expand Down
33 changes: 9 additions & 24 deletions deepforest/_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class Layer(object):
def __init__(
self,
layer_idx,
n_classes,
n_outputs,
criterion,
n_estimators=2,
n_trees=100,
Expand All @@ -61,7 +61,7 @@ def __init__(
is_classifier=True,
):
self.layer_idx = layer_idx
self.n_classes = n_classes
self.n_outputs = n_outputs
self.criterion = criterion
self.n_estimators = n_estimators * 2 # internal conversion
self.n_trees = n_trees
Expand Down Expand Up @@ -135,10 +135,7 @@ def fit_transform(self, X, y, sample_weight=None):
n_samples, self.n_features = X.shape

X_aug = []
if self.is_classifier:
Alex-Medium marked this conversation as resolved.
Show resolved Hide resolved
oob_decision_function = np.zeros((n_samples, self.n_classes))
else:
oob_decision_function = np.zeros((n_samples, 1))
oob_decision_function = np.zeros((n_samples, self.n_outputs))

# A random forest and an extremely random forest will be fitted
for estimator_idx in range(self.n_estimators // 2):
Expand Down Expand Up @@ -181,12 +178,12 @@ def fit_transform(self, X, y, sample_weight=None):
self.oob_decision_function_ = oob_decision_function / self.n_estimators
if self.is_classifier:
y_pred = np.argmax(oob_decision_function, axis=1)
self.val_acc_ = accuracy_score(
Alex-Medium marked this conversation as resolved.
Show resolved Hide resolved
self.val_performance_ = accuracy_score(
y, y_pred, sample_weight=sample_weight
)
else:
y_pred = self.oob_decision_function_
self.val_acc_ = mean_squared_error(
self.val_performance_ = mean_squared_error(
y, y_pred, sample_weight=sample_weight
)

Expand All @@ -198,10 +195,7 @@ def transform(self, X, is_classifier):
Return the concatenated transformation results from all base
estimators."""
n_samples, _ = X.shape
if is_classifier:
X_aug = np.zeros((n_samples, self.n_classes * self.n_estimators))
else:
X_aug = np.zeros((n_samples, self.n_estimators))
X_aug = np.zeros((n_samples, self.n_outputs * self.n_estimators))
for idx, (key, estimator) in enumerate(self.estimators_.items()):
if self.verbose > 1:
msg = "{} - Evaluating estimator = {:<5} in layer = {}"
Expand All @@ -211,21 +205,15 @@ def transform(self, X, is_classifier):
# Load the estimator from the buffer
estimator = self.buffer.load_estimator(estimator)

if is_classifier:
left, right = self.n_classes * idx, self.n_classes * (idx + 1)
else:
left, right = idx, (idx + 1)
left, right = self.n_outputs * idx, self.n_outputs * (idx + 1)
X_aug[:, left:right] += estimator.predict(X)

return X_aug

def predict_full(self, X, is_classifier):
"""Return the concatenated predictions from all base estimators."""
n_samples, _ = X.shape
if is_classifier:
pred = np.zeros((n_samples, self.n_classes * self.n_estimators))
else:
pred = np.zeros((n_samples, self.n_estimators))
pred = np.zeros((n_samples, self.n_outputs * self.n_estimators))
for idx, (key, estimator) in enumerate(self.estimators_.items()):
if self.verbose > 1:
msg = "{} - Evaluating estimator = {:<5} in layer = {}"
Expand All @@ -235,10 +223,7 @@ def predict_full(self, X, is_classifier):
# Load the estimator from the buffer
estimator = self.buffer.load_estimator(estimator)

if is_classifier:
left, right = self.n_classes * idx, self.n_classes * (idx + 1)
else:
left, right = idx, (idx + 1)
left, right = self.n_outputs * idx, self.n_outputs * (idx + 1)
pred[:, left:right] += estimator.predict(X)

return pred
14 changes: 6 additions & 8 deletions deepforest/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def _get_n_output(self, y):
if is_classifier(self):
n_output = np.unique(y).shape[0] # classification
return n_output
return 1 # this parameter are not used in regression
return y.shape[1] if len(y.shape) > 1 else 1 # regression

def _get_layer(self, layer_idx):
"""Get the layer from the internal container according to the index."""
Expand Down Expand Up @@ -705,9 +705,7 @@ def predict(self, X):

@property
def n_aug_features_(self):
if is_classifier(self):
return 2 * self.n_estimators * self.n_outputs_
return 2 * self.n_estimators
return 2 * self.n_estimators * self.n_outputs_

# flake8: noqa: E501
def fit(self, X, y, sample_weight=None):
Expand Down Expand Up @@ -763,7 +761,7 @@ def fit(self, X, y, sample_weight=None):
training_time = toc - tic

# Set the reference performance
pivot = layer_.val_acc_
pivot = layer_.val_performance_

if self.verbose > 0:
msg = "{} layer = {:<2} | {} | Elapsed = {:.3f} s"
Expand Down Expand Up @@ -844,7 +842,7 @@ def fit(self, X, y, sample_weight=None):
toc = time.time()
training_time = toc - tic

new_pivot = layer_.val_acc_
new_pivot = layer_.val_performance_

if self.verbose > 0:
msg = "{} layer = {:<2} | {} | Elapsed = {:.3f} s"
Expand Down Expand Up @@ -1408,7 +1406,7 @@ def __init__(
)

def _repr_performance(self, pivot):
msg = "Val Acc = {:.3f}"
msg = "Val MSE = {:.5f}"
return msg.format(pivot)

@deepforest_model_doc(
Expand Down Expand Up @@ -1495,5 +1493,5 @@ def predict(self, X):
_y = predictor.predict(X_middle_test_)
else:
_y = layer.predict_full(X_middle_test_, is_classifier(self))
_y = _y.sum(axis=1) / _y.shape[1]
_y = _utils.merge_proba(_y, self.n_outputs_)
return _y
10 changes: 6 additions & 4 deletions deepforest/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,12 @@ def _parallel_build_trees(
if not value.flags["C_CONTIGUOUS"]:
value = np.ascontiguousarray(value)

value = np.squeeze(value, axis=1)

if is_classifier:
value = np.squeeze(value, axis=1)
value /= value.sum(axis=1)[:, np.newaxis]
else:
if len(value.shape) == 3:
value = np.squeeze(value, axis=2)

# Set the OOB predictions
oob_prediction = _C_FOREST.predict(
Expand Down Expand Up @@ -454,7 +456,7 @@ def fit(self, X, y, sample_weight=None):
(n_samples, self.classes_[0].shape[0])
)
else:
oob_decision_function = np.zeros((n_samples, 1))
oob_decision_function = np.zeros((n_samples, self.n_outputs_))
mask = np.zeros(n_samples)

lock = threading.Lock()
Expand Down Expand Up @@ -790,7 +792,7 @@ def predict(self, X):
n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs)

# avoid storing the output of every estimator by summing them here
y_hat = np.zeros((X.shape[0], 1), dtype=np.float64)
y_hat = np.zeros((X.shape[0], self.n_outputs_), dtype=np.float64)

# Parallel loop
lock = threading.Lock()
Expand Down
6 changes: 3 additions & 3 deletions deepforest/tree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,14 +429,14 @@ def predict(self, X, check_input=True):
"""
check_is_fitted(self)
X = self._validate_X_predict(X, check_input)
proba = self.tree_.predict(X)
pred = self.tree_.predict(X)

# Classification
if is_classifier(self):
return self.classes_.take(np.argmax(proba, axis=1), axis=0)
return self.classes_.take(np.argmax(pred, axis=1), axis=0)
# Regression
else:
return proba[:, 0]
return np.squeeze(pred)


class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_layer_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Parameters
classifier_layer_kwargs = {
"layer_idx": 0,
"n_classes": 10,
"n_outputs": 10,
"criterion": "gini",
"n_estimators": 1,
"n_trees": 10,
Expand All @@ -46,7 +46,7 @@

regressor_layer_kwargs = {
"layer_idx": 0,
"n_classes": 1,
"n_outputs": 1,
"criterion": "mse",
"n_estimators": 1,
"n_trees": 10,
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_classifier_layer_properties_after_fitting():
# Output dim
expect_dim = (
2
* classifier_layer_kwargs["n_classes"]
* classifier_layer_kwargs["n_outputs"]
* classifier_layer_kwargs["n_estimators"]
)
assert X_aug.shape[1] == expect_dim
Expand Down
70 changes: 66 additions & 4 deletions tests/test_tree_same.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""
Testing cases here make sure that the outputs of the reduced implementation
on `DecisionTreeClassifier` and `ExtraTreeClassifier` are exactly the same as
the original version in Scikit-Learn after the data binning.
Testing cases here make sure that predictions of the reduced implementation
on decision tree is exactly the same as the original version in Scikit-Learn
after data binning.
"""

import pytest
import numpy as np
from numpy.testing import assert_array_equal
from sklearn.tree import (
DecisionTreeClassifier as sklearn_DecisionTreeClassifier,
Expand All @@ -19,7 +20,7 @@
from sklearn.model_selection import train_test_split
from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper

# Toy classification datasets
# Toy datasets
from sklearn.datasets import load_iris, load_wine, load_boston

from deepforest import DecisionTreeClassifier
Expand Down Expand Up @@ -137,3 +138,64 @@ def test_extra_tree_regressor_pred(load_func):
expected_pred = model.predict(X_test_binned)

assert_array_equal(actual_pred, expected_pred)


@pytest.mark.parametrize("load_func", [load_boston])
def test_tree_regressor_multi_output_pred(load_func):

X, y = load_func(return_X_y=True)

# Generate pseudo multi output targets
y = np.expand_dims(y, axis=1)
y = np.concatenate((y, -y), axis=1)

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, random_state=random_state
)

# Data binning
binner = _BinMapper(random_state=random_state)
X_train_binned = binner.fit_transform(X_train)
X_test_binned = binner.transform(X_test)

# Ours
model = DecisionTreeRegressor(random_state=random_state)
model.fit(X_train_binned, y_train)
actual_pred = model.predict(X_test_binned)

# Sklearn
model = sklearn_DecisionTreeRegressor(random_state=random_state)
model.fit(X_train_binned, y_train)
expected_pred = model.predict(X_test_binned)

assert_array_equal(actual_pred, expected_pred)


@pytest.mark.parametrize("load_func", [load_boston])
def test_extra_tree_regressor_multi_output_pred(load_func):
X, y = load_func(return_X_y=True)

# Generate pseudo multi output targets
y = np.expand_dims(y, axis=1)
y = np.concatenate((y, -y), axis=1)

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, random_state=random_state
)

# Data binning
binner = _BinMapper(random_state=random_state)
X_train_binned = binner.fit_transform(X_train)
X_test_binned = binner.transform(X_test)

# Ours
model = ExtraTreeRegressor(random_state=random_state)
model.fit(X_train_binned, y_train)
actual_pred = model.predict(X_test_binned)

# Sklearn
model = sklearn_ExtraTreeRegressor(random_state=random_state)
model.fit(X_train_binned, y_train)
expected_pred = model.predict(X_test_binned)

assert_array_equal(actual_pred, expected_pred)