Skip to content

Commit

Permalink
perf: implement one hot encoder and imputer using polars (#768)
Browse files Browse the repository at this point in the history
### Summary of Changes

The one hot encoder and imputer are now also implemented using polars,
providing better performance.

Tests should pass again now. We'll maximize coverage over the coming
days.

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
  • Loading branch information
lars-reimann and megalinter-bot committed May 15, 2024
1 parent 6fbe537 commit e993c17
Show file tree
Hide file tree
Showing 25 changed files with 428 additions and 579 deletions.
2 changes: 2 additions & 0 deletions .mega-linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ JSON_PRETTIER_FILE_EXTENSIONS:
- .html
# - .md

PYTHON_RUFF_CONFIG_FILE: pyproject.toml

# Commands
PRE_COMMANDS:
- command: npm i @lars-reimann/prettier-config
106 changes: 53 additions & 53 deletions poetry.lock

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions src/safeds/data/labeled/containers/_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def shuffle(self) -> ImageDataset[T]:

class _TableAsTensor:
def __init__(self, table: Table) -> None:
import polars as pl
import torch

_init_default_device()
Expand All @@ -298,7 +299,7 @@ def __init__(self, table: Table) -> None:
if table.number_of_rows == 0:
self._tensor = torch.empty((0, table.number_of_columns), dtype=torch.float32).to(_get_device())
else:
self._tensor = table._data_frame.to_torch().to(_get_device())
self._tensor = table._data_frame.to_torch(dtype=pl.Float32).to(_get_device())

if not torch.all(self._tensor.sum(dim=1) == torch.ones(self._tensor.size(dim=0))):
raise ValueError(
Expand Down Expand Up @@ -345,6 +346,7 @@ def _to_table(self) -> Table:

class _ColumnAsTensor:
def __init__(self, column: Column) -> None:
import polars as pl
import torch

_init_default_device()
Expand All @@ -360,9 +362,9 @@ def __init__(self, column: Column) -> None:
# TODO: should not one-hot-encode the target. label encoding without order is sufficient. should also not
# be done automatically?
self._one_hot_encoder = OneHotEncoder().fit(column_as_table, [self._column_name])
self._tensor = torch.Tensor(self._one_hot_encoder.transform(column_as_table)._data_frame.to_torch()).to(
_get_device(),
)
self._tensor = torch.Tensor(
self._one_hot_encoder.transform(column_as_table)._data_frame.to_torch(dtype=pl.Float32),
).to(_get_device())

def __eq__(self, other: object) -> bool:
import torch
Expand Down
2 changes: 1 addition & 1 deletion src/safeds/data/labeled/containers/_tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class TabularDataset(Dataset):
Examples
--------
>>> from safeds.data.labeled.containers import TabularDataset
>>> from safeds.data.tabular.containers import Table
>>> table = Table(
... {
... "id": [1, 2, 3],
Expand Down
9 changes: 1 addition & 8 deletions src/safeds/data/tabular/containers/_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,14 +1009,7 @@ def mode(
>>> from safeds.data.tabular.containers import Column
>>> column = Column("test", [3, 1, 2, 1, 3])
>>> column.mode()
+------+
| test |
| --- |
| i64 |
+======+
| 1 |
| 3 |
+------+
[1, 3]
"""
import polars as pl

Expand Down
17 changes: 8 additions & 9 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def __eq__(self, other: object) -> bool:
if self is other:
return True

return self._data_frame.frame_equal(other._data_frame)
return self._data_frame.equals(other._data_frame)

def __hash__(self) -> int:
return _structural_hash(self.schema, self.number_of_rows)
Expand Down Expand Up @@ -859,7 +859,7 @@ def rename_column(self, old_name: str, new_name: str) -> Table:
def replace_column(
self,
old_name: str,
new_columns: Column | list[Column],
new_columns: Column | list[Column] | Table,
) -> Table:
"""
Return a new table with a column replaced by zero or more columns.
Expand All @@ -871,7 +871,7 @@ def replace_column(
old_name:
The name of the column to replace.
new_columns:
The new column or columns.
The new columns.
Returns
-------
Expand Down Expand Up @@ -922,11 +922,13 @@ def replace_column(
| 9 | 12 | 6 |
+-----+-----+-----+
"""
_check_columns_exist(self, old_name)
_check_columns_dont_exist(self, [column.name for column in new_columns], old_name=old_name)

if isinstance(new_columns, Column):
new_columns = [new_columns]
elif isinstance(new_columns, Table):
new_columns = new_columns.to_columns()

_check_columns_exist(self, old_name)
_check_columns_dont_exist(self, [column.name for column in new_columns], old_name=old_name)

if len(new_columns) == 0:
return self.remove_columns(old_name)
Expand Down Expand Up @@ -1033,9 +1035,6 @@ def remove_duplicate_rows(self) -> Table:
| 2 | 5 |
+-----+-----+
"""
if self.number_of_columns == 0:
return self # Workaround for https://github.com/pola-rs/polars/issues/16207

return Table._from_polars_lazy_frame(
self._lazy_frame.unique(maintain_order=True),
)
Expand Down
2 changes: 1 addition & 1 deletion src/safeds/data/tabular/plotting/_table_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TablePlotter:
Examples
--------
>>> from safeds.data.tabular.containers import Table
>>> table = Table("test", [1, 2, 3])
>>> table = Table({"test": [1, 2, 3]})
>>> plotter = table.plot
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ class InvertibleTableTransformer(TableTransformer):
@abstractmethod
def inverse_transform(self, transformed_table: Table) -> Table:
"""
Undo the learned transformation.
Undo the learned transformation as well as possible.
The table is not modified.
Column order and types may differ from the original table. Likewise, some values might not be restored.
**Note:** The given table is not modified.
Parameters
----------
Expand Down
21 changes: 12 additions & 9 deletions src/safeds/data/tabular/transformation/_label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ def __init__(self, *, partial_order: list[Any] | None = None) -> None:
self._partial_order = partial_order

# Internal state
self._mapping: dict[str, dict[Any, int]] | None = None
self._inverse_mapping: dict[str, dict[int, Any]] | None = None
self._mapping: dict[str, dict[Any, int]] | None = None # Column name -> value -> label
self._inverse_mapping: dict[str, dict[int, Any]] | None = None # Column name -> label -> value

def __hash__(self) -> int:
return _structural_hash(
super().__hash__(),
self._partial_order,
# Leave out the internal state for faster hashing
)

# ------------------------------------------------------------------------------------------------------------------
Expand All @@ -61,7 +62,7 @@ def fit(self, table: Table, column_names: list[str] | None) -> LabelEncoder:
table:
The table used to fit the transformer.
column_names:
The list of columns from the table used to fit the transformer. If `None`, all columns are used.
The list of columns from the table used to fit the transformer. If `None`, all non-numeric columns are used.
Returns
-------
Expand All @@ -76,14 +77,13 @@ def fit(self, table: Table, column_names: list[str] | None) -> LabelEncoder:
If the table contains 0 rows.
"""
if column_names is None:
column_names = table.column_names
column_names = [name for name in table.column_names if not table.get_column_type(name).is_numeric]
else:
_check_columns_exist(table, column_names)
_warn_if_columns_are_numeric(table, column_names)

if table.number_of_rows == 0:
raise ValueError("The LabelEncoder cannot transform the table because it contains 0 rows")

_warn_if_columns_are_numeric(table, column_names)
raise ValueError("The LabelEncoder cannot be fitted because the table contains 0 rows")

# Learn the transformation
mapping = {}
Expand Down Expand Up @@ -142,7 +142,10 @@ def transform(self, table: Table) -> Table:

_check_columns_exist(table, self._column_names)

columns = [pl.col(name).replace(self._mapping[name], return_dtype=pl.UInt32) for name in self._column_names]
columns = [
pl.col(name).replace(self._mapping[name], default=None, return_dtype=pl.UInt32)
for name in self._column_names
]

return Table._from_polars_lazy_frame(
table._lazy_frame.with_columns(columns),
Expand Down Expand Up @@ -186,7 +189,7 @@ def inverse_transform(self, transformed_table: Table) -> Table:
operation="inverse-transform with a LabelEncoder",
)

columns = [pl.col(name).replace(self._inverse_mapping[name]) for name in self._column_names]
columns = [pl.col(name).replace(self._inverse_mapping[name], default=None) for name in self._column_names]

return Table._from_polars_lazy_frame(
transformed_table._lazy_frame.with_columns(columns),
Expand Down

0 comments on commit e993c17

Please sign in to comment.