Skip to content

Commit

Permalink
fix: OneHotEncoder no longer creates duplicate column names (#271)
Browse files Browse the repository at this point in the history
Closes #201.

### Summary of Changes

Changed OneHotEncoder to manually implement the encoding.
(Breaking) Changed the format of newly generated columns to use two
underscores as separator. In case of naming conflicts, a hash and a
unique ID will be appended to the column name.

---------

Co-authored-by: zzril <>
Co-authored-by: ilkajw <123072184+ilkajw@users.noreply.github.com>
Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
  • Loading branch information
3 people committed May 10, 2023
1 parent 8db5914 commit f604666
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 92 deletions.
8 changes: 4 additions & 4 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,10 +1136,10 @@ def transform_table(self, transformer: TableTransformer) -> Table:
>>> table = Table({"col1": [1, 2, 1], "col2": [1, 2, 4]})
>>> fitted_transformer = transformer.fit(table, None)
>>> table.transform_table(fitted_transformer)
col1_1 col1_2 col2_1 col2_2 col2_4
0 1.0 0.0 1.0 0.0 0.0
1 0.0 1.0 0.0 1.0 0.0
2 1.0 0.0 0.0 0.0 1.0
col1__1 col1__2 col2__1 col2__2 col2__4
0 1.0 0.0 1.0 0.0 0.0
1 0.0 1.0 0.0 1.0 0.0
2 1.0 0.0 0.0 0.0 1.0
"""
return transformer.transform(self)

Expand Down
142 changes: 87 additions & 55 deletions src/safeds/data/tabular/transformation/_one_hot_encoder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import pandas as pd
from sklearn.preprocessing import OneHotEncoder as sk_OneHotEncoder
from collections import Counter
from typing import Any

from safeds.data.tabular.containers import Table
from safeds.data.tabular.containers import Column, Table
from safeds.data.tabular.transformation._table_transformer import (
InvertibleTableTransformer,
)
from safeds.exceptions import TransformerNotFittedError, UnknownColumnNameError
from safeds.exceptions import TransformerNotFittedError, UnknownColumnNameError, ValueNotPresentWhenFittedError


class OneHotEncoder(InvertibleTableTransformer):
Expand All @@ -27,12 +27,12 @@ class OneHotEncoder(InvertibleTableTransformer):
The one-hot encoding of this table is:
| col1_a | col1_b | col1_c |
|--------|--------|--------|
| 1 | 0 | 0 |
| 0 | 1 | 0 |
| 0 | 0 | 1 |
| 1 | 0 | 0 |
| col1__a | col1__b | col1__c |
|---------|---------|---------|
| 1 | 0 | 0 |
| 0 | 1 | 0 |
| 0 | 0 | 1 |
| 1 | 0 | 0 |
The name "one-hot" comes from the fact that each row has exactly one 1 in it, and the rest of the values are 0s.
One-hot encoding is closely related to dummy variable / indicator variables, which are used in statistics.
Expand All @@ -44,16 +44,18 @@ class OneHotEncoder(InvertibleTableTransformer):
>>> table = Table({"col1": ["a", "b", "c", "a"]})
>>> transformer = OneHotEncoder()
>>> transformer.fit_and_transform(table, ["col1"])
col1_a col1_b col1_c
0 1.0 0.0 0.0
1 0.0 1.0 0.0
2 0.0 0.0 1.0
3 1.0 0.0 0.0
col1__a col1__b col1__c
0 1.0 0.0 0.0
1 0.0 1.0 0.0
2 0.0 0.0 1.0
3 1.0 0.0 0.0
"""

def __init__(self) -> None:
self._wrapped_transformer: sk_OneHotEncoder | None = None
# Maps each old column to (list of) new columns created from it:
self._column_names: dict[str, list[str]] | None = None
# Maps concrete values (tuples of old column and value) to corresponding new column names:
self._value_to_column: dict[tuple[str, Any], str] | None = None

# noinspection PyProtectedMember
def fit(self, table: Table, column_names: list[str] | None) -> OneHotEncoder:
Expand Down Expand Up @@ -84,15 +86,28 @@ def fit(self, table: Table, column_names: list[str] | None) -> OneHotEncoder:
data = table._data.copy()
data.columns = table.column_names

wrapped_transformer = sk_OneHotEncoder()
wrapped_transformer.fit(data[column_names])

result = OneHotEncoder()
result._wrapped_transformer = wrapped_transformer
result._column_names = {
column: [f"{column}_{element}" for element in table.get_column(column).get_unique_values()]
for column in column_names
}

result._column_names = {}
result._value_to_column = {}

# Keep track of number of occurrences of column names;
# initially all old column names appear exactly once:
name_counter = Counter(data.columns)

# Iterate through all columns to-be-changed:
for column in column_names:
result._column_names[column] = []
for element in table.get_column(column).get_unique_values():
base_name = f"{column}__{element}"
name_counter[base_name] += 1
new_column_name = base_name
# Check if newly created name matches some other existing column name:
if name_counter[base_name] > 1:
new_column_name += f"#{name_counter[base_name]}"
# Update dictionary entries:
result._column_names[column] += [new_column_name]
result._value_to_column[(column, element)] = new_column_name

return result

Expand All @@ -119,37 +134,49 @@ def transform(self, table: Table) -> Table:
If the transformer has not been fitted yet.
"""
# Transformer has not been fitted yet
if self._wrapped_transformer is None or self._column_names is None:
if self._column_names is None or self._value_to_column is None:
raise TransformerNotFittedError

# Input table does not contain all columns used to fit the transformer
missing_columns = set(self._column_names.keys()) - set(table.column_names)
if len(missing_columns) > 0:
raise UnknownColumnNameError(list(missing_columns))

original = table._data.copy()
original.columns = table.schema.column_names

one_hot_encoded = pd.DataFrame(
self._wrapped_transformer.transform(original[self._column_names.keys()]).toarray(),
)
one_hot_encoded.columns = self._wrapped_transformer.get_feature_names_out()

unchanged = original.drop(self._column_names.keys(), axis=1)

res = Table._from_pandas_dataframe(pd.concat([unchanged, one_hot_encoded], axis=1))
encoded_values = {}
for new_column_name in self._value_to_column.values():
encoded_values[new_column_name] = [0.0 for _ in range(table.number_of_rows)]

for old_column_name in self._column_names:
for i in range(table.number_of_rows):
value = table.get_column(old_column_name).get_value(i)
try:
new_column_name = self._value_to_column[(old_column_name, value)]
except KeyError:
# This happens when a column in the to-be-transformed table contains a new value that was not
# already present in the table the OneHotEncoder was fitted on.
raise ValueNotPresentWhenFittedError(value, old_column_name) from None
encoded_values[new_column_name][i] = 1.0

for new_column in self._column_names[old_column_name]:
table = table.add_column(Column(new_column, encoded_values[new_column]))

# New columns may not be sorted:
column_names = []

for name in table.column_names:
if name not in self._column_names.keys():
column_names.append(name)
else:
column_names.extend(
[f_name for f_name in self._wrapped_transformer.get_feature_names_out() if f_name.startswith(name)],
[f_name for f_name in self._value_to_column.values() if f_name.startswith(name)],
)
res = res.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

return res
# Drop old, non-encoded columns:
# (Don't do this earlier - we need the old column nams for sorting,
# plus we need to prevent the table from possibly having 0 columns temporarily.)
table = table.remove_columns(list(self._column_names.keys()))

# Apply sorting and return:
return table.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

# noinspection PyProtectedMember
def inverse_transform(self, transformed_table: Table) -> Table:
Expand All @@ -174,21 +201,24 @@ def inverse_transform(self, transformed_table: Table) -> Table:
If the transformer has not been fitted yet.
"""
# Transformer has not been fitted yet
if self._wrapped_transformer is None or self._column_names is None:
if self._column_names is None or self._value_to_column is None:
raise TransformerNotFittedError

data = transformed_table._data.copy()
data.columns = transformed_table.column_names
original_columns = {}
for original_column_name in self._column_names:
original_columns[original_column_name] = [None for _ in range(transformed_table.number_of_rows)]

for original_column_name, value in self._value_to_column:
constructed_column = self._value_to_column[(original_column_name, value)]
for i in range(transformed_table.number_of_rows):
if transformed_table.get_column(constructed_column)[i] == 1.0:
original_columns[original_column_name][i] = value

decoded = pd.DataFrame(
self._wrapped_transformer.inverse_transform(
transformed_table.keep_only_columns(self._wrapped_transformer.get_feature_names_out())._data,
),
columns=list(self._column_names.keys()),
)
unchanged = data.drop(self._wrapped_transformer.get_feature_names_out(), axis=1)
table = transformed_table

for column_name, encoded_column in original_columns.items():
table = table.add_column(Column(column_name, encoded_column))

res = Table._from_pandas_dataframe(pd.concat([unchanged, decoded], axis=1))
column_names = [
(
name
Expand All @@ -201,11 +231,13 @@ def inverse_transform(self, transformed_table: Table) -> Table:
][0]
]
)
for name in transformed_table.column_names
for name in table.column_names
]
res = res.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

return res
# Drop old column names:
table = table.remove_columns(list(self._value_to_column.values()))

return table.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

def is_fitted(self) -> bool:
"""
Expand All @@ -216,4 +248,4 @@ def is_fitted(self) -> bool:
is_fitted : bool
Whether the transformer is fitted.
"""
return self._wrapped_transformer is not None
return self._column_names is not None and self._value_to_column is not None
2 changes: 2 additions & 0 deletions src/safeds/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SchemaMismatchError,
TransformerNotFittedError,
UnknownColumnNameError,
ValueNotPresentWhenFittedError,
)
from safeds.exceptions._ml import (
DatasetContainsTargetError,
Expand All @@ -29,6 +30,7 @@
"SchemaMismatchError",
"TransformerNotFittedError",
"UnknownColumnNameError",
"ValueNotPresentWhenFittedError",
# ML exceptions
"DatasetContainsTargetError",
"DatasetMissesFeaturesError",
Expand Down
7 changes: 7 additions & 0 deletions src/safeds/exceptions/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,10 @@ class TransformerNotFittedError(Exception):

def __init__(self) -> None:
super().__init__("The transformer has not been fitted yet.")


class ValueNotPresentWhenFittedError(Exception):
"""Exception raised when attempting to one-hot-encode a table containing values not present in the fitting phase."""

def __init__(self, value: str, column: str) -> None:
super().__init__(f"Value not present in the table the transformer was fitted on: \n{value} in column {column}.")
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def test_should_not_change_transformed_table() -> None:

expected = Table(
{
"col1_a": [1.0, 0.0, 0.0, 0.0],
"col1_b": [0.0, 1.0, 1.0, 0.0],
"col1_c": [0.0, 0.0, 0.0, 1.0],
"col1__a": [1.0, 0.0, 0.0, 0.0],
"col1__b": [0.0, 1.0, 1.0, 0.0],
"col1__c": [0.0, 0.0, 0.0, 1.0],
},
)

Expand Down
24 changes: 12 additions & 12 deletions tests/safeds/data/tabular/containers/_table/test_transform_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
None,
Table(
{
"col1_a": [1.0, 0.0, 0.0, 0.0],
"col1_b": [0.0, 1.0, 1.0, 0.0],
"col1_c": [0.0, 0.0, 0.0, 1.0],
"col1__a": [1.0, 0.0, 0.0, 0.0],
"col1__b": [0.0, 1.0, 1.0, 0.0],
"col1__c": [0.0, 0.0, 0.0, 1.0],
},
),
),
Expand All @@ -32,9 +32,9 @@
["col1"],
Table(
{
"col1_a": [1.0, 0.0, 0.0, 0.0],
"col1_b": [0.0, 1.0, 1.0, 0.0],
"col1_c": [0.0, 0.0, 0.0, 1.0],
"col1__a": [1.0, 0.0, 0.0, 0.0],
"col1__b": [0.0, 1.0, 1.0, 0.0],
"col1__c": [0.0, 0.0, 0.0, 1.0],
"col2": ["a", "b", "b", "c"],
},
),
Expand All @@ -49,12 +49,12 @@
["col1", "col2"],
Table(
{
"col1_a": [1.0, 0.0, 0.0, 0.0],
"col1_b": [0.0, 1.0, 1.0, 0.0],
"col1_c": [0.0, 0.0, 0.0, 1.0],
"col2_a": [1.0, 0.0, 0.0, 0.0],
"col2_b": [0.0, 1.0, 1.0, 0.0],
"col2_c": [0.0, 0.0, 0.0, 1.0],
"col1__a": [1.0, 0.0, 0.0, 0.0],
"col1__b": [0.0, 1.0, 1.0, 0.0],
"col1__c": [0.0, 0.0, 0.0, 1.0],
"col2__a": [1.0, 0.0, 0.0, 0.0],
"col2__b": [0.0, 1.0, 1.0, 0.0],
"col2__c": [0.0, 0.0, 0.0, 1.0],
},
),
),
Expand Down

0 comments on commit f604666

Please sign in to comment.