Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Datamanager in memory #382

Merged
merged 4 commits into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
84 changes: 53 additions & 31 deletions autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,34 +433,16 @@ def __init__(self, backend: Backend,
self.backend: Backend = backend
self.queue = queue

self.datamanager: BaseDataset = self.backend.load_datamanager()

assert self.datamanager.task_type is not None, \
"Expected dataset {} to have task_type got None".format(self.datamanager.__class__.__name__)
self.task_type = STRING_TO_TASK_TYPES[self.datamanager.task_type]
self.output_type = STRING_TO_OUTPUT_TYPES[self.datamanager.output_type]
self.issparse = self.datamanager.issparse

self.include = include
self.exclude = exclude
self.search_space_updates = search_space_updates

self.X_train, self.y_train = self.datamanager.train_tensors

if self.datamanager.val_tensors is not None:
self.X_valid, self.y_valid = self.datamanager.val_tensors
else:
self.X_valid, self.y_valid = None, None

if self.datamanager.test_tensors is not None:
self.X_test, self.y_test = self.datamanager.test_tensors
else:
self.X_test, self.y_test = None, None

self.metric = metric

self.seed = seed

self._init_datamanager_info()

# Flag to save target for ensemble
self.output_y_hat_optimization = output_y_hat_optimization

Expand Down Expand Up @@ -497,12 +479,6 @@ def __init__(self, backend: Backend,
else:
raise ValueError('task {} not available'.format(self.task_type))
self.predict_function = self._predict_proba
self.dataset_properties = self.datamanager.get_dataset_properties(
get_dataset_requirements(info=self.datamanager.get_required_dataset_info(),
include=self.include,
exclude=self.exclude,
search_space_updates=self.search_space_updates
))

self.additional_metrics: Optional[List[autoPyTorchMetric]] = None
metrics_dict: Optional[Dict[str, List[str]]] = None
Expand Down Expand Up @@ -542,6 +518,53 @@ def __init__(self, backend: Backend,
self.logger.debug("Fit dictionary in Abstract evaluator: {}".format(dict_repr(self.fit_dictionary)))
self.logger.debug("Search space updates :{}".format(self.search_space_updates))

def _init_datamanager_info(
self,
) -> None:
"""
Initialises instance attributes that come from the datamanager.
For example,
X_train, y_train, etc.
"""

datamanager: BaseDataset = self.backend.load_datamanager()

assert datamanager.task_type is not None, \
"Expected dataset {} to have task_type got None".format(datamanager.__class__.__name__)
self.task_type = STRING_TO_TASK_TYPES[datamanager.task_type]
self.output_type = STRING_TO_OUTPUT_TYPES[datamanager.output_type]
self.issparse = datamanager.issparse

self.X_train, self.y_train = datamanager.train_tensors

if datamanager.val_tensors is not None:
self.X_valid, self.y_valid = datamanager.val_tensors
else:
self.X_valid, self.y_valid = None, None

if datamanager.test_tensors is not None:
self.X_test, self.y_test = datamanager.test_tensors
else:
self.X_test, self.y_test = None, None

self.resampling_strategy = datamanager.resampling_strategy

self.num_classes: Optional[int] = getattr(datamanager, "num_classes", None)

self.dataset_properties = datamanager.get_dataset_properties(
get_dataset_requirements(info=datamanager.get_required_dataset_info(),
include=self.include,
exclude=self.exclude,
search_space_updates=self.search_space_updates
))
self.splits = datamanager.splits
if self.splits is None:
raise AttributeError(f"create_splits on {datamanager.__class__.__name__} must be called "
f"before the instantiation of {self.__class__.__name__}")

# delete datamanager from memory
del datamanager

def _init_fit_dictionary(
self,
logger_port: int,
Expand Down Expand Up @@ -988,21 +1011,20 @@ def _ensure_prediction_array_sizes(self, prediction: np.ndarray,
(np.ndarray):
The formatted prediction
"""
assert self.datamanager.num_classes is not None, "Called function on wrong task"
num_classes: int = self.datamanager.num_classes
assert self.num_classes is not None, "Called function on wrong task"

if self.output_type == MULTICLASS and \
prediction.shape[1] < num_classes:
prediction.shape[1] < self.num_classes:
if Y_train is None:
raise ValueError('Y_train must not be None!')
classes = list(np.unique(Y_train))

mapping = dict()
for class_number in range(num_classes):
for class_number in range(self.num_classes):
if class_number in classes:
index = classes.index(class_number)
mapping[index] = class_number
new_predictions = np.zeros((prediction.shape[0], num_classes),
new_predictions = np.zeros((prediction.shape[0], self.num_classes),
dtype=np.float32)

for index in mapping:
Expand Down
9 changes: 2 additions & 7 deletions autoPyTorch/evaluation/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,12 @@ def __init__(
search_space_updates=search_space_updates
)

if not isinstance(self.datamanager.resampling_strategy, (NoResamplingStrategyTypes)):
resampling_strategy = self.datamanager.resampling_strategy
if not isinstance(self.resampling_strategy, (NoResamplingStrategyTypes)):
raise ValueError(
f'resampling_strategy for TestEvaluator must be in '
f'NoResamplingStrategyTypes, but got {resampling_strategy}'
f'NoResamplingStrategyTypes, but got {self.resampling_strategy}'
)

self.splits = self.datamanager.splits
if self.splits is None:
raise AttributeError("create_splits must be called in {}".format(self.datamanager.__class__.__name__))
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved

def fit_predict_and_loss(self) -> None:

split_id = 0
Expand Down
8 changes: 2 additions & 6 deletions autoPyTorch/evaluation/train_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,12 @@ def __init__(self, backend: Backend, queue: Queue,
search_space_updates=search_space_updates
)

if not isinstance(self.datamanager.resampling_strategy, (CrossValTypes, HoldoutValTypes)):
resampling_strategy = self.datamanager.resampling_strategy
if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)):
raise ValueError(
f'resampling_strategy for TrainEvaluator must be in '
f'(CrossValTypes, HoldoutValTypes), but got {resampling_strategy}'
f'(CrossValTypes, HoldoutValTypes), but got {self.resampling_strategy}'
)

self.splits = self.datamanager.splits
if self.splits is None:
raise AttributeError("Must have called create_splits on {}".format(self.datamanager.__class__.__name__))
self.num_folds: int = len(self.splits)
self.Y_targets: List[Optional[np.ndarray]] = [None] * self.num_folds
self.Y_train_targets: np.ndarray = np.ones(self.y_train.shape) * np.NaN
Expand Down
14 changes: 1 addition & 13 deletions autoPyTorch/optimizer/smbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from smac.utils.io.traj_logging import TrajEntry

from autoPyTorch.automl_common.common.utils.backend import Backend
from autoPyTorch.datasets.base_dataset import BaseDataset
from autoPyTorch.datasets.resampling_strategy import (
CrossValTypes,
DEFAULT_RESAMPLING_PARAMETERS,
Expand Down Expand Up @@ -194,9 +193,8 @@ def __init__(self,
super(AutoMLSMBO, self).__init__()
# data related
self.dataset_name = dataset_name
self.datamanager: Optional[BaseDataset] = None
self.metric = metric
self.task: Optional[str] = None

self.backend = backend
self.all_supported_metrics = all_supported_metrics

Expand Down Expand Up @@ -252,21 +250,11 @@ def __init__(self,
self.initial_configurations = initial_configurations \
if len(initial_configurations) > 0 else None

def reset_data_manager(self) -> None:
if self.datamanager is not None:
del self.datamanager
self.datamanager = self.backend.load_datamanager()

if self.datamanager is not None and self.datamanager.task_type is not None:
self.task = self.datamanager.task_type

def run_smbo(self, func: Optional[Callable] = None
) -> Tuple[RunHistory, List[TrajEntry], str]:

self.watcher.start_task('SMBO')
self.logger.info("Started run of SMBO")
# == first things first: load the datamanager
self.reset_data_manager()

# == Initialize non-SMBO stuff
# first create a scenario
Expand Down