Skip to content

Commit

Permalink
Merge pull request #126 from franchuterivera/refactor_development_ben…
Browse files Browse the repository at this point in the history
…chmarkproblems

Fixes to address automlbenchmark problems
  • Loading branch information
ravinkohli committed Mar 8, 2021
2 parents 6da24a8 + c95fbf3 commit a3d40ac
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 25 deletions.
7 changes: 5 additions & 2 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,8 +858,11 @@ def _search(
saveable_trajectory = \
[list(entry[:2]) + [entry[2].get_dictionary()] + list(entry[3:])
for entry in self.trajectory]
with open(trajectory_filename, 'w') as fh:
json.dump(saveable_trajectory, fh)
try:
with open(trajectory_filename, 'w') as fh:
json.dump(saveable_trajectory, fh)
except Exception as e:
self._logger.warning(f"Cannot save {trajectory_filename} due to {e}...")
except Exception as e:
self._logger.exception(str(e))
raise
Expand Down
143 changes: 125 additions & 18 deletions autoPyTorch/data/tabular_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,13 @@ def _fit(
if len(self.dtypes) != 0:
self.dtypes[list(X.columns).index(column)] = X[column].dtype

if not X.select_dtypes(include='object').empty:
X = self.infer_objects(X)

self.enc_columns, self.feat_type = self._get_columns_to_encode(X)

if len(self.enc_columns) > 0:
# impute missing values before encoding,
# remove once sklearn natively supports
# it in ordinal encoding. Sklearn issue:
# "https://github.com/scikit-learn/scikit-learn/issues/17123)"
for column in self.enc_columns:
if X[column].isna().any():
missing_value: typing.Union[int, str] = -1
# make sure for a string column we give
# string missing value else we give numeric
if type(X[column][0]) == str:
missing_value = str(missing_value)
X[column] = X[column].cat.add_categories([missing_value])
X[column] = X[column].fillna(missing_value)
X = self.impute_nan_in_categories(X)

self.encoder = ColumnTransformer(
[
Expand Down Expand Up @@ -160,6 +151,10 @@ def transform(
if X[column].isna().all():
X[column] = pd.to_numeric(X[column])

# Also remove the object dtype for new data
if not X.select_dtypes(include='object').empty:
X = self.infer_objects(X)

# Check the data here so we catch problems on new test data
self._check_data(X)

Expand All @@ -172,18 +167,32 @@ def transform(
for column in X.columns:
if X[column].isna().all():
X[column] = pd.to_numeric(X[column])

# We also need to fillna on the transformation
# in case test data is provided
X = self.impute_nan_in_categories(X)

X = self.encoder.transform(X)

# Sparse related transformations
# Not all sparse format support index sorting
if scipy.sparse.issparse(X) and hasattr(X, 'sort_indices'):
X.sort_indices()

return sklearn.utils.check_array(
X,
force_all_finite=False,
accept_sparse='csr'
)
try:
X = sklearn.utils.check_array(
X,
force_all_finite=False,
accept_sparse='csr'
)
except Exception as e:
self.logger.exception(f"Conversion failed for input {X.dtypes} {X}"
"This means AutoPyTorch was not able to properly "
"Extract the dtypes of the provided input features. "
"Please try to manually cast it to a supported "
"numerical or categorical values.")
raise e
return X

def _check_data(
self,
Expand Down Expand Up @@ -231,6 +240,10 @@ def _check_data(
# If entered here, we have a pandas dataframe
X = typing.cast(pd.DataFrame, X)

# Handle objects if possible
if not X.select_dtypes(include='object').empty:
X = self.infer_objects(X)

# Define the column to be encoded here as the feature validator is fitted once
# per estimator
enc_columns, _ = self._get_columns_to_encode(X)
Expand All @@ -245,6 +258,7 @@ def _check_data(
)
else:
self.column_order = column_order

dtypes = [dtype.name for dtype in X.dtypes]
if len(self.dtypes) > 0:
if self.dtypes != dtypes:
Expand Down Expand Up @@ -379,3 +393,96 @@ def numpy_array_to_pandas(
pd.DataFrame
"""
return pd.DataFrame(X).infer_objects().convert_dtypes()

def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
"""
In case the input contains object columns, their type is inferred if possible
This has to be done once, so the test and train data are treated equally
Arguments:
X (pd.DataFrame):
data to be interpreted.
Returns:
pd.DataFrame
"""
if hasattr(self, 'object_dtype_mapping'):
# Mypy does not process the has attr. This dict is defined below
for key, dtype in self.object_dtype_mapping.items(): # type: ignore[has-type]
if 'int' in dtype.name:
# In the case train data was interpreted as int
# and test data was interpreted as float, because of 0.0
# for example, honor training data
X[key] = X[key].applymap(np.int64)
else:
try:
X[key] = X[key].astype(dtype.name)
except Exception as e:
# Try inference if possible
self.logger.warning(f"Tried to cast column {key} to {dtype} caused {e}")
pass
else:
X = X.infer_objects()
for column in X.columns:
if not is_numeric_dtype(X[column]):
X[column] = X[column].astype('category')
self.object_dtype_mapping = {column: X[column].dtype for column in X.columns}
self.logger.debug(f"Infer Objects: {self.object_dtype_mapping}")
return X

def impute_nan_in_categories(self, X: pd.DataFrame) -> pd.DataFrame:
"""
impute missing values before encoding,
remove once sklearn natively supports
it in ordinal encoding. Sklearn issue:
"https://github.com/scikit-learn/scikit-learn/issues/17123)"
Arguments:
X (pd.DataFrame):
data to be interpreted.
Returns:
pd.DataFrame
"""

# To be on the safe side, map always to the same missing
# value per column
if not hasattr(self, 'dict_nancol_to_missing'):
self.dict_missing_value_per_col: typing.Dict[str, typing.Any] = {}

# First make sure that we do not alter the type of the column which cause:
# TypeError: '<' not supported between instances of 'int' and 'str'
# in the encoding
for column in self.enc_columns:
if X[column].isna().any():
if column not in self.dict_missing_value_per_col:
try:
float(X[column].dropna().values[0])
can_cast_as_number = True
except Exception:
can_cast_as_number = False
if can_cast_as_number:
# In this case, we expect to have a number as category
# it might be string, but its value represent a number
missing_value: typing.Union[str, int] = '-1' if isinstance(X[column].dropna().values[0],
str) else -1
else:
missing_value = 'Missing!'

# Make sure this missing value is not seen before
# Do this check for categorical columns
# else modify the value
if hasattr(X[column], 'cat'):
while missing_value in X[column].cat.categories:
if isinstance(missing_value, str):
missing_value += '0'
else:
missing_value += missing_value
self.dict_missing_value_per_col[column] = missing_value

# Convert the frame in place
X[column].cat.add_categories([self.dict_missing_value_per_col[column]],
inplace=True)
X.fillna({column: self.dict_missing_value_per_col[column]}, inplace=True)
return X
6 changes: 6 additions & 0 deletions autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ def estimator_supports_iterative_fit(self) -> bool: # pylint: disable=R0201
def get_additional_run_info(self) -> None: # pylint: disable=R0201
return None

def get_pipeline_representation(self) -> Dict[str, str]:
return {
'Preprocessing': 'None',
'Estimator': 'Dummy',
}

@staticmethod
def get_default_pipeline_options() -> Dict[str, Any]:
return {'budget_type': 'epochs',
Expand Down
44 changes: 44 additions & 0 deletions test/test_api/test_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import pickle
import sys
import unittest

import numpy as np

Expand All @@ -21,6 +22,7 @@
CrossValTypes,
HoldoutValTypes,
)
from autoPyTorch.optimizer.smbo import AutoMLSMBO


# Fixtures
Expand Down Expand Up @@ -344,3 +346,45 @@ def test_tabular_regression(openml_name, resampling_strategy, backend):
with open(dump_file, 'rb') as f:
restored_estimator = pickle.load(f)
restored_estimator.predict(X_test)


@pytest.mark.parametrize('openml_id', (
1590, # Adult to test NaN in categorical columns
))
def test_tabular_input_support(openml_id, backend):
"""
Make sure we can process inputs with NaN in categorical and Object columns
when the later is possible
"""

# Get the data and check that contents of data-manager make sense
X, y = sklearn.datasets.fetch_openml(
data_id=int(openml_id),
return_X_y=True, as_frame=True
)

# Make sure we are robust against objects
X[X.columns[0]] = X[X.columns[0]].astype(object)

X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
X, y, random_state=1)
# Search for a good configuration
estimator = TabularClassificationTask(
backend=backend,
resampling_strategy=HoldoutValTypes.holdout_validation,
ensemble_size=0,
)

estimator._do_dummy_prediction = unittest.mock.MagicMock()

with unittest.mock.patch.object(AutoMLSMBO, 'run_smbo') as AutoMLSMBOMock:
AutoMLSMBOMock.return_value = ({}, {}, 'epochs')
estimator.search(
X_train=X_train, y_train=y_train,
X_test=X_test, y_test=y_test,
optimize_metric='accuracy',
total_walltime_limit=150,
func_eval_time_limit=50,
traditional_per_total_budget=0,
load_models=False,
)
6 changes: 1 addition & 5 deletions test/test_data/test_feature_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_featurevalidator_categorical_nan(input_data_featuretest):
validator.fit(input_data_featuretest)
transformed_X = validator.transform(input_data_featuretest)
assert any(pd.isna(input_data_featuretest))
assert any((-1 in categories) or ('-1' in categories) for categories in
assert any((-1 in categories) or ('-1' in categories) or ('Missing!' in categories) for categories in
validator.encoder.named_transformers_['encoder'].categories_)
assert np.shape(input_data_featuretest) == np.shape(transformed_X)
assert np.issubdtype(transformed_X.dtype, np.number)
Expand Down Expand Up @@ -328,10 +328,6 @@ def test_features_unsupported_calls_are_raised():
validator.fit(
pd.DataFrame({'datetime': [pd.Timestamp('20180310')]})
)
with pytest.raises(ValueError, match="has invalid type object"):
validator.fit(
pd.DataFrame({'string': [TabularFeatureValidator()]})
)
with pytest.raises(ValueError, match=r"AutoPyTorch only supports.*yet, the provided input"):
validator.fit({'input1': 1, 'input2': 2})
with pytest.raises(ValueError, match=r"has unsupported dtype string"):
Expand Down

0 comments on commit a3d40ac

Please sign in to comment.