Skip to content

Commit

Permalink
[ADD] Add column transformer (#305)
Browse files Browse the repository at this point in the history
* Match paper libraries-versions

* Update README.md

* Update README.md

* Update README.md

* [FIX] master branch README (#209)

* Enable github actions (#273)

* Update README.md

* Create CITATION.cff

* Added column transformer, changed requirements and added tests

* remove redundant lines

* Remove unwanted change made

* Fix bug in test api and dummy forward pass

* Fix silly bugs

* increase time to pass test

* remove parallel capabilities of traditional learners to resolve bug in docs building

* almost fixed

* Add documentation for tabularfeaturevalidator

* fix flake

* fix silly bug

* address comment from shuhei

* rename enc_columns to transformed_columns in the remaining places

* fix bug in test

* fix mypy

* add shuhei's suggestion

Co-authored-by: chico <francisco.rivera.valverde@gmail.com>
Co-authored-by: Marius Lindauer <marius.rks@googlemail.com>
Co-authored-by: Frank <fh@informatik.uni-freiburg.de>
Co-authored-by: Francisco Rivera Valverde <44504424+franchuterivera@users.noreply.github.com>
  • Loading branch information
5 people committed Nov 4, 2021
1 parent 9002937 commit a11caf4
Show file tree
Hide file tree
Showing 13 changed files with 188 additions and 137 deletions.
3 changes: 2 additions & 1 deletion autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,8 @@ def get_incumbent_results(
if not include_traditional:
# traditional classifiers have trainer_configuration in their additional info
run_history_data = dict(
filter(lambda elem: elem[1].additional_info is not None and elem[1].
filter(lambda elem: elem[1].status == StatusType.SUCCESS and elem[1].
additional_info is not None and elem[1].
additional_info['configuration_origin'] != 'traditional',
run_history_data.items()))
run_history_data = dict(
Expand Down
10 changes: 3 additions & 7 deletions autoPyTorch/data/base_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@ class BaseFeatureValidator(BaseEstimator):
List of the column types found by this estimator during fit.
data_type (str):
Class name of the data type provided during fit.
encoder (typing.Optional[BaseEstimator])
Host a encoder object if the data requires transformation (for example,
if provided a categorical column in a pandas DataFrame)
enc_columns (typing.List[str])
List of columns that were encoded.
"""
def __init__(self,
logger: typing.Optional[typing.Union[PicklableClientLogger, logging.Logger
Expand All @@ -51,8 +46,8 @@ def __init__(self,
self.dtypes = [] # type: typing.List[str]
self.column_order = [] # type: typing.List[str]

self.encoder = None # type: typing.Optional[BaseEstimator]
self.enc_columns = [] # type: typing.List[str]
self.column_transformer = None # type: typing.Optional[BaseEstimator]
self.transformed_columns = [] # type: typing.List[str]

self.logger: typing.Union[
PicklableClientLogger, logging.Logger
Expand All @@ -61,6 +56,7 @@ def __init__(self,
# Required for dataset properties
self.num_features = None # type: typing.Optional[int]
self.categories = [] # type: typing.List[typing.List[int]]

self.categorical_columns: typing.List[int] = []
self.numerical_columns: typing.List[int] = []

Expand Down
225 changes: 127 additions & 98 deletions autoPyTorch/data/tabular_feature_validator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import typing
from typing import Dict, List

import numpy as np

Expand All @@ -13,11 +14,108 @@
from sklearn.base import BaseEstimator
from sklearn.compose import ColumnTransformer
from sklearn.exceptions import NotFittedError
from sklearn.impute import SimpleImputer
from sklearn.pipeline import make_pipeline

from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SUPPORTED_FEAT_TYPES


def _create_column_transformer(
preprocessors: Dict[str, List[BaseEstimator]],
categorical_columns: List[str],
) -> ColumnTransformer:
"""
Given a dictionary of preprocessors, this function
creates a sklearn column transformer with appropriate
columns associated with their preprocessors.
Args:
preprocessors (Dict[str, List[BaseEstimator]]):
Dictionary containing list of numerical and categorical preprocessors.
categorical_columns (List[str]):
List of names of categorical columns
Returns:
ColumnTransformer
"""

categorical_pipeline = make_pipeline(*preprocessors['categorical'])

return ColumnTransformer([
('categorical_pipeline', categorical_pipeline, categorical_columns)],
remainder='passthrough'
)


def get_tabular_preprocessors() -> Dict[str, List[BaseEstimator]]:
"""
This function creates a Dictionary containing a list
of numerical and categorical preprocessors
Returns:
Dict[str, List[BaseEstimator]]
"""
preprocessors: Dict[str, List[BaseEstimator]] = dict()

# Categorical Preprocessors
onehot_encoder = preprocessing.OrdinalEncoder(handle_unknown='use_encoded_value',
unknown_value=-1)
categorical_imputer = SimpleImputer(strategy='constant', copy=False)

preprocessors['categorical'] = [categorical_imputer, onehot_encoder]

return preprocessors


class TabularFeatureValidator(BaseFeatureValidator):
"""
A subclass of `BaseFeatureValidator` made for tabular data.
It ensures that the dataset provided is of the expected format.
Subsequently, it preprocesses the data by fitting a column
transformer.
Attributes:
categories (List[List[str]]):
List for which an element at each index is a
list containing the categories for the respective
categorical column.
transformed_columns (List[str])
List of columns that were transformed.
column_transformer (Optional[BaseEstimator])
Hosts an imputer and an encoder object if the data
requires transformation (for example, if provided a
categorical column in a pandas DataFrame)
column_order (List[str]):
List of the features stored in the order that
was fitted.
numerical_columns (List[int]):
List of indices of numerical columns
categorical_columns (List[int]):
List of indices of categorical columns
"""
@staticmethod
def _comparator(cmp1: str, cmp2: str) -> int:
"""Order so that categorical columns come left and numerical columns come right
Args:
cmp1 (str): First variable to compare
cmp2 (str): Second variable to compare
Raises:
ValueError: if the values of the variables to compare
are not in 'categorical' or 'numerical'
Returns:
int: either [0, -1, 1]
"""
choices = ['categorical', 'numerical']
if cmp1 not in choices or cmp2 not in choices:
raise ValueError('The comparator for the column order only accepts {}, '
'but got {} and {}'.format(choices, cmp1, cmp2))

idx1, idx2 = choices.index(cmp1), choices.index(cmp2)
return idx1 - idx2

def _fit(
self,
X: SUPPORTED_FEAT_TYPES,
Expand Down Expand Up @@ -60,51 +158,38 @@ def _fit(
if not X.select_dtypes(include='object').empty:
X = self.infer_objects(X)

self.enc_columns, self.feat_type = self._get_columns_to_encode(X)
self.transformed_columns, self.feat_type = self._get_columns_to_encode(X)

if len(self.enc_columns) > 0:
X = self.impute_nan_in_categories(X)
assert self.feat_type is not None

self.encoder = ColumnTransformer(
[
("encoder",
preprocessing.OrdinalEncoder(
handle_unknown='use_encoded_value',
unknown_value=-1,
), self.enc_columns)],
remainder="passthrough"
if len(self.transformed_columns) > 0:

preprocessors = get_tabular_preprocessors()
self.column_transformer = _create_column_transformer(
preprocessors=preprocessors,
categorical_columns=self.transformed_columns,
)

# Mypy redefinition
assert self.encoder is not None
self.encoder.fit(X)

# The column transformer reoders the feature types - we therefore need to change
# it as well
# This means columns are shifted to the right
def comparator(cmp1: str, cmp2: str) -> int:
if (
cmp1 == 'categorical' and cmp2 == 'categorical'
or cmp1 == 'numerical' and cmp2 == 'numerical'
):
return 0
elif cmp1 == 'categorical' and cmp2 == 'numerical':
return -1
elif cmp1 == 'numerical' and cmp2 == 'categorical':
return 1
else:
raise ValueError((cmp1, cmp2))
assert self.column_transformer is not None
self.column_transformer.fit(X)

# The column transformer reorders the feature types
# therefore, we need to change the order of columns as well
# This means categorical columns are shifted to the left
self.feat_type = sorted(
self.feat_type,
key=functools.cmp_to_key(comparator)
key=functools.cmp_to_key(self._comparator)
)

encoded_categories = self.column_transformer.\
named_transformers_['categorical_pipeline'].\
named_steps['ordinalencoder'].categories_
self.categories = [
# We fit an ordinal encoder, where all categorical
# columns are shifted to the left
list(range(len(cat)))
for cat in self.encoder.transformers_[0][1].categories_
for cat in encoded_categories
]

for i, type_ in enumerate(self.feat_type):
Expand Down Expand Up @@ -158,7 +243,7 @@ def transform(
self._check_data(X)

# Pandas related transformations
if hasattr(X, "iloc") and self.encoder is not None:
if hasattr(X, "iloc") and self.column_transformer is not None:
if np.any(pd.isnull(X)):
# After above check it means that if there is a NaN
# the whole column must be NaN
Expand All @@ -167,11 +252,7 @@ def transform(
if X[column].isna().all():
X[column] = pd.to_numeric(X[column])

# We also need to fillna on the transformation
# in case test data is provided
X = self.impute_nan_in_categories(X)

X = self.encoder.transform(X)
X = self.column_transformer.transform(X)

# Sparse related transformations
# Not all sparse format support index sorting
Expand Down Expand Up @@ -245,7 +326,7 @@ def _check_data(

# Define the column to be encoded here as the feature validator is fitted once
# per estimator
enc_columns, _ = self._get_columns_to_encode(X)
self.transformed_columns, self.feat_type = self._get_columns_to_encode(X)

column_order = [column for column in X.columns]
if len(self.column_order) > 0:
Expand Down Expand Up @@ -282,13 +363,17 @@ def _get_columns_to_encode(
A set of features that are going to be validated (type and dimensionality
checks) and a encoder fitted in the case the data needs encoding
Returns:
enc_columns (List[str]):
transformed_columns (List[str]):
Columns to encode, if any
feat_type:
Type of each column numerical/categorical
"""

if len(self.transformed_columns) > 0 and self.feat_type is not None:
return self.transformed_columns, self.feat_type

# Register if a column needs encoding
enc_columns = []
transformed_columns = []

# Also, register the feature types for the estimator
feat_type = []
Expand All @@ -297,7 +382,7 @@ def _get_columns_to_encode(
for i, column in enumerate(X.columns):
if X[column].dtype.name in ['category', 'bool']:

enc_columns.append(column)
transformed_columns.append(column)
feat_type.append('categorical')
# Move away from np.issubdtype as it causes
# TypeError: data type not understood in certain pandas types
Expand Down Expand Up @@ -339,7 +424,7 @@ def _get_columns_to_encode(
)
else:
feat_type.append('numerical')
return enc_columns, feat_type
return transformed_columns, feat_type

def list_to_dataframe(
self,
Expand Down Expand Up @@ -429,59 +514,3 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
self.object_dtype_mapping = {column: X[column].dtype for column in X.columns}
self.logger.debug(f"Infer Objects: {self.object_dtype_mapping}")
return X

def impute_nan_in_categories(self, X: pd.DataFrame) -> pd.DataFrame:
"""
impute missing values before encoding,
remove once sklearn natively supports
it in ordinal encoding. Sklearn issue:
"https://github.com/scikit-learn/scikit-learn/issues/17123)"
Arguments:
X (pd.DataFrame):
data to be interpreted.
Returns:
pd.DataFrame
"""

# To be on the safe side, map always to the same missing
# value per column
if not hasattr(self, 'dict_nancol_to_missing'):
self.dict_missing_value_per_col: typing.Dict[str, typing.Any] = {}

# First make sure that we do not alter the type of the column which cause:
# TypeError: '<' not supported between instances of 'int' and 'str'
# in the encoding
for column in self.enc_columns:
if X[column].isna().any():
if column not in self.dict_missing_value_per_col:
try:
float(X[column].dropna().values[0])
can_cast_as_number = True
except Exception:
can_cast_as_number = False
if can_cast_as_number:
# In this case, we expect to have a number as category
# it might be string, but its value represent a number
missing_value: typing.Union[str, int] = '-1' if isinstance(X[column].dropna().values[0],
str) else -1
else:
missing_value = 'Missing!'

# Make sure this missing value is not seen before
# Do this check for categorical columns
# else modify the value
if hasattr(X[column], 'cat'):
while missing_value in X[column].cat.categories:
if isinstance(missing_value, str):
missing_value += '0'
else:
missing_value += missing_value
self.dict_missing_value_per_col[column] = missing_value

# Convert the frame in place
X[column].cat.add_categories([self.dict_missing_value_per_col[column]],
inplace=True)
X.fillna({column: self.dict_missing_value_per_col[column]}, inplace=True)
return X
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self, n_res_inputs: int, n_outputs: int):
def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor:
shortcut = self.shortcut(res)
shortcut = self.bn(shortcut)
x += shortcut
x = x + shortcut
return torch.relu(x)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
{
"n_estimators" : 300,
"n_jobs" : -1
"n_estimators" : 300
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
{
"weights" : "uniform",
"n_jobs" : -1
"weights" : "uniform"
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,5 @@
"min_data_in_leaf" : 3,
"feature_fraction" : 0.9,
"boosting_type" : "gbdt",
"learning_rate" : 0.03,
"num_threads" : -1
"learning_rate" : 0.03
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
{
"n_estimators" : 300,
"n_jobs" : -1
"n_estimators" : 300
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
{
"n_jobs" : -1
}

0 comments on commit a11caf4

Please sign in to comment.