Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX weighted loss issue #94

Merged
merged 8 commits into from
Feb 22, 2021
15 changes: 8 additions & 7 deletions autoPyTorch/pipeline/components/training/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
from autoPyTorch.constants import BINARY, CLASSIFICATION_TASKS, CONTINUOUS, MULTICLASS, REGRESSION_TASKS, \
STRING_TO_OUTPUT_TYPES, STRING_TO_TASK_TYPES, TASK_TYPES_TO_STRING


losses = dict(classification=dict(
CrossEntropyLoss=dict(
module=CrossEntropyLoss, supported_output_type=MULTICLASS),
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
module=CrossEntropyLoss, supported_output_types=(MULTICLASS, BINARY)),
BCEWithLogitsLoss=dict(
module=BCEWithLogitsLoss, supported_output_type=BINARY)),
module=BCEWithLogitsLoss, supported_output_types=(BINARY,))),
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
regression=dict(
MSELoss=dict(
module=MSELoss, supported_output_type=CONTINUOUS),
module=MSELoss, supported_output_types=(CONTINUOUS,)),
L1Loss=dict(
module=L1Loss, supported_output_type=CONTINUOUS)))
module=L1Loss, supported_output_types=(CONTINUOUS,))))

default_losses = dict(classification=CrossEntropyLoss, regression=MSELoss)

Expand All @@ -38,16 +39,16 @@ def get_supported_losses(task: int, output_type: int) -> Dict[str, Type[Loss]]:
supported_losses = dict()
if task in CLASSIFICATION_TASKS:
for key, value in losses['classification'].items():
if output_type == value['supported_output_type']:
if output_type in value['supported_output_types']:
supported_losses[key] = value['module']
elif task in REGRESSION_TASKS:
for key, value in losses['regression'].items():
if output_type == value['supported_output_type']:
if output_type in value['supported_output_types']:
supported_losses[key] = value['module']
return supported_losses


def get_loss_instance(dataset_properties: Dict[str, Any], name: Optional[str] = None) -> Type[Loss]:
def get_loss(dataset_properties: Dict[str, Any], name: Optional[str] = None) -> Type[Loss]:
assert 'task_type' in dataset_properties, \
"Expected dataset_properties to have task_type got {}".format(dataset_properties.keys())
assert 'output_type' in dataset_properties, \
Expand Down
23 changes: 10 additions & 13 deletions autoPyTorch/pipeline/components/training/trainer/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import numpy as np

Expand All @@ -11,7 +11,6 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.tensorboard.writer import SummaryWriter

from autoPyTorch.constants import BINARY
from autoPyTorch.pipeline.components.training.base_training import autoPyTorchTrainingComponent
from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score
from autoPyTorch.utils.implementations import get_loss_weight_strategy
Expand Down Expand Up @@ -175,14 +174,13 @@ def prepare(
self,
metrics: List[Any],
model: torch.nn.Module,
criterion: torch.nn.Module,
criterion: Type[torch.nn.Module],
budget_tracker: BudgetTracker,
optimizer: Optimizer,
device: torch.device,
metrics_during_training: bool,
scheduler: _LRScheduler,
task_type: int,
output_type: int,
labels: Union[np.ndarray, torch.Tensor, pd.DataFrame]
) -> None:

Expand All @@ -196,11 +194,7 @@ def prepare(
weights = None
kwargs = {}
if self.weighted_loss:
weights = self.get_class_weights(output_type, labels)
if output_type == BINARY:
kwargs['pos_weight'] = weights
else:
kwargs['weight'] = weights
kwargs = self.get_class_weights(criterion, labels)

criterion = criterion(**kwargs) if weights is not None else criterion()

Expand Down Expand Up @@ -376,13 +370,16 @@ def compute_metrics(self, outputs_data: np.ndarray, targets_data: np.ndarray
targets_data = torch.cat(targets_data, dim=0)
return calculate_score(targets_data, outputs_data, self.task_type, self.metrics)

def get_class_weights(self, output_type: int, labels: Union[np.ndarray, torch.Tensor, pd.DataFrame]
) -> np.ndarray:
strategy = get_loss_weight_strategy(output_type)
def get_class_weights(self, criterion: Type[torch.nn.Module], labels: Union[np.ndarray, torch.Tensor, pd.DataFrame]
) -> Dict[str, np.ndarray]:
strategy = get_loss_weight_strategy(criterion)
weights = strategy(y=labels)
weights = torch.from_numpy(weights)
weights = weights.type(torch.FloatTensor).to(self.device)
return weights
if criterion.__name__ == 'BCEWithLogitsLoss':
return {'pos_weight': weights}
else:
return {'weight': weights}

def data_preparation(self, X: np.ndarray, y: np.ndarray,
) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.tensorboard.writer import SummaryWriter

from autoPyTorch.constants import STRING_TO_OUTPUT_TYPES, STRING_TO_TASK_TYPES
from autoPyTorch.constants import STRING_TO_TASK_TYPES
from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice
from autoPyTorch.pipeline.components.base_component import (
ThirdPartyComponents,
autoPyTorchComponent,
find_components,
)
from autoPyTorch.pipeline.components.training.losses import get_loss_instance
from autoPyTorch.pipeline.components.training.losses import get_loss
from autoPyTorch.pipeline.components.training.metrics.utils import get_metrics
from autoPyTorch.pipeline.components.training.trainer.base_trainer import (
BaseTrainerComponent,
Expand Down Expand Up @@ -266,15 +266,14 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> torch.nn.Modu
model=X['network'],
metrics=get_metrics(dataset_properties=X['dataset_properties'],
names=additional_metrics),
criterion=get_loss_instance(X['dataset_properties'],
name=additional_losses),
criterion=get_loss(X['dataset_properties'],
name=additional_losses),
budget_tracker=budget_tracker,
optimizer=X['optimizer'],
device=self.get_device(X),
metrics_during_training=X['metrics_during_training'],
scheduler=X['lr_scheduler'],
task_type=STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']],
output_type=STRING_TO_OUTPUT_TYPES[X['dataset_properties']['output_type']],
labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]]
)
total_parameter_count, trainable_parameter_count = self.count_parameters(X['network'])
Expand Down
27 changes: 21 additions & 6 deletions autoPyTorch/utils/implementations.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
from typing import Callable, Union
from typing import Any, Callable, Dict, Type, Union

import numpy as np

import torch

from autoPyTorch.constants import BINARY


def get_loss_weight_strategy(output_type: int) -> Callable:
if output_type == BINARY:
def get_loss_weight_strategy(loss: Type[torch.nn.Module]) -> Callable:
"""
Utility function that returns strategy for the given loss
Args:
loss (Type[torch.nn.Module]): type of the loss function
Returns:
(Callable): Relevant Callable strategy
"""
if loss.__name__ in LossWeightStrategyWeightedBinary.get_properties()['supported_losses']:
return LossWeightStrategyWeightedBinary()
else:
elif loss.__name__ in LossWeightStrategyWeighted.get_properties()['supported_losses']:
return LossWeightStrategyWeighted()
else:
raise ValueError("No strategy currently supports the given loss, {}".format(loss.__name__))


class LossWeightStrategyWeighted():
Expand All @@ -34,6 +41,10 @@ def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray:

return weights

@staticmethod
def get_properties() -> Dict[str, Any]:
return {'supported_losses': ['CrossEntropyLoss']}


class LossWeightStrategyWeightedBinary():
def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
Expand All @@ -46,3 +57,7 @@ def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
weights = counts_zero / np.maximum(counts_one, 1)

return np.array(weights)

@staticmethod
def get_properties() -> Dict[str, Any]:
return {'supported_losses': ['BCEWithLogitsLoss']}
47 changes: 47 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from sklearn.datasets import fetch_openml, make_classification

import torch

from autoPyTorch.data.tabular_validator import TabularInputValidator
from autoPyTorch.datasets.tabular_dataset import TabularDataset
from autoPyTorch.utils.backend import create
Expand Down Expand Up @@ -357,3 +359,48 @@ def error_search_space_updates():
value_range=[0, 0.5],
default_value=0.2)
return updates


@pytest.fixture
def loss_cross_entropy_multiclass():
dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'multiclass'}
predictions = torch.randn(4, 4, requires_grad=True)
name = 'CrossEntropyLoss'
targets = torch.empty(4, dtype=torch.long).random_(4)
labels = torch.empty(20, dtype=torch.long).random_(4)
return dataset_properties, predictions, name, targets, labels


@pytest.fixture
def loss_cross_entropy_binary():
dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'binary'}
predictions = torch.randn(4, 2, requires_grad=True)
name = 'CrossEntropyLoss'
targets = torch.empty(4, dtype=torch.long).random_(2)
labels = torch.empty(20, dtype=torch.long).random_(2)
return dataset_properties, predictions, name, targets, labels


@pytest.fixture
def loss_bce():
dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'binary'}
predictions = torch.empty(4).random_(2)
name = 'BCEWithLogitsLoss'
targets = torch.empty(4).random_(2)
labels = torch.empty(20, dtype=torch.long).random_(2)
return dataset_properties, predictions, name, targets, labels


@pytest.fixture
def loss_mse():
dataset_properties = {'task_type': 'tabular_regression', 'output_type': 'continuous'}
predictions = torch.randn(4)
name = 'MSELoss'
targets = torch.randn(4)
labels = None
return dataset_properties, predictions, name, targets, labels


@pytest.fixture
def loss_details(request):
return request.getfixturevalue(request.param)
3 changes: 0 additions & 3 deletions test/test_pipeline/components/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def test_evaluate(self):
metrics_during_training=True,
scheduler=None,
task_type=self.task_type,
output_type=self.output_type,
labels=self.y
)

Expand Down Expand Up @@ -176,7 +175,6 @@ def test_epoch_training(self):
device=self.device,
metrics_during_training=True,
task_type=self.task_type,
output_type=self.output_type,
labels=self.y
)

Expand Down Expand Up @@ -210,7 +208,6 @@ def test_epoch_training(self):
device=self.device,
metrics_during_training=True,
task_type=self.task_type,
output_type=self.output_type,
labels=self.y
)

Expand Down
48 changes: 21 additions & 27 deletions test/test_pipeline/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import torch
from torch import nn

from autoPyTorch.constants import STRING_TO_OUTPUT_TYPES
from autoPyTorch.pipeline.components.training.losses import get_loss_instance
from autoPyTorch.pipeline.components.training.losses import get_loss
from autoPyTorch.utils.implementations import get_loss_weight_strategy


Expand All @@ -14,7 +13,7 @@
'continuous'])
def test_get_no_name(output_type):
dataset_properties = {'task_type': 'tabular_classification', 'output_type': output_type}
loss = get_loss_instance(dataset_properties)
loss = get_loss(dataset_properties)
assert isinstance(loss(), nn.Module)


Expand All @@ -23,7 +22,7 @@ def test_get_no_name(output_type):
def test_get_name(output_type_name):
output_type, name = output_type_name
dataset_properties = {'task_type': 'tabular_classification', 'output_type': output_type}
loss = get_loss_instance(dataset_properties, name)()
loss = get_loss(dataset_properties, name)()
assert isinstance(loss, nn.Module)
assert str(loss) == f"{name}()"

Expand All @@ -32,29 +31,24 @@ def test_get_name_error():
dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'multiclass'}
name = 'BCELoss'
with pytest.raises(ValueError, match=r"Invalid name entered for task [a-z]+_[a-z]+, "):
get_loss_instance(dataset_properties, name)
get_loss(dataset_properties, name)


@pytest.mark.parametrize('weighted', [True, False])
def test_losses(weighted):
list_properties = [{'task_type': 'tabular_classification', 'output_type': 'multiclass'},
{'task_type': 'tabular_classification', 'output_type': 'binary'},
{'task_type': 'tabular_regression', 'output_type': 'continuous'}]
pred_cross_entropy = torch.randn(4, 4, requires_grad=True)
list_predictions = [pred_cross_entropy, torch.empty(4).random_(2), torch.randn(4)]
list_names = [None, 'BCEWithLogitsLoss', None]
list_targets = [torch.empty(4, dtype=torch.long).random_(4), torch.empty(4).random_(2), torch.randn(4)]
labels = [torch.empty(20, dtype=torch.long).random_(4), torch.empty(12, dtype=torch.long).random_(2), None]
for dataset_properties, pred, target, name, label in zip(list_properties, list_predictions,
list_targets, list_names, labels):
loss = get_loss_instance(dataset_properties=dataset_properties, name=name)
weights = None
if bool(weighted) and 'classification' in dataset_properties['task_type']:
strategy = get_loss_weight_strategy(output_type=STRING_TO_OUTPUT_TYPES[dataset_properties['output_type']])
weights = strategy(y=label)
weights = torch.from_numpy(weights)
weights = weights.type(torch.FloatTensor)
kwargs = {'pos_weight': weights} if 'binary' in dataset_properties['output_type'] else {'weight': weights}
loss = loss() if weights is None else loss(**kwargs)
score = loss(pred, target)
assert isinstance(score, torch.Tensor)
@pytest.mark.parametrize('loss_details', ['loss_cross_entropy_multiclass',
'loss_cross_entropy_binary',
'loss_bce',
'loss_mse'], indirect=True)
def test_losses(weighted, loss_details):
dataset_properties, predictions, name, targets, labels = loss_details
loss = get_loss(dataset_properties=dataset_properties, name=name)
weights = None
if bool(weighted) and 'classification' in dataset_properties['task_type']:
strategy = get_loss_weight_strategy(loss)
weights = strategy(y=labels)
weights = torch.from_numpy(weights)
weights = weights.type(torch.FloatTensor)
kwargs = {'pos_weight': weights} if loss.__name__ == 'BCEWithLogitsLoss' else {'weight': weights}
loss = loss() if weights is None else loss(**kwargs)
score = loss(predictions, targets)
assert isinstance(score, torch.Tensor)
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved