Skip to content

Commit

Permalink
feat: return TaggedTable from predict (#73)
Browse files Browse the repository at this point in the history
### Summary of Changes

The `predict` method returns the original feature vector combined with
the predicted target values. This is exactly what a `TaggedTable` can
express, so we return this instead of a `Table`.

---------

Co-authored-by: lars-reimann <lars-reimann@users.noreply.github.com>
  • Loading branch information
lars-reimann and lars-reimann committed Mar 24, 2023
1 parent 8655521 commit 5d5f5a6
Show file tree
Hide file tree
Showing 18 changed files with 147 additions and 195 deletions.
12 changes: 3 additions & 9 deletions src/safeds/ml/_util_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# noinspection PyProtectedMember
def fit(model: Any, tagged_table: TaggedTable) -> str:
def fit(model: Any, tagged_table: TaggedTable) -> None:
"""
Fit a model for a given tagged table.
Expand All @@ -17,11 +17,6 @@ def fit(model: Any, tagged_table: TaggedTable) -> str:
tagged_table : TaggedTable
The tagged table containing the feature and target vectors.
Returns
-------
target_name : str
The target column name, inferred from the tagged table.
Raises
------
LearningError
Expand All @@ -32,15 +27,14 @@ def fit(model: Any, tagged_table: TaggedTable) -> str:
tagged_table.feature_vectors._data,
tagged_table.target_values._data,
)
return tagged_table.target_values.name
except ValueError as exception:
raise LearningError(str(exception)) from exception
except Exception as exception:
raise LearningError(None) from exception


# noinspection PyProtectedMember
def predict(model: Any, dataset: Table, target_name: str) -> Table:
def predict(model: Any, dataset: Table, target_name: str) -> TaggedTable:
"""
Predict a target vector using a dataset containing feature vectors. The model has to be trained first.
Expand Down Expand Up @@ -73,7 +67,7 @@ def predict(model: Any, dataset: Table, target_name: str) -> Table:
f"Dataset already contains '{target_name}' column. Please rename this column"
)
result_set[target_name] = predicted_target_vector
return Table(result_set)
return TaggedTable(Table(result_set), target_column=target_name)
except NotFittedError as exception:
raise PredictionError("The model was not trained") from exception
except ValueError as exception:
Expand Down
21 changes: 8 additions & 13 deletions src/safeds/ml/classification/_ada_boost.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# noinspection PyProtectedMember
import safeds.ml._util_sklearn
from safeds.data.tabular.containers import Table, TaggedTable
from safeds.ml._util_sklearn import fit, predict
from sklearn.ensemble import AdaBoostClassifier as sk_AdaBoostClassifier

from ._classifier import Classifier
Expand All @@ -14,8 +14,8 @@ class AdaBoost(Classifier):
"""

def __init__(self) -> None:
self._classification = sk_AdaBoostClassifier()
self.target_name = ""
self._wrapped_classifier = sk_AdaBoostClassifier()
self._target_name = ""

def fit(self, training_set: TaggedTable) -> None:
"""
Expand All @@ -31,11 +31,10 @@ def fit(self, training_set: TaggedTable) -> None:
LearningError
If the tagged table contains invalid values or if the training failed.
"""
self.target_name = safeds.ml._util_sklearn.fit(
self._classification, training_set
)
fit(self._wrapped_classifier, training_set)
self._target_name = training_set.target_values.name

def predict(self, dataset: Table) -> Table:
def predict(self, dataset: Table) -> TaggedTable:
"""
Predict a target vector using a dataset containing feature vectors. The model has to be trained first.
Expand All @@ -46,16 +45,12 @@ def predict(self, dataset: Table) -> Table:
Returns
-------
table : Table
table : TaggedTable
A dataset containing the given feature vectors and the predicted target vector.
Raises
------
PredictionError
If prediction with the given dataset failed.
"""
return safeds.ml._util_sklearn.predict(
self._classification,
dataset,
self.target_name,
)
return predict(self._wrapped_classifier, dataset, self._target_name)
8 changes: 6 additions & 2 deletions src/safeds/ml/classification/_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@


class Classifier(ABC):
"""
Abstract base class for all classifiers.
"""

@abstractmethod
def fit(self, training_set: TaggedTable) -> None:
"""
Expand All @@ -21,7 +25,7 @@ def fit(self, training_set: TaggedTable) -> None:
"""

@abstractmethod
def predict(self, dataset: Table) -> Table:
def predict(self, dataset: Table) -> TaggedTable:
"""
Predict a target vector using a dataset containing feature vectors. The model has to be trained first.
Expand All @@ -32,7 +36,7 @@ def predict(self, dataset: Table) -> Table:
Returns
-------
table : Table
table : TaggedTable
A dataset containing the given feature vectors and the predicted target vector.
Raises
Expand Down
21 changes: 8 additions & 13 deletions src/safeds/ml/classification/_decision_tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# noinspection PyProtectedMember
import safeds.ml._util_sklearn
from safeds.data.tabular.containers import Table, TaggedTable
from safeds.ml._util_sklearn import fit, predict
from sklearn.tree import DecisionTreeClassifier as sk_DecisionTreeClassifier

from ._classifier import Classifier
Expand All @@ -14,8 +14,8 @@ class DecisionTree(Classifier):
"""

def __init__(self) -> None:
self._classification = sk_DecisionTreeClassifier()
self.target_name = ""
self._wrapped_classifier = sk_DecisionTreeClassifier()
self._target_name = ""

def fit(self, training_set: TaggedTable) -> None:
"""
Expand All @@ -31,11 +31,10 @@ def fit(self, training_set: TaggedTable) -> None:
LearningError
If the tagged table contains invalid values or if the training failed.
"""
self.target_name = safeds.ml._util_sklearn.fit(
self._classification, training_set
)
fit(self._wrapped_classifier, training_set)
self._target_name = training_set.target_values.name

def predict(self, dataset: Table) -> Table:
def predict(self, dataset: Table) -> TaggedTable:
"""
Predict a target vector using a dataset containing feature vectors. The model has to be trained first.
Expand All @@ -46,16 +45,12 @@ def predict(self, dataset: Table) -> Table:
Returns
-------
table : Table
table : TaggedTable
A dataset containing the given feature vectors and the predicted target vector.
Raises
------
PredictionError
If prediction with the given dataset failed.
"""
return safeds.ml._util_sklearn.predict(
self._classification,
dataset,
self.target_name,
)
return predict(self._wrapped_classifier, dataset, self._target_name)
23 changes: 9 additions & 14 deletions src/safeds/ml/classification/_gradient_boosting_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# noinspection PyProtectedMember
import safeds.ml._util_sklearn
from safeds.data.tabular.containers import Table, TaggedTable
from sklearn.ensemble import GradientBoostingClassifier
from safeds.ml._util_sklearn import fit, predict
from sklearn.ensemble import GradientBoostingClassifier as sk_GradientBoostingClassifier

from ._classifier import Classifier

Expand All @@ -14,8 +14,8 @@ class GradientBoosting(Classifier):
"""

def __init__(self) -> None:
self._classification = GradientBoostingClassifier()
self.target_name = ""
self._wrapped_classifier = sk_GradientBoostingClassifier()
self._target_name = ""

def fit(self, training_set: TaggedTable) -> None:
"""
Expand All @@ -31,12 +31,11 @@ def fit(self, training_set: TaggedTable) -> None:
LearningError
If the tagged table contains invalid values or if the training failed.
"""
self.target_name = safeds.ml._util_sklearn.fit(
self._classification, training_set
)
fit(self._wrapped_classifier, training_set)
self._target_name = training_set.target_values.name

# noinspection PyProtectedMember
def predict(self, dataset: Table) -> Table:
def predict(self, dataset: Table) -> TaggedTable:
"""
Predict a target vector using a dataset containing feature vectors. The model has to be trained first.
Expand All @@ -47,16 +46,12 @@ def predict(self, dataset: Table) -> Table:
Returns
-------
table : Table
table : TaggedTable
A dataset containing the given feature vectors and the predicted target vector.
Raises
------
PredictionError
If prediction with the given dataset failed.
"""
return safeds.ml._util_sklearn.predict(
self._classification,
dataset,
self.target_name,
)
return predict(self._wrapped_classifier, dataset, self._target_name)
25 changes: 13 additions & 12 deletions src/safeds/ml/classification/_k_nearest_neighbors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# noinspection PyProtectedMember
import safeds.ml._util_sklearn
from safeds.data.tabular.containers import Table, TaggedTable
from sklearn.neighbors import KNeighborsClassifier
from safeds.ml._util_sklearn import fit, predict
from sklearn.neighbors import KNeighborsClassifier as sk_KNeighborsClassifier

from ._classifier import Classifier

Expand All @@ -18,8 +18,10 @@ class KNearestNeighbors(Classifier):
"""

def __init__(self, n_neighbors: int) -> None:
self._classification = KNeighborsClassifier(n_jobs=-1, n_neighbors=n_neighbors)
self.target_name = ""
self._wrapped_classifier = sk_KNeighborsClassifier(
n_jobs=-1, n_neighbors=n_neighbors
)
self._target_name = ""

def fit(self, training_set: TaggedTable) -> None:
"""
Expand All @@ -35,11 +37,10 @@ def fit(self, training_set: TaggedTable) -> None:
LearningError
If the tagged table contains invalid values or if the training failed.
"""
self.target_name = safeds.ml._util_sklearn.fit(
self._classification, training_set
)
fit(self._wrapped_classifier, training_set)
self._target_name = training_set.target_values.name

def predict(self, dataset: Table) -> Table:
def predict(self, dataset: Table) -> TaggedTable:
"""
Predict a target vector using a dataset containing feature vectors. The model has to be trained first
Expand All @@ -50,16 +51,16 @@ def predict(self, dataset: Table) -> Table:
Returns
-------
table : Table
table : TaggedTable
A dataset containing the given feature vectors and the predicted target vector.
Raises
------
PredictionError
If prediction with the given dataset failed.
"""
return safeds.ml._util_sklearn.predict(
self._classification,
return predict(
self._wrapped_classifier,
dataset,
self.target_name,
self._target_name,
)
21 changes: 8 additions & 13 deletions src/safeds/ml/classification/_logistic_regression.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# noinspection PyProtectedMember
import safeds.ml._util_sklearn
from safeds.data.tabular.containers import Table, TaggedTable
from safeds.ml._util_sklearn import fit, predict
from sklearn.linear_model import LogisticRegression as sk_LogisticRegression

from ._classifier import Classifier
Expand All @@ -14,8 +14,8 @@ class LogisticRegression(Classifier):
"""

def __init__(self) -> None:
self._classification = sk_LogisticRegression(n_jobs=-1)
self.target_name = ""
self._wrapped_classifier = sk_LogisticRegression(n_jobs=-1)
self._target_name = ""

def fit(self, training_set: TaggedTable) -> None:
"""
Expand All @@ -31,11 +31,10 @@ def fit(self, training_set: TaggedTable) -> None:
LearningError
If the tagged table contains invalid values or if the training failed.
"""
self.target_name = safeds.ml._util_sklearn.fit(
self._classification, training_set
)
fit(self._wrapped_classifier, training_set)
self._target_name = training_set.target_values.name

def predict(self, dataset: Table) -> Table:
def predict(self, dataset: Table) -> TaggedTable:
"""
Predict a target vector using a dataset containing feature vectors. The model has to be trained first.
Expand All @@ -46,16 +45,12 @@ def predict(self, dataset: Table) -> Table:
Returns
-------
table : Table
table : TaggedTable
A dataset containing the given feature vectors and the predicted target vector.
Raises
------
PredictionError
If prediction with the given dataset failed.
"""
return safeds.ml._util_sklearn.predict(
self._classification,
dataset,
self.target_name,
)
return predict(self._wrapped_classifier, dataset, self._target_name)

0 comments on commit 5d5f5a6

Please sign in to comment.