-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Implement Proximity Forest classifier #1729
base: main
Are you sure you want to change the base?
Changes from 14 commits
dc8bd1d
d4ce3e3
7406dff
7ed886f
d063b99
81f86ee
1ab94aa
3359fe3
80f1ca8
db136c6
4631e8d
b7d0461
59a8175
8905461
0294311
2d74d4d
7953c00
b7505ad
c853e55
8959efa
1cb74a4
e5a095f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
"""Proximity Forest Classifier. | ||
|
||
The Proximity Forest is an ensemble of Proximity Trees. | ||
""" | ||
|
||
__all__ = ["ProximityForest"] | ||
|
||
from typing import Type, Union | ||
|
||
import numpy as np | ||
from joblib import Parallel, delayed | ||
|
||
from aeon.classification.base import BaseClassifier | ||
from aeon.classification.distance_based._proximity_tree import ProximityTree | ||
|
||
|
||
class ProximityForest(BaseClassifier): | ||
"""Proximity Forest Classifier. | ||
|
||
The Proximity Forest is a distance-based classifier that creates an | ||
ensemble of decision trees, where the splits are based on the | ||
similarity between time series measured using various parameterised | ||
distance measures. | ||
|
||
Parameters | ||
---------- | ||
n_trees: int, default = 100 | ||
The number of trees, by default an ensemble of 100 trees is formed. | ||
n_splitters: int, default = 5 | ||
The number of candidate splitters to be evaluated at each node. | ||
max_depth: int, default = None | ||
The maximum depth of the tree. If None, then nodes are expanded until all | ||
leaves are pure or until all leaves contain less than min_samples_split samples. | ||
min_samples_split: int, default = 2 | ||
The minimum number of samples required to split an internal node. | ||
random_state : int, RandomState instance or None, default=None | ||
If `int`, random_state is the seed used by the random number generator; | ||
If `RandomState` instance, random_state is the random number generator; | ||
If `None`, the random number generator is the `RandomState` instance used | ||
by `np.random`. | ||
n_jobs : int, default = 1 | ||
The number of parallel jobs to run for neighbors search. | ||
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. | ||
``-1`` means using all processors. See :term:`Glossary <n_jobs>` | ||
for more details. Parameter for compatibility purposes, still unimplemented. | ||
|
||
Notes | ||
----- | ||
For the Java version, see | ||
`ProximityForest | ||
<https://github.com/fpetitjean/ProximityForest>`_. | ||
|
||
References | ||
---------- | ||
.. [1] Lucas, B., Shifaz, A., Pelletier, C., O’Neill, L., Zaidi, N., Goethals, B., | ||
Petitjean, F. and Webb, G.I., 2019. Proximity forest: an effective and scalable | ||
distance-based classifier for time series. Data Mining and Knowledge Discovery, | ||
33(3), pp.607-635. | ||
|
||
Examples | ||
-------- | ||
>>> from aeon.datasets import load_unit_test | ||
>>> from aeon.classification.distance_based import ProximityForest | ||
>>> X_train, y_train = load_unit_test(split="train") | ||
>>> X_test, y_test = load_unit_test(split="test") | ||
>>> classifier = ProximityForest(n_trees = 10, n_splitters = 3) | ||
>>> classifier.fit(X_train, y_train) | ||
ProximityForest(...) | ||
>>> y_pred = classifier.predict(X_test) | ||
""" | ||
|
||
_tags = { | ||
"capability:multivariate": False, | ||
"capability:unequal_length": False, | ||
"capability:multithreading": True, | ||
"algorithm_type": "distance", | ||
"X_inner_type": ["numpy2D"], | ||
} | ||
|
||
def __init__( | ||
self, | ||
n_trees=100, | ||
n_splitters: int = 5, | ||
max_depth: int = None, | ||
min_samples_split: int = 2, | ||
random_state: Union[int, Type[np.random.RandomState], None] = None, | ||
n_jobs: int = 1, | ||
): | ||
self.n_trees = n_trees | ||
self.n_splitters = n_splitters | ||
self.max_depth = max_depth | ||
self.min_samples_split = min_samples_split | ||
self.random_state = random_state | ||
self.n_jobs = n_jobs | ||
super().__init__() | ||
|
||
def _fit(self, X, y): | ||
self.classes_ = list(np.unique(y)) | ||
self.trees_ = Parallel(n_jobs=self.n_jobs)( | ||
delayed(self._fit_tree)(X, y) for _ in range(self.n_trees) | ||
) | ||
|
||
def _fit_tree(self, X, y): | ||
clf = ProximityTree( | ||
n_splitters=self.n_splitters, | ||
max_depth=self.max_depth, | ||
min_samples_split=self.min_samples_split, | ||
random_state=self.random_state, | ||
n_jobs=self.n_jobs, | ||
) | ||
clf.fit(X, y) | ||
return clf | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar comment for predict, but I think it might be better to define the function you parallelize with joblib outside of the object you call them from. Something to do with the fact that joblib pickling the objects you parallelize, if I remember right ? This might mean that you create a copy of the To avoid that, you would define There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing this out. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this true? I think we have functions elsewhere that do this. Interesting to see if that needs to be changed. |
||
|
||
def _predict_proba(self, X): | ||
output_probas = Parallel(n_jobs=self.n_jobs)( | ||
delayed(self._predict_proba_tree)(tree, X) for tree in self.trees_ | ||
) | ||
output_probas = np.sum(output_probas, axis=0) | ||
output_probas = np.divide(output_probas, self.n_trees) | ||
return output_probas | ||
|
||
def _predict_proba_tree(self, tree, X): | ||
return tree.predict_proba(X) | ||
|
||
def _predict(self, X): | ||
probas = self._predict_proba(X) | ||
idx = np.argmax(probas, axis=1) | ||
preds = np.asarray([self.classes_[x] for x in idx]) | ||
return preds |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
"""Test for Proximity Forest.""" | ||
|
||
import pytest | ||
from sklearn.metrics import accuracy_score | ||
|
||
from aeon.classification.distance_based import ProximityForest | ||
from aeon.testing.data_generation import make_example_3d_numpy | ||
|
||
|
||
@pytest.fixture | ||
def time_series_dataset(): | ||
"""Generate time series dataset for testing.""" | ||
n_samples = 100 # Total number of samples (should be even) | ||
n_timepoints = 24 # Length of each time series | ||
n_channels = 1 | ||
data, labels = make_example_3d_numpy(n_samples, n_channels, n_timepoints) | ||
return data, labels | ||
|
||
|
||
def test_univariate(time_series_dataset): | ||
"""Test that the function gives appropriate error message.""" | ||
X, y = time_series_dataset | ||
X_multivariate = X.reshape((100, 2, 12)) | ||
clf = ProximityForest(n_trees=5, random_state=42, n_jobs=-1) | ||
with pytest.raises(ValueError): | ||
clf.fit(X_multivariate, y) | ||
|
||
|
||
def test_proximity_forest(time_series_dataset): | ||
"""Test the fit method of ProximityTree.""" | ||
X, y = time_series_dataset | ||
clf = ProximityForest(n_trees=5, n_splitters=3, max_depth=4) | ||
clf.fit(X, y) | ||
X_test, y_test = time_series_dataset | ||
y_pred = clf.predict(X_test) | ||
score = accuracy_score(y_test, y_pred) | ||
assert score >= 0.9 | ||
itsdivya1309 marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is done in the base class, don't want to do so here. Caused an issue when trying to run it 🙂