Skip to content

Commit

Permalink
feat: Add parameter number_of_tree to RandomForest classifier and…
Browse files Browse the repository at this point in the history
… regressor (#230)

Closes #161.

### Summary of Changes

Added number_of_trees parameter to initiator of random_forest_classifier
and random_forest_regressor.

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
Co-authored-by: Alexander <47296670+Marsmaennchen221@users.noreply.github.com>
Co-authored-by: Lars Reimann <mail@larsreimann.com>
  • Loading branch information
4 people committed Apr 22, 2023
1 parent 4f08a2c commit 414336a
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 10 deletions.
24 changes: 19 additions & 5 deletions src/safeds/ml/classical/classification/_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,23 @@


class RandomForest(Classifier):
"""Random forest classification."""

def __init__(self) -> None:
"""Random forest classification.
Parameters
----------
number_of_trees : int
The number of trees to be used in the random forest. Has to be greater than 0.
Raises
------
ValueError
If the number of trees is less than or equal to 0.
"""

def __init__(self, number_of_trees: int = 100) -> None:
if number_of_trees < 1:
raise ValueError("The number of trees has to be greater than 0.")
self.number_of_trees = number_of_trees
self._wrapped_classifier: sk_RandomForestClassifier | None = None
self._feature_names: list[str] | None = None
self._target_name: str | None = None
Expand All @@ -41,10 +55,10 @@ def fit(self, training_set: TaggedTable) -> RandomForest:
LearningError
If the training data contains invalid values or if the training failed.
"""
wrapped_classifier = sk_RandomForestClassifier(n_jobs=-1)
wrapped_classifier = sk_RandomForestClassifier(self.number_of_trees, n_jobs=-1)
fit(wrapped_classifier, training_set)

result = RandomForest()
result = RandomForest(self.number_of_trees)
result._wrapped_classifier = wrapped_classifier
result._feature_names = training_set.features.column_names
result._target_name = training_set.target.name
Expand Down
24 changes: 19 additions & 5 deletions src/safeds/ml/classical/regression/_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,23 @@


class RandomForest(Regressor):
"""Random forest regression."""

def __init__(self) -> None:
"""Random forest regression.
Parameters
----------
number_of_trees : int
The number of trees to be used in the random forest. Has to be greater than 0.
Raises
------
ValueError
If the number of trees is less than or equal to 0.
"""

def __init__(self, number_of_trees: int = 100) -> None:
if number_of_trees < 1:
raise ValueError("The number of trees has to be greater than 0.")
self.number_of_trees = number_of_trees
self._wrapped_regressor: sk_RandomForestRegressor | None = None
self._feature_names: list[str] | None = None
self._target_name: str | None = None
Expand All @@ -41,10 +55,10 @@ def fit(self, training_set: TaggedTable) -> RandomForest:
LearningError
If the training data contains invalid values or if the training failed.
"""
wrapped_regressor = sk_RandomForestRegressor(n_jobs=-1)
wrapped_regressor = sk_RandomForestRegressor(self.number_of_trees, n_jobs=-1)
fit(wrapped_regressor, training_set)

result = RandomForest()
result = RandomForest(self.number_of_trees)
result._wrapped_regressor = wrapped_regressor
result._feature_names = training_set.features.column_names
result._target_name = training_set.target.name
Expand Down
17 changes: 17 additions & 0 deletions tests/safeds/ml/classical/classification/test_random_forest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.ml.classical.classification import RandomForest


def test_number_of_trees_invalid() -> None:
with pytest.raises(ValueError, match="The number of trees has to be greater than 0."):
RandomForest(-1)


def test_number_of_trees_valid() -> None:
training_set = Table.from_dict({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]})
tagged_training_set = training_set.tag_columns(target_name="col1", feature_names=["col2"])

random_forest = RandomForest(10).fit(tagged_training_set)
assert random_forest._wrapped_classifier is not None
assert random_forest._wrapped_classifier.n_estimators == 10
17 changes: 17 additions & 0 deletions tests/safeds/ml/classical/regression/test_random_forest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.ml.classical.regression import RandomForest


def test_number_of_trees_invalid() -> None:
with pytest.raises(ValueError, match="The number of trees has to be greater than 0."):
RandomForest(-1)


def test_number_of_trees_valid() -> None:
training_set = Table.from_dict({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]})
tagged_training_set = training_set.tag_columns(target_name="col1", feature_names=["col2"])

random_forest = RandomForest(10).fit(tagged_training_set)
assert random_forest._wrapped_regressor is not None
assert random_forest._wrapped_regressor.n_estimators == 10

0 comments on commit 414336a

Please sign in to comment.