Skip to content

Commit

Permalink
feat: return new model when calling fit (#91)
Browse files Browse the repository at this point in the history
Closes #69.

### Summary of Changes

The `fit` method of classifiers/regressors now returns a new (fitted)
classifier regressor. The receiver of the method call is not changed
anymore. This is consistent with the methods on the `Table` class and
other data containers. Furthermore, `fit` is now a pure function, which
works better in notebooks and our own [execution
strategy](https://arxiv.org/abs/2302.14556).

---------

Co-authored-by: lars-reimann <lars-reimann@users.noreply.github.com>
  • Loading branch information
lars-reimann and lars-reimann committed Mar 27, 2023
1 parent 494d69a commit 165c97c
Show file tree
Hide file tree
Showing 35 changed files with 473 additions and 200 deletions.
6 changes: 3 additions & 3 deletions docs/tutorials/machine_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"from safeds.ml.regression import LinearRegression\n",
"\n",
"model = LinearRegression()\n",
"model.fit(tagged_table)"
"fitted_model = model.fit(tagged_table)"
],
"metadata": {
"collapsed": false
Expand All @@ -71,7 +71,7 @@
"source": [
"## Predicting new values\n",
"\n",
"The `fit` method trains the model in place. This means that the model object is modified and can be used to make predictions. Predictions are made by calling the `predict` method on the model object. The `predict` method takes a `Table` as input and returns a `Table` with the predictions:"
"The `fit` method returns the fitted model, the original model is **not** changed. Predictions are made by calling the `predict` method on the fitted model. The `predict` method takes a `Table` as input and returns a `Table` with the predictions:"
],
"metadata": {
"collapsed": false
Expand All @@ -87,7 +87,7 @@
" \"b\": [2, 0, 5, 2, 7],\n",
" \"c\": [1, 4, 3, 2, 1]})\n",
"\n",
"model.predict(dataset=test_set)\n"
"fitted_model.predict(dataset=test_set)\n"
],
"metadata": {
"collapsed": false
Expand Down
8 changes: 6 additions & 2 deletions src/safeds/ml/_util_sklearn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional

from safeds.data.tabular.containers import Table, TaggedTable
from safeds.exceptions import LearningError, PredictionError
Expand Down Expand Up @@ -34,7 +34,7 @@ def fit(model: Any, tagged_table: TaggedTable) -> None:


# noinspection PyProtectedMember
def predict(model: Any, dataset: Table, target_name: str) -> TaggedTable:
def predict(model: Any, dataset: Table, target_name: Optional[str]) -> TaggedTable:
"""
Predict a target vector using a dataset containing feature vectors. The model has to be trained first.
Expand All @@ -57,6 +57,10 @@ def predict(model: Any, dataset: Table, target_name: str) -> TaggedTable:
PredictionError
If predicting with the given dataset failed.
"""

if model is None or target_name is None:
raise PredictionError("The model was not trained")

dataset_df = dataset._data
dataset_df.columns = dataset.schema.get_column_names()
try:
Expand Down
33 changes: 25 additions & 8 deletions src/safeds/ml/classification/_ada_boost.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from typing import Optional

from safeds.data.tabular.containers import Table, TaggedTable
from safeds.ml._util_sklearn import fit, predict
from sklearn.ensemble import AdaBoostClassifier as sk_AdaBoostClassifier
Expand All @@ -12,25 +16,38 @@ class AdaBoost(Classifier):
"""

def __init__(self) -> None:
self._wrapped_classifier = sk_AdaBoostClassifier()
self._target_name = ""
self._wrapped_classifier: Optional[sk_AdaBoostClassifier] = None
self._target_name: Optional[str] = None

def fit(self, training_set: TaggedTable) -> None:
def fit(self, training_set: TaggedTable) -> AdaBoost:
"""
Fit this model given a tagged table.
Create a new classifier based on this one and fit it with the given training data. This classifier is not
modified.
Parameters
----------
training_set : TaggedTable
The tagged table containing the feature and target vectors.
The training data containing the feature and target vectors.
Returns
-------
fitted_classifier : AdaBoost
The fitted classifier.
Raises
------
LearningError
If the tagged table contains invalid values or if the training failed.
If the training data contains invalid values or if the training failed.
"""
fit(self._wrapped_classifier, training_set)
self._target_name = training_set.target.name

wrapped_classifier = sk_AdaBoostClassifier()
fit(wrapped_classifier, training_set)

result = AdaBoost()
result._wrapped_classifier = wrapped_classifier
result._target_name = training_set.target.name

return result

def predict(self, dataset: Table) -> TaggedTable:
"""
Expand Down
16 changes: 12 additions & 4 deletions src/safeds/ml/classification/_classifier.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from abc import ABC, abstractmethod

from safeds.data.tabular.containers import Table, TaggedTable
Expand All @@ -10,19 +12,25 @@ class Classifier(ABC):
"""

@abstractmethod
def fit(self, training_set: TaggedTable) -> None:
def fit(self, training_set: TaggedTable) -> Classifier:
"""
Fit this model given a tagged table.
Create a new classifier based on this one and fit it with the given training data. This classifier is not
modified.
Parameters
----------
training_set : TaggedTable
The tagged table containing the feature and target vectors.
The training data containing the feature and target vectors.
Returns
-------
fitted_classifier : Classifier
The fitted classifier.
Raises
------
LearningError
If the tagged table contains invalid values or if the training failed.
If the training data contains invalid values or if the training failed.
"""

@abstractmethod
Expand Down
33 changes: 25 additions & 8 deletions src/safeds/ml/classification/_decision_tree.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from typing import Optional

from safeds.data.tabular.containers import Table, TaggedTable
from safeds.ml._util_sklearn import fit, predict
from sklearn.tree import DecisionTreeClassifier as sk_DecisionTreeClassifier
Expand All @@ -12,25 +16,38 @@ class DecisionTree(Classifier):
"""

def __init__(self) -> None:
self._wrapped_classifier = sk_DecisionTreeClassifier()
self._target_name = ""
self._wrapped_classifier: Optional[sk_DecisionTreeClassifier] = None
self._target_name: Optional[str] = None

def fit(self, training_set: TaggedTable) -> None:
def fit(self, training_set: TaggedTable) -> DecisionTree:
"""
Fit this model given a tagged table.
Create a new classifier based on this one and fit it with the given training data. This classifier is not
modified.
Parameters
----------
training_set : TaggedTable
The tagged table containing the feature and target vectors.
The training data containing the feature and target vectors.
Returns
-------
fitted_classifier : DecisionTree
The fitted classifier.
Raises
------
LearningError
If the tagged table contains invalid values or if the training failed.
If the training data contains invalid values or if the training failed.
"""
fit(self._wrapped_classifier, training_set)
self._target_name = training_set.target.name

wrapped_classifier = sk_DecisionTreeClassifier()
fit(wrapped_classifier, training_set)

result = DecisionTree()
result._wrapped_classifier = wrapped_classifier
result._target_name = training_set.target.name

return result

def predict(self, dataset: Table) -> TaggedTable:
"""
Expand Down
33 changes: 25 additions & 8 deletions src/safeds/ml/classification/_gradient_boosting_classification.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from typing import Optional

from safeds.data.tabular.containers import Table, TaggedTable
from safeds.ml._util_sklearn import fit, predict
from sklearn.ensemble import GradientBoostingClassifier as sk_GradientBoostingClassifier
Expand All @@ -12,25 +16,38 @@ class GradientBoosting(Classifier):
"""

def __init__(self) -> None:
self._wrapped_classifier = sk_GradientBoostingClassifier()
self._target_name = ""
self._wrapped_classifier: Optional[sk_GradientBoostingClassifier] = None
self._target_name: Optional[str] = None

def fit(self, training_set: TaggedTable) -> None:
def fit(self, training_set: TaggedTable) -> GradientBoosting:
"""
Fit this model given a tagged table.
Create a new classifier based on this one and fit it with the given training data. This classifier is not
modified.
Parameters
----------
training_set : TaggedTable
The tagged table containing the feature and target vectors.
The training data containing the feature and target vectors.
Returns
-------
fitted_classifier : GradientBoosting
The fitted classifier.
Raises
------
LearningError
If the tagged table contains invalid values or if the training failed.
If the training data contains invalid values or if the training failed.
"""
fit(self._wrapped_classifier, training_set)
self._target_name = training_set.target.name

wrapped_classifier = sk_GradientBoostingClassifier()
fit(wrapped_classifier, training_set)

result = GradientBoosting()
result._wrapped_classifier = wrapped_classifier
result._target_name = training_set.target.name

return result

# noinspection PyProtectedMember
def predict(self, dataset: Table) -> TaggedTable:
Expand Down
36 changes: 26 additions & 10 deletions src/safeds/ml/classification/_k_nearest_neighbors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from typing import Optional

from safeds.data.tabular.containers import Table, TaggedTable
from safeds.ml._util_sklearn import fit, predict
from sklearn.neighbors import KNeighborsClassifier as sk_KNeighborsClassifier
Expand All @@ -16,27 +20,39 @@ class KNearestNeighbors(Classifier):
"""

def __init__(self, n_neighbors: int) -> None:
self._wrapped_classifier = sk_KNeighborsClassifier(
n_jobs=-1, n_neighbors=n_neighbors
)
self._target_name = ""
self._n_neighbors = n_neighbors

def fit(self, training_set: TaggedTable) -> None:
self._wrapped_classifier: Optional[sk_KNeighborsClassifier] = None
self._target_name: Optional[str] = None

def fit(self, training_set: TaggedTable) -> KNearestNeighbors:
"""
Fit this model given a tagged table.
Create a new classifier based on this one and fit it with the given training data. This classifier is not
modified.
Parameters
----------
training_set : TaggedTable
The tagged table containing the feature and target vectors.
The training data containing the feature and target vectors.
Returns
-------
fitted_classifier : KNearestNeighbors
The fitted classifier.
Raises
------
LearningError
If the tagged table contains invalid values or if the training failed.
If the training data contains invalid values or if the training failed.
"""
fit(self._wrapped_classifier, training_set)
self._target_name = training_set.target.name
wrapped_classifier = sk_KNeighborsClassifier(self._n_neighbors, n_jobs=-1)
fit(wrapped_classifier, training_set)

result = KNearestNeighbors(self._n_neighbors)
result._wrapped_classifier = wrapped_classifier
result._target_name = training_set.target.name

return result

def predict(self, dataset: Table) -> TaggedTable:
"""
Expand Down
33 changes: 25 additions & 8 deletions src/safeds/ml/classification/_logistic_regression.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from typing import Optional

from safeds.data.tabular.containers import Table, TaggedTable
from safeds.ml._util_sklearn import fit, predict
from sklearn.linear_model import LogisticRegression as sk_LogisticRegression
Expand All @@ -12,25 +16,38 @@ class LogisticRegression(Classifier):
"""

def __init__(self) -> None:
self._wrapped_classifier = sk_LogisticRegression(n_jobs=-1)
self._target_name = ""
self._wrapped_classifier: Optional[sk_LogisticRegression] = None
self._target_name: Optional[str] = None

def fit(self, training_set: TaggedTable) -> None:
def fit(self, training_set: TaggedTable) -> LogisticRegression:
"""
Fit this model given a tagged table.
Create a new classifier based on this one and fit it with the given training data. This classifier is not
modified.
Parameters
----------
training_set : TaggedTable
The tagged table containing the feature and target vectors.
The training data containing the feature and target vectors.
Returns
-------
fitted_classifier : LogisticRegression
The fitted classifier.
Raises
------
LearningError
If the tagged table contains invalid values or if the training failed.
If the training data contains invalid values or if the training failed.
"""
fit(self._wrapped_classifier, training_set)
self._target_name = training_set.target.name

wrapped_classifier = sk_LogisticRegression(n_jobs=-1)
fit(wrapped_classifier, training_set)

result = LogisticRegression()
result._wrapped_classifier = wrapped_classifier
result._target_name = training_set.target.name

return result

def predict(self, dataset: Table) -> TaggedTable:
"""
Expand Down

0 comments on commit 165c97c

Please sign in to comment.