Skip to content

Commit

Permalink
feat: return fitted transformer and transformed table from `fit_and_t…
Browse files Browse the repository at this point in the history
…ransform` (#724)

Closes #613

### Summary of Changes

`TableTransformer.fit_and_transform` now returns a tuple containing
* the fitted transformer
* the transformed table.

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
  • Loading branch information
lars-reimann and megalinter-bot committed May 5, 2024
1 parent 31ffd12 commit 2960d35
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 103 deletions.
2 changes: 0 additions & 2 deletions src/safeds/data/tabular/transformation/_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def is_fitted(self) -> bool:
"""Whether the transformer is fitted."""
return self._wrapped_transformer is not None

# noinspection PyProtectedMember
def fit(self, table: Table, column_names: list[str] | None) -> Imputer:
"""
Learn a transformation for a set of columns in a table.
Expand Down Expand Up @@ -199,7 +198,6 @@ def fit(self, table: Table, column_names: list[str] | None) -> Imputer:

return result

# noinspection PyProtectedMember
def transform(self, table: Table) -> Table:
"""
Apply the learned transformation to a table.
Expand Down
1 change: 0 additions & 1 deletion src/safeds/data/tabular/transformation/_label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from sklearn.preprocessing import OrdinalEncoder as sk_OrdinalEncoder


# noinspection PyProtectedMember
class LabelEncoder(InvertibleTableTransformer):
"""The LabelEncoder encodes one or more given columns into labels."""

Expand Down
5 changes: 1 addition & 4 deletions src/safeds/data/tabular/transformation/_one_hot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class OneHotEncoder(InvertibleTableTransformer):
>>> from safeds.data.tabular.transformation import OneHotEncoder
>>> table = Table({"col1": ["a", "b", "c", "a"]})
>>> transformer = OneHotEncoder()
>>> transformer.fit_and_transform(table, ["col1"])
>>> transformer.fit_and_transform(table, ["col1"])[1]
col1__a col1__b col1__c
0 1.0 0.0 0.0
1 0.0 1.0 0.0
Expand All @@ -65,7 +65,6 @@ def __init__(self) -> None:
# Maps nan values (str of old column) to corresponding new column name
self._value_to_column_nans: dict[str, str] | None = None

# noinspection PyProtectedMember
def fit(self, table: Table, column_names: list[str] | None) -> OneHotEncoder:
"""
Learn a transformation for a set of columns in a table.
Expand Down Expand Up @@ -150,7 +149,6 @@ def fit(self, table: Table, column_names: list[str] | None) -> OneHotEncoder:

return result

# noinspection PyProtectedMember
def transform(self, table: Table) -> Table:
"""
Apply the learned transformation to a table.
Expand Down Expand Up @@ -238,7 +236,6 @@ def transform(self, table: Table) -> Table:
# Apply sorting and return:
return table.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

# noinspection PyProtectedMember
def inverse_transform(self, transformed_table: Table) -> Table:
"""
Undo the learned transformation.
Expand Down
42 changes: 14 additions & 28 deletions src/safeds/data/tabular/transformation/_table_transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Self

from safeds._utils import _structural_hash

Expand All @@ -26,8 +26,13 @@ def __hash__(self) -> int:
removed = self.get_names_of_removed_columns() if self.is_fitted else []
return _structural_hash(self.__class__.__qualname__, self.is_fitted, added, changed, removed)

@property
@abstractmethod
def is_fitted(self) -> bool:
"""Whether the transformer is fitted."""

@abstractmethod
def fit(self, table: Table, column_names: list[str] | None) -> TableTransformer:
def fit(self, table: Table, column_names: list[str] | None) -> Self:
"""
Learn a transformation for a set of columns in a table.
Expand Down Expand Up @@ -117,16 +122,11 @@ def get_names_of_removed_columns(self) -> list[str]:
If the transformer has not been fitted yet.
"""

@property
@abstractmethod
def is_fitted(self) -> bool:
"""Whether the transformer is fitted."""

def fit_and_transform(self, table: Table, column_names: list[str] | None = None) -> Table:
def fit_and_transform(self, table: Table, column_names: list[str] | None = None) -> tuple[Self, Table]:
"""
Learn a transformation for a set of columns in a table and apply the learned transformation to the same table.
The table is not modified. If you also need the fitted transformer, use `fit` and `transform` separately.
Neither the transformer nor the table are modified.
Parameters
----------
Expand All @@ -137,33 +137,19 @@ def fit_and_transform(self, table: Table, column_names: list[str] | None = None)
Returns
-------
fitted_transformer:
The fitted transformer.
transformed_table:
The transformed table.
"""
return self.fit(table, column_names).transform(table)
fitted_transformer = self.fit(table, column_names)
transformed_table = fitted_transformer.transform(table)
return fitted_transformer, transformed_table


class InvertibleTableTransformer(TableTransformer):
"""A `TableTransformer` that can also undo the learned transformation after it has been applied."""

@abstractmethod
def fit(self, table: Table, column_names: list[str] | None) -> InvertibleTableTransformer:
"""
Learn a transformation for a set of columns in a table.
Parameters
----------
table:
The table used to fit the transformer.
column_names:
The list of columns from the table used to fit the transformer. If `None`, all columns are used.
Returns
-------
fitted_transformer:
The fitted transformer.
"""

@abstractmethod
def inverse_transform(self, transformed_table: Table) -> Table:
"""
Expand Down
10 changes: 7 additions & 3 deletions tests/safeds/data/tabular/transformation/test_discretizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,15 @@ class TestFitAndTransform:
],
ids=["None", "col1"],
)
def test_should_return_transformed_table(
def test_should_return_fitted_transformer_and_transformed_table(
self,
table: Table,
column_names: list[str] | None,
expected: Table,
) -> None:
assert Discretizer().fit_and_transform(table, column_names) == expected
fitted_transformer, transformed_table = Discretizer().fit_and_transform(table, column_names)
assert fitted_transformer.is_fitted
assert transformed_table == expected

@pytest.mark.parametrize(
("table", "number_of_bins", "expected"),
Expand Down Expand Up @@ -243,7 +245,9 @@ def test_should_return_transformed_table_with_correct_number_of_bins(
number_of_bins: int,
expected: Table,
) -> None:
assert Discretizer(number_of_bins).fit_and_transform(table, ["col1"]) == expected
fitted_transformer, transformed_table = Discretizer(number_of_bins).fit_and_transform(table, ["col1"])
assert fitted_transformer.is_fitted
assert transformed_table == expected

def test_should_not_change_original_table(self) -> None:
table = Table(
Expand Down
28 changes: 13 additions & 15 deletions tests/safeds/data/tabular/transformation/test_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,29 +413,27 @@ class TestFitAndTransform:
"other value to replace",
],
)
def test_should_return_transformed_table(
def test_should_return_fitted_transformer_and_transformed_table(
self,
table: Table,
column_names: list[str] | None,
strategy: Imputer.Strategy,
value_to_replace: float | str | None,
expected: Table,
) -> None:
if isinstance(strategy, _Mode):
with warnings.catch_warnings():
warnings.filterwarnings(
action="ignore",
message=r"There are multiple most frequent values in a column given to the Imputer\..*",
category=UserWarning,
)
assert (
Imputer(strategy, value_to_replace=value_to_replace).fit_and_transform(table, column_names)
== expected
)
else:
assert (
Imputer(strategy, value_to_replace=value_to_replace).fit_and_transform(table, column_names) == expected
with warnings.catch_warnings():
warnings.filterwarnings(
action="ignore",
message=r"There are multiple most frequent values in a column given to the Imputer\..*",
category=UserWarning,
)
fitted_transformer, transformed_table = Imputer(
strategy,
value_to_replace=value_to_replace,
).fit_and_transform(table, column_names)

assert fitted_transformer.is_fitted
assert transformed_table == expected

@pytest.mark.parametrize("strategy", strategies(), ids=lambda x: x.__class__.__name__)
def test_should_not_change_original_table(self, strategy: Imputer.Strategy) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,15 @@ class TestFitAndTransform:
],
ids=["no_column_names", "with_column_names"],
)
def test_should_return_transformed_table(
def test_should_return_fitted_transformer_and_transformed_table(
self,
table: Table,
column_names: list[str] | None,
expected: Table,
) -> None:
assert LabelEncoder().fit_and_transform(table, column_names) == expected
fitted_transformer, transformed_table = LabelEncoder().fit_and_transform(table, column_names)
assert fitted_transformer.is_fitted
assert transformed_table == expected

def test_should_not_change_original_table(self) -> None:
table = Table(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,15 @@ class TestFitAndTransform:
"column with nans",
],
)
def test_should_return_transformed_table(
def test_should_return_fitted_transformer_and_transformed_table(
self,
table: Table,
column_names: list[str] | None,
expected: Table,
) -> None:
assert OneHotEncoder().fit_and_transform(table, column_names) == expected
fitted_transformer, transformed_table = OneHotEncoder().fit_and_transform(table, column_names)
assert fitted_transformer.is_fitted
assert transformed_table == expected

def test_should_not_change_original_table(self) -> None:
table = Table(
Expand Down
15 changes: 11 additions & 4 deletions tests/safeds/data/tabular/transformation/test_range_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,15 @@ class TestFitAndTransform:
],
ids=["one_column", "two_columns"],
)
def test_should_return_transformed_table(
def test_should_return_fitted_transformer_and_transformed_table(
self,
table: Table,
column_names: list[str] | None,
expected: Table,
) -> None:
assert RangeScaler().fit_and_transform(table, column_names) == expected
fitted_transformer, transformed_table = RangeScaler().fit_and_transform(table, column_names)
assert fitted_transformer.is_fitted
assert transformed_table == expected

@pytest.mark.parametrize(
("table", "column_names", "expected"),
Expand Down Expand Up @@ -186,13 +188,18 @@ def test_should_return_transformed_table(
],
ids=["one_column", "two_columns"],
)
def test_should_return_transformed_table_with_correct_range(
def test_should_return_fitted_transformer_and_transformed_table_with_correct_range(
self,
table: Table,
column_names: list[str] | None,
expected: Table,
) -> None:
assert RangeScaler(minimum=-10.0, maximum=10.0).fit_and_transform(table, column_names) == expected
fitted_transformer, transformed_table = RangeScaler(minimum=-10.0, maximum=10.0).fit_and_transform(
table,
column_names,
)
assert fitted_transformer.is_fitted
assert transformed_table == expected

def test_should_not_change_original_table(self) -> None:
table = Table(
Expand Down
32 changes: 8 additions & 24 deletions tests/safeds/data/tabular/transformation/test_standard_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def test_should_return_true_after_fitting(self) -> None:
assert fitted_transformer.is_fitted


class TestFitAndTransformOnMultipleTables:
class TestFitAndTransform:
@pytest.mark.parametrize(
("fit_and_transform_table", "only_transform_table", "column_names", "expected_1", "expected_2"),
("table", "column_names", "expected"),
[
(
Table(
Expand All @@ -122,43 +122,27 @@ class TestFitAndTransformOnMultipleTables:
"col2": [0.0, 0.0, 1.0, 1.0],
},
),
Table(
{
"col1": [2],
"col2": [2],
},
),
None,
Table(
{
"col1": [-1.0, -1.0, 1.0, 1.0],
"col2": [-1.0, -1.0, 1.0, 1.0],
},
),
Table(
{
"col1": [3.0],
"col2": [3.0],
},
),
),
],
ids=["two_columns"],
)
def test_should_return_transformed_tables(
def test_should_return_fitted_transformer_and_transformed_table(
self,
fit_and_transform_table: Table,
only_transform_table: Table,
table: Table,
column_names: list[str] | None,
expected_1: Table,
expected_2: Table,
expected: Table,
) -> None:
s = StandardScaler().fit(fit_and_transform_table, column_names)
assert s.fit_and_transform(fit_and_transform_table, column_names) == expected_1
assert s.transform(only_transform_table) == expected_2

fitted_transformer, transformed_table = StandardScaler().fit_and_transform(table, column_names)
assert fitted_transformer.is_fitted
assert_that_tables_are_close(transformed_table, expected)

class TestFitAndTransform:
def test_should_not_change_original_table(self) -> None:
table = Table(
{
Expand Down

0 comments on commit 2960d35

Please sign in to comment.