Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hotfix] Adapt LR scheduler to epoch wise #212

47 changes: 30 additions & 17 deletions autoPyTorch/pipeline/components/training/trainer/MixUpTrainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import typing
from typing import Any, Callable, Dict, Optional, Tuple

from ConfigSpace.configuration_space import ConfigurationSpace
from ConfigSpace.hyperparameters import (
Expand All @@ -11,13 +11,16 @@
import torch

from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES
from autoPyTorch.pipeline.components.training.trainer.base_trainer import BaseTrainerComponent
from autoPyTorch.pipeline.components.training.trainer.base_trainer import (
BaseTrainerComponent,
_NewLossParameters
)
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter


class MixUpTrainer(BaseTrainerComponent):
def __init__(self, alpha: float, weighted_loss: bool = False,
random_state: typing.Optional[np.random.RandomState] = None):
random_state: Optional[np.random.RandomState] = None):
"""
This class handles the training of a network for a single given epoch.

Expand All @@ -29,45 +32,55 @@ def __init__(self, alpha: float, weighted_loss: bool = False,
self.weighted_loss = weighted_loss
self.alpha = alpha

def data_preparation(self, X: np.ndarray, y: np.ndarray,
) -> typing.Tuple[np.ndarray, typing.Dict[str, np.ndarray]]:
def _data_preprocessing(self, X: torch.Tensor, y: torch.Tensor,
) -> Tuple[torch.Tensor, _NewLossParameters]:
"""
Depending on the trainer choice, data fed to the network might be pre-processed
on a different way. That is, in standard training we provide the data to the
network as we receive it to the loader. Some regularization techniques, like mixup
alter the data.

Args:
X (np.ndarray): The batch training features
y (np.ndarray): The batch training labels
X (torch.Tensor): The batch training features
y (torch.Tensor): The batch training labels

Returns:
np.ndarray: that processes data
typing.Dict[str, np.ndarray]: arguments to the criterion function
torch.Tensor: that processes data
_NewLossParameters: arguments to the new loss function
"""
lam = self.random_state.beta(self.alpha, self.alpha) if self.alpha > 0. else 1.
batch_size = X.size()[0]
index = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size)
batch_size = X.shape[0]
index = torch.randperm(batch_size).to(self.device)

mixed_x = lam * X + (1 - lam) * X[index, :]
y_a, y_b = y, y[index]
return mixed_x, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
return mixed_x, _NewLossParameters(y_a=y_a, y_b=y_b, lam=lam)

def _get_new_loss_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we try to make this in another PR? I think this does not directly helps to solve the LR schedule

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The renaming? Ok, I will do it in another PR.

self,
new_loss_params: _NewLossParameters
) -> Callable:

y_a = new_loss_params.y_a
y_b = new_loss_params.y_b
lam = new_loss_params.lam

if lam > 1.0 or lam < 0:
raise ValueError("The mixup coefficient `lam` must be [0, 1], but got {:.2f}.".format(lam))

def criterion_preparation(self, y_a: np.ndarray, y_b: np.ndarray = None, lam: float = 1.0
) -> typing.Callable:
return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

@staticmethod
def get_properties(dataset_properties: typing.Optional[typing.Dict[str, typing.Any]] = None
) -> typing.Dict[str, str]:
def get_properties(dataset_properties: Optional[Dict[str, Any]] = None
) -> Dict[str, str]:
return {
'shortname': 'MixUpTrainer',
'name': 'MixUp Regularized Trainer',
}

@staticmethod
def get_hyperparameter_search_space(
dataset_properties: typing.Optional[typing.Dict] = None,
dataset_properties: Optional[Dict] = None,
alpha: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="alpha",
value_range=(0, 1),
default_value=0.2),
Expand Down
40 changes: 24 additions & 16 deletions autoPyTorch/pipeline/components/training/trainer/StandardTrainer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import typing
from typing import Any, Callable, Dict, Optional, Tuple

from ConfigSpace.configuration_space import ConfigurationSpace
from ConfigSpace.hyperparameters import CategoricalHyperparameter

import numpy as np

import torch

from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES
from autoPyTorch.pipeline.components.training.trainer.base_trainer import BaseTrainerComponent
from autoPyTorch.pipeline.components.training.trainer.base_trainer import (
BaseTrainerComponent,
_NewLossParameters
)
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter


class StandardTrainer(BaseTrainerComponent):
def __init__(self, weighted_loss: bool = False,
random_state: typing.Optional[np.random.RandomState] = None):
random_state: Optional[np.random.RandomState] = None):
"""
This class handles the training of a network for a single given epoch.

Expand All @@ -23,39 +28,42 @@ def __init__(self, weighted_loss: bool = False,
super().__init__(random_state=random_state)
self.weighted_loss = weighted_loss

def data_preparation(self, X: np.ndarray, y: np.ndarray,
) -> typing.Tuple[np.ndarray, typing.Dict[str, np.ndarray]]:
def _data_preprocessing(self, X: torch.Tensor, y: torch.Tensor,
) -> Tuple[torch.Tensor, _NewLossParameters]:
"""
Depending on the trainer choice, data fed to the network might be pre-processed
on a different way. That is, in standard training we provide the data to the
network as we receive it to the loader. Some regularization techniques, like mixup
alter the data.

Args:
X (np.ndarray): The batch training features
y (np.ndarray): The batch training labels
X (torch.Tensor): The batch training features
y (torch.Tensor): The batch training labels

Returns:
np.ndarray: that processes data
typing.Dict[str, np.ndarray]: arguments to the criterion function
torch.Tensor: that processes data
_NewLossParameters: arguments to the new loss function
"""
return X, {'y_a': y}
return X, _NewLossParameters(y_a=y)

def criterion_preparation(self, y_a: np.ndarray, y_b: np.ndarray = None, lam: float = 1.0
) -> typing.Callable:
def _get_new_loss_fn(
self,
new_loss_params: _NewLossParameters
) -> Callable:
y_a = new_loss_params.y_a
return lambda criterion, pred: criterion(pred, y_a)

@staticmethod
def get_properties(dataset_properties: typing.Optional[typing.Dict[str, typing.Any]] = None
) -> typing.Dict[str, str]:
def get_properties(dataset_properties: Optional[Dict[str, Any]] = None
) -> Dict[str, str]:
return {
'shortname': 'StandardTrainer',
'name': 'StandardTrainer',
'name': 'Standard Trainer',
}

@staticmethod
def get_hyperparameter_search_space(
dataset_properties: typing.Optional[typing.Dict] = None,
dataset_properties: Optional[Dict] = None,
weighted_loss: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="weighted_loss",
value_range=(True, False),
default_value=True),
Expand Down
34 changes: 32 additions & 2 deletions autoPyTorch/pipeline/components/training/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
BaseTrainerComponent,
BudgetTracker,
RunSummary,
StepIntervalUnit
)
from autoPyTorch.utils.common import FitRequirement, get_device_from_fit_dictionary
from autoPyTorch.utils.logging_ import get_named_client_logger
Expand All @@ -45,6 +46,34 @@ def add_trainer(trainer: BaseTrainerComponent) -> None:
_addons.add_component(trainer)


class _X():
__slots__ = (
'additional_losses',
'additional_metrics',
'backend',
'budget_type',
'dataset_properties',
'device',
'early_stopping',
'epochs'
'logger_port',
'lr_scheduler',
'metrics_during_training',
'network',
'num_run',
'optimizer',
'runtime',
'split_id',
'step_unit',
'test_data_loader',
'torch_num_threads',
'train_data_loader',
'use_tensorboard_logger',
'val_data_loader',
'y_train',
)


class TrainerChoice(autoPyTorchChoice):
"""This class is an interface to the PyTorch trainer.

Expand Down Expand Up @@ -250,7 +279,7 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
# Support additional user metrics
additional_metrics = X['additional_metrics'] if 'additional_metrics' in X else None
additional_losses = X['additional_losses'] if 'additional_losses' in X else None
self.choice.prepare(
self.choice.set_training_params(
model=X['network'],
metrics=get_metrics(dataset_properties=X['dataset_properties'],
names=additional_metrics),
Expand All @@ -262,7 +291,8 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
metrics_during_training=X['metrics_during_training'],
scheduler=X['lr_scheduler'],
task_type=STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']],
labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]]
labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]],
step_unit=X.get('step_unit', StepIntervalUnit.batch)
)
total_parameter_count, trainable_parameter_count = self.count_parameters(X['network'])
self.run_summary = RunSummary(
Expand Down