Skip to content

Commit

Permalink
[ADD] Test evaluator (#368)
Browse files Browse the repository at this point in the history
* add test evaluator

* add no resampling and other changes for test evaluator

* finalise changes for test_evaluator, TODO: tests

* add tests for new functionality

* fix flake and mypy

* add documentation for the evaluator

* add NoResampling to fit_pipeline

* raise error when trying to construct ensemble with noresampling

* fix tests

* reduce fit_pipeline accuracy check

* Apply suggestions from code review

Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com>

* address comments from shuhei

* fix bug in base data loader

* fix bug in data loader for val set

* fix bugs introduced in suggestions

* fix flake

* fix bug in test preprocessing

* fix bug in test data loader

* merge tests for evaluators and change listcomp in get_best_epoch

* rename resampling strategies

* add test for get dataset

Co-authored-by: nabenabe0928 <47781922+nabenabe0928@users.noreply.github.com>
  • Loading branch information
ravinkohli and nabenabe0928 committed Jan 25, 2022
1 parent c0fb82e commit 6554702
Show file tree
Hide file tree
Showing 22 changed files with 817 additions and 120 deletions.
34 changes: 25 additions & 9 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@
)
from autoPyTorch.data.base_validator import BaseInputValidator
from autoPyTorch.datasets.base_dataset import BaseDataset, BaseDatasetPropertiesType
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes
from autoPyTorch.datasets.resampling_strategy import (
CrossValTypes,
HoldoutValTypes,
NoResamplingStrategyTypes,
ResamplingStrategies,
)
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
from autoPyTorch.ensemble.singlebest_ensemble import SingleBest
from autoPyTorch.evaluation.abstract_evaluator import fit_and_suppress_warnings
Expand Down Expand Up @@ -145,6 +150,13 @@ class BaseTask(ABC):
name and Value is an Iterable of the names of the components
to exclude. All except these components will be present in
the search space.
resampling_strategy resampling_strategy (RESAMPLING_STRATEGIES),
(default=HoldoutValTypes.holdout_validation):
strategy to split the training data.
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
required for the chosen resampling strategy. If None, uses
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
in ```datasets/resampling_strategy.py```.
search_space_updates (Optional[HyperparameterSearchSpaceUpdates]):
Search space updates that can be used to modify the search
space of particular components or choice modules of the pipeline
Expand All @@ -166,11 +178,15 @@ def __init__(
include_components: Optional[Dict[str, Any]] = None,
exclude_components: Optional[Dict[str, Any]] = None,
backend: Optional[Backend] = None,
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
resampling_strategy: ResamplingStrategies = HoldoutValTypes.holdout_validation,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None,
task_type: Optional[str] = None
) -> None:

if isinstance(resampling_strategy, NoResamplingStrategyTypes) and ensemble_size != 0:
raise ValueError("`NoResamplingStrategy` cannot be used for ensemble construction")

self.seed = seed
self.n_jobs = n_jobs
self.n_threads = n_threads
Expand Down Expand Up @@ -280,7 +296,7 @@ def _get_dataset_input_validator(
y_train: Union[List, pd.DataFrame, np.ndarray],
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]] = None,
resampling_strategy: Optional[ResamplingStrategies] = None,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
dataset_name: Optional[str] = None,
) -> Tuple[BaseDataset, BaseInputValidator]:
Expand All @@ -298,7 +314,7 @@ def _get_dataset_input_validator(
Testing feature set
y_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
Testing target set
resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
Strategy to split the training data. if None, uses
HoldoutValTypes.holdout_validation.
resampling_strategy_args (Optional[Dict[str, Any]]):
Expand All @@ -322,7 +338,7 @@ def get_dataset(
y_train: Union[List, pd.DataFrame, np.ndarray],
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]] = None,
resampling_strategy: Optional[ResamplingStrategies] = None,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
dataset_name: Optional[str] = None,
) -> BaseDataset:
Expand All @@ -338,7 +354,7 @@ def get_dataset(
Testing feature set
y_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
Testing target set
resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
Strategy to split the training data. if None, uses
HoldoutValTypes.holdout_validation.
resampling_strategy_args (Optional[Dict[str, Any]]):
Expand Down Expand Up @@ -1360,7 +1376,7 @@ def fit_pipeline(
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
dataset_name: Optional[str] = None,
resampling_strategy: Optional[Union[HoldoutValTypes, CrossValTypes]] = None,
resampling_strategy: Optional[Union[HoldoutValTypes, CrossValTypes, NoResamplingStrategyTypes]] = None,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
run_time_limit_secs: int = 60,
memory_limit: Optional[int] = None,
Expand Down Expand Up @@ -1395,7 +1411,7 @@ def fit_pipeline(
be provided to track the generalization performance of each stage.
dataset_name (Optional[str]):
Name of the dataset, if None, random value is used.
resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
Strategy to split the training data. if None, uses
HoldoutValTypes.holdout_validation.
resampling_strategy_args (Optional[Dict[str, Any]]):
Expand Down Expand Up @@ -1657,7 +1673,7 @@ def predict(
# Mypy assert
assert self.ensemble_ is not None, "Load models should error out if no ensemble"

if isinstance(self.resampling_strategy, HoldoutValTypes):
if isinstance(self.resampling_strategy, (HoldoutValTypes, NoResamplingStrategyTypes)):
models = self.models_
elif isinstance(self.resampling_strategy, CrossValTypes):
models = self.cv_models_
Expand Down
17 changes: 12 additions & 5 deletions autoPyTorch/api/tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from autoPyTorch.data.tabular_validator import TabularInputValidator
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.datasets.resampling_strategy import (
CrossValTypes,
HoldoutValTypes,
ResamplingStrategies,
)
from autoPyTorch.datasets.tabular_dataset import TabularDataset
from autoPyTorch.evaluation.utils import DisableFileOutputParameters
Expand Down Expand Up @@ -64,8 +64,15 @@ class TabularClassificationTask(BaseTask):
name and Value is an Iterable of the names of the components
to exclude. All except these components will be present in
the search space.
resampling_strategy resampling_strategy (RESAMPLING_STRATEGIES),
(default=HoldoutValTypes.holdout_validation):
strategy to split the training data.
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
required for the chosen resampling strategy. If None, uses
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
in ```datasets/resampling_strategy.py```.
search_space_updates (Optional[HyperparameterSearchSpaceUpdates]):
search space updates that can be used to modify the search
Search space updates that can be used to modify the search
space of particular components or choice modules of the pipeline
"""
def __init__(
Expand All @@ -83,7 +90,7 @@ def __init__(
delete_output_folder_after_terminate: bool = True,
include_components: Optional[Dict[str, Any]] = None,
exclude_components: Optional[Dict[str, Any]] = None,
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
resampling_strategy: ResamplingStrategies = HoldoutValTypes.holdout_validation,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
backend: Optional[Backend] = None,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
Expand Down Expand Up @@ -153,7 +160,7 @@ def _get_dataset_input_validator(
y_train: Union[List, pd.DataFrame, np.ndarray],
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]] = None,
resampling_strategy: Optional[ResamplingStrategies] = None,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
dataset_name: Optional[str] = None,
) -> Tuple[TabularDataset, TabularInputValidator]:
Expand All @@ -170,7 +177,7 @@ def _get_dataset_input_validator(
Testing feature set
y_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
Testing target set
resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
Strategy to split the training data. if None, uses
HoldoutValTypes.holdout_validation.
resampling_strategy_args (Optional[Dict[str, Any]]):
Expand Down
17 changes: 12 additions & 5 deletions autoPyTorch/api/tabular_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from autoPyTorch.data.tabular_validator import TabularInputValidator
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.datasets.resampling_strategy import (
CrossValTypes,
HoldoutValTypes,
ResamplingStrategies,
)
from autoPyTorch.datasets.tabular_dataset import TabularDataset
from autoPyTorch.evaluation.utils import DisableFileOutputParameters
Expand Down Expand Up @@ -64,8 +64,15 @@ class TabularRegressionTask(BaseTask):
name and Value is an Iterable of the names of the components
to exclude. All except these components will be present in
the search space.
resampling_strategy resampling_strategy (RESAMPLING_STRATEGIES),
(default=HoldoutValTypes.holdout_validation):
strategy to split the training data.
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
required for the chosen resampling strategy. If None, uses
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
in ```datasets/resampling_strategy.py```.
search_space_updates (Optional[HyperparameterSearchSpaceUpdates]):
search space updates that can be used to modify the search
Search space updates that can be used to modify the search
space of particular components or choice modules of the pipeline
"""

Expand All @@ -84,7 +91,7 @@ def __init__(
delete_output_folder_after_terminate: bool = True,
include_components: Optional[Dict[str, Any]] = None,
exclude_components: Optional[Dict[str, Any]] = None,
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
resampling_strategy: ResamplingStrategies = HoldoutValTypes.holdout_validation,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
backend: Optional[Backend] = None,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
Expand Down Expand Up @@ -154,7 +161,7 @@ def _get_dataset_input_validator(
y_train: Union[List, pd.DataFrame, np.ndarray],
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]] = None,
resampling_strategy: Optional[ResamplingStrategies] = None,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
dataset_name: Optional[str] = None,
) -> Tuple[TabularDataset, TabularInputValidator]:
Expand All @@ -171,7 +178,7 @@ def _get_dataset_input_validator(
Testing feature set
y_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
Testing target set
resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
Strategy to split the training data. if None, uses
HoldoutValTypes.holdout_validation.
resampling_strategy_args (Optional[Dict[str, Any]]):
Expand Down
39 changes: 28 additions & 11 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
DEFAULT_RESAMPLING_PARAMETERS,
HoldOutFunc,
HoldOutFuncs,
HoldoutValTypes
HoldoutValTypes,
NoResamplingFunc,
NoResamplingFuncs,
NoResamplingStrategyTypes,
ResamplingStrategies
)
from autoPyTorch.utils.common import FitRequirement

Expand Down Expand Up @@ -78,7 +82,7 @@ def __init__(
dataset_name: Optional[str] = None,
val_tensors: Optional[BaseDatasetInputType] = None,
test_tensors: Optional[BaseDatasetInputType] = None,
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
resampling_strategy: ResamplingStrategies = HoldoutValTypes.holdout_validation,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
shuffle: Optional[bool] = True,
seed: Optional[int] = 42,
Expand All @@ -95,8 +99,7 @@ def __init__(
validation data
test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute):
test data
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]),
(default=HoldoutValTypes.holdout_validation):
resampling_strategy (RESAMPLING_STRATEGIES: default=HoldoutValTypes.holdout_validation):
strategy to split the training data.
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
required for the chosen resampling strategy. If None, uses
Expand All @@ -109,16 +112,18 @@ def __init__(
val_transforms (Optional[torchvision.transforms.Compose]):
Additional Transforms to be applied to the validation/test data
"""
self.dataset_name = dataset_name

if self.dataset_name is None:
if dataset_name is None:
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
else:
self.dataset_name = dataset_name

if not hasattr(train_tensors[0], 'shape'):
type_check(train_tensors, val_tensors)
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
self.cross_validators: Dict[str, CrossValFunc] = {}
self.holdout_validators: Dict[str, HoldOutFunc] = {}
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
self.random_state = np.random.RandomState(seed=seed)
self.shuffle = shuffle
self.resampling_strategy = resampling_strategy
Expand All @@ -143,6 +148,8 @@ def __init__(
# Make sure cross validation splits are created once
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
self.no_resampling_validators = NoResamplingFuncs.get_no_resampling_validators(*NoResamplingStrategyTypes)

self.splits = self.get_splits_from_resampling_strategy()

# We also need to be able to transform the data, be it for pre-processing
Expand Down Expand Up @@ -210,7 +217,7 @@ def __len__(self) -> int:
def _get_indices(self) -> np.ndarray:
return self.random_state.permutation(len(self)) if self.shuffle else np.arange(len(self))

def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]:
def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], Optional[List[int]]]]:
"""
Creates a set of splits based on a resampling strategy provided
Expand Down Expand Up @@ -241,6 +248,9 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
num_splits=cast(int, num_splits),
)
)
elif isinstance(self.resampling_strategy, NoResamplingStrategyTypes):
splits.append((self.no_resampling_validators[self.resampling_strategy.name](self.random_state,
self._get_indices()), None))
else:
raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}")
return splits
Expand Down Expand Up @@ -312,22 +322,29 @@ def create_holdout_val_split(
self.random_state, val_share, self._get_indices(), **kwargs)
return train, val

def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
def get_dataset(self, split_id: int, train: bool) -> Dataset:
"""
The above split methods employ the Subset to internally subsample the whole dataset.
During training, we need access to one of those splits. This is a handy function
to provide training data to fit a pipeline
Args:
split (int): The desired subset of the dataset to split and use
split_id (int): which split id to get from the splits
train (bool): whether the dataset is required for training or evaluating.
Returns:
Dataset: the reduced dataset to be used for testing
"""
# Subset creates a dataset. Splits is a (train_indices, test_indices) tuple
return (TransformSubset(self, self.splits[split_id][0], train=True),
TransformSubset(self, self.splits[split_id][1], train=False))
if split_id >= len(self.splits): # old version: split_id > len(self.splits)
raise IndexError(f"self.splits index out of range, got split_id={split_id}"
f" (>= num_splits={len(self.splits)})")
indices = self.splits[split_id][int(not train)] # 0: for training, 1: for evaluation
if indices is None:
raise ValueError("Specified fold (or subset) does not exist")

return TransformSubset(self, indices, train=train)

def replace_data(self, X_train: BaseDatasetInputType,
X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset':
Expand Down
7 changes: 5 additions & 2 deletions autoPyTorch/datasets/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from autoPyTorch.datasets.resampling_strategy import (
CrossValTypes,
HoldoutValTypes,
NoResamplingStrategyTypes
)

IMAGE_DATASET_INPUT = Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]
Expand All @@ -39,7 +40,7 @@ class ImageDataset(BaseDataset):
validation data
test (Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]):
testing data
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]),
resampling_strategy (Union[CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]),
(default=HoldoutValTypes.holdout_validation):
strategy to split the training data.
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
Expand All @@ -57,7 +58,9 @@ def __init__(self,
train: IMAGE_DATASET_INPUT,
val: Optional[IMAGE_DATASET_INPUT] = None,
test: Optional[IMAGE_DATASET_INPUT] = None,
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
resampling_strategy: Union[CrossValTypes,
HoldoutValTypes,
NoResamplingStrategyTypes] = HoldoutValTypes.holdout_validation,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
shuffle: Optional[bool] = True,
seed: Optional[int] = 42,
Expand Down

0 comments on commit 6554702

Please sign in to comment.