Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ravinkohli committed Nov 8, 2021
1 parent 1793b5c commit febe4b5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 31 deletions.
56 changes: 31 additions & 25 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
STRING_TO_TASK_TYPES,
)
from autoPyTorch.data.base_validator import BaseInputValidator
from autoPyTorch.datasets.base_dataset import BaseDataset
from autoPyTorch.datasets.base_dataset import BaseDataset, BaseDatasetPropertiesType
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
from autoPyTorch.ensemble.singlebest_ensemble import SingleBest
Expand Down Expand Up @@ -1068,6 +1068,28 @@ def _search(

return self

def _get_fit_dictionary(
self,
dataset_properties: Dict[str, BaseDatasetPropertiesType],
dataset: BaseDataset,
split_id: int = 0
) -> Dict[str, Any]:
X_test = dataset.test_tensors[0].copy() if dataset.test_tensors is not None else None
y_test = dataset.test_tensors[1].copy() if dataset.test_tensors is not None else None
X: Dict[str, Any] = dict({'dataset_properties': dataset_properties,
'backend': self._backend,
'X_train': dataset.train_tensors[0].copy(),
'y_train': dataset.train_tensors[1].copy(),
'X_test': X_test,
'y_test': y_test,
'train_indices': dataset.splits[split_id][0],
'val_indices': dataset.splits[split_id][1],
'split_id': split_id,
'num_run': self._backend.get_next_num_run(),
})
X.update(self.pipeline_options)
return X

def refit(
self,
dataset: BaseDataset,
Expand Down Expand Up @@ -1111,18 +1133,6 @@ def refit(
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
self._backend.save_datamanager(dataset)

X: Dict[str, Any] = dict({'dataset_properties': dataset_properties,
'backend': self._backend,
'X_train': dataset.train_tensors[0],
'y_train': dataset.train_tensors[1],
'X_test': dataset.test_tensors[0] if dataset.test_tensors is not None else None,
'y_test': dataset.test_tensors[1] if dataset.test_tensors is not None else None,
'train_indices': dataset.splits[split_id][0],
'val_indices': dataset.splits[split_id][1],
'split_id': split_id,
'num_run': self._backend.get_next_num_run(),
})
X.update(self.pipeline_options)
if self.models_ is None or len(self.models_) == 0 or self.ensemble_ is None:
self._load_models()

Expand All @@ -1138,6 +1148,10 @@ def refit(
# try to fit the model. If it fails, shuffle the data. This
# could alleviate the problem in algorithms that depend on
# the ordering of the data.
X = self._get_fit_dictionary(
dataset_properties=dataset_properties,
dataset=dataset,
split_id=split_id)
fit_and_suppress_warnings(self._logger, model, X, y=None)

self._clean_logger()
Expand Down Expand Up @@ -1191,18 +1205,10 @@ def fit(self,
pipeline.set_hyperparameters(pipeline_config)

# initialise fit dictionary
X: Dict[str, Any] = dict({'dataset_properties': dataset_properties,
'backend': self._backend,
'X_train': dataset.train_tensors[0],
'y_train': dataset.train_tensors[1],
'X_test': dataset.test_tensors[0] if dataset.test_tensors is not None else None,
'y_test': dataset.test_tensors[1] if dataset.test_tensors is not None else None,
'train_indices': dataset.splits[split_id][0],
'val_indices': dataset.splits[split_id][1],
'split_id': split_id,
'num_run': self._backend.get_next_num_run(),
})
X.update(self.pipeline_options)
X = self._get_fit_dictionary(
dataset_properties=dataset_properties,
dataset=dataset,
split_id=split_id)

fit_and_suppress_warnings(self._logger, pipeline, X, y=None)

Expand Down
12 changes: 6 additions & 6 deletions test/test_pipeline/components/setup/test_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,12 +483,12 @@ def test_dropout(self, resnet_shape):
backbone = resnet_backbone.build_backbone((100, 5))
dropout_probabilites = [resnet_backbone.config[key] for key in resnet_backbone.config if 'dropout_' in key]
dropout_shape = get_shaped_neuron_counts(
shape=self.config['resnet_shape'],
in_feat=0,
out_feat=0,
max_neurons=self.config["max_dropout"],
layer_count=self.config['num_groups'] + 1,
)[:-1]
shape=resnet_shape,
in_feat=0,
out_feat=0,
max_neurons=max_dropout,
layer_count=num_groups + 1,
)[:-1]
blocks_dropout = []
for block in backbone:
if isinstance(block, torch.nn.Sequential):
Expand Down

0 comments on commit febe4b5

Please sign in to comment.