Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX weighted loss issue #94

Merged
merged 8 commits into from
Feb 22, 2021
58 changes: 51 additions & 7 deletions autoPyTorch/pipeline/components/training/losses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
"""
Loss functions available in autoPyTorch

Classification:
CrossEntropyLoss: supports multiclass, binary output types
BCEWithLogitsLoss: supports binary output types
Default: CrossEntropyLoss
Regression:
MSELoss: supports continuous output types
L1Loss: supports continuous output types
Default: MSELoss
"""
from typing import Any, Dict, Optional, Type

from torch.nn.modules.loss import (
Expand All @@ -11,21 +23,30 @@
from autoPyTorch.constants import BINARY, CLASSIFICATION_TASKS, CONTINUOUS, MULTICLASS, REGRESSION_TASKS, \
STRING_TO_OUTPUT_TYPES, STRING_TO_TASK_TYPES, TASK_TYPES_TO_STRING


losses = dict(classification=dict(
CrossEntropyLoss=dict(
module=CrossEntropyLoss, supported_output_type=MULTICLASS),
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
module=CrossEntropyLoss, supported_output_types=[MULTICLASS, BINARY]),
BCEWithLogitsLoss=dict(
module=BCEWithLogitsLoss, supported_output_type=BINARY)),
module=BCEWithLogitsLoss, supported_output_types=[BINARY])),
regression=dict(
MSELoss=dict(
module=MSELoss, supported_output_type=CONTINUOUS),
module=MSELoss, supported_output_types=[CONTINUOUS]),
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
L1Loss=dict(
module=L1Loss, supported_output_type=CONTINUOUS)))
module=L1Loss, supported_output_types=[CONTINUOUS])))
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved

default_losses = dict(classification=CrossEntropyLoss, regression=MSELoss)


def get_default(task: int) -> Type[Loss]:
"""
Utility function to get default loss for the task
Args:
task (int):

Returns:
Type[torch.nn.modules.loss._Loss]
"""
if task in CLASSIFICATION_TASKS:
return default_losses['classification']
elif task in REGRESSION_TASKS:
Expand All @@ -35,19 +56,42 @@ def get_default(task: int) -> Type[Loss]:


def get_supported_losses(task: int, output_type: int) -> Dict[str, Type[Loss]]:
"""
Utility function to get supported losses for a given task and output type
Args:
task (int): integer identifier for the task
output_type: integer identifier for the output type of the task

Returns:
Returns a dictionary containing the losses supported for the given
inputs. Key-Name, Value-Module
"""
supported_losses = dict()
if task in CLASSIFICATION_TASKS:
for key, value in losses['classification'].items():
if output_type == value['supported_output_type']:
if output_type in value['supported_output_types']:
supported_losses[key] = value['module']
elif task in REGRESSION_TASKS:
for key, value in losses['regression'].items():
if output_type == value['supported_output_type']:
if output_type in value['supported_output_types']:
supported_losses[key] = value['module']
return supported_losses


def get_loss_instance(dataset_properties: Dict[str, Any], name: Optional[str] = None) -> Type[Loss]:
def get_loss(dataset_properties: Dict[str, Any], name: Optional[str] = None) -> Type[Loss]:
"""
Utility function to get losses for the given dataset properties.
If name is mentioned, checks if the loss is compatible with
the dataset properties and returns the specific loss
Args:
dataset_properties (Dict[str, Any]): Dictionary containing
properties of the dataset. Must contain task_type and
output_type as strings.
name (Optional[str]): name of the specific loss

Returns:
Type[torch.nn.modules.loss._Loss]
"""
assert 'task_type' in dataset_properties, \
"Expected dataset_properties to have task_type got {}".format(dataset_properties.keys())
assert 'output_type' in dataset_properties, \
Expand Down
32 changes: 14 additions & 18 deletions autoPyTorch/pipeline/components/training/trainer/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import numpy as np

Expand All @@ -10,6 +10,7 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.tensorboard.writer import SummaryWriter


from autoPyTorch.constants import REGRESSION_TASKS
from autoPyTorch.pipeline.components.training.base_training import autoPyTorchTrainingComponent
from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score
Expand Down Expand Up @@ -173,14 +174,13 @@ def prepare(
self,
metrics: List[Any],
model: torch.nn.Module,
criterion: torch.nn.Module,
criterion: Type[torch.nn.Module],
budget_tracker: BudgetTracker,
optimizer: Optimizer,
device: torch.device,
metrics_during_training: bool,
scheduler: _LRScheduler,
task_type: int,
output_type: int,
labels: Union[np.ndarray, torch.Tensor, pd.DataFrame]
) -> None:

Expand All @@ -191,19 +191,12 @@ def prepare(
self.metrics = metrics

# Weights for the loss function
weights = None
kwargs: Dict[str, Any] = {}
# if self.weighted_loss:
# weights = self.get_class_weights(output_type, labels)
# if output_type == BINARY:
# kwargs['pos_weight'] = weights
# pass
# else:
# kwargs['weight'] = weights
kwargs = {}
if self.weighted_loss:
kwargs = self.get_class_weights(criterion, labels)

# Setup the loss function
self.criterion = criterion(**kwargs) if weights is not None else criterion()

self.criterion = criterion(**kwargs)
# setup the model
self.model = model.to(device)

Expand Down Expand Up @@ -384,13 +377,16 @@ def compute_metrics(self, outputs_data: np.ndarray, targets_data: np.ndarray
targets_data = torch.cat(targets_data, dim=0).numpy()
return calculate_score(targets_data, outputs_data, self.task_type, self.metrics)

def get_class_weights(self, output_type: int, labels: Union[np.ndarray, torch.Tensor, pd.DataFrame]
) -> np.ndarray:
strategy = get_loss_weight_strategy(output_type)
def get_class_weights(self, criterion: Type[torch.nn.Module], labels: Union[np.ndarray, torch.Tensor, pd.DataFrame]
) -> Dict[str, np.ndarray]:
strategy = get_loss_weight_strategy(criterion)
weights = strategy(y=labels)
weights = torch.from_numpy(weights)
weights = weights.float().to(self.device)
return weights
if criterion.__name__ == 'BCEWithLogitsLoss':
return {'pos_weight': weights}
else:
return {'weight': weights}

def data_preparation(self, X: np.ndarray, y: np.ndarray,
) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.tensorboard.writer import SummaryWriter

from autoPyTorch.constants import STRING_TO_OUTPUT_TYPES, STRING_TO_TASK_TYPES
from autoPyTorch.constants import STRING_TO_TASK_TYPES
from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice
from autoPyTorch.pipeline.components.base_component import (
ThirdPartyComponents,
autoPyTorchComponent,
find_components,
)
from autoPyTorch.pipeline.components.training.losses import get_loss_instance
from autoPyTorch.pipeline.components.training.losses import get_loss
from autoPyTorch.pipeline.components.training.metrics.utils import get_metrics
from autoPyTorch.pipeline.components.training.trainer.base_trainer import (
BaseTrainerComponent,
Expand Down Expand Up @@ -265,15 +265,14 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> torch.nn.Modu
model=X['network'],
metrics=get_metrics(dataset_properties=X['dataset_properties'],
names=additional_metrics),
criterion=get_loss_instance(X['dataset_properties'],
name=additional_losses),
criterion=get_loss(X['dataset_properties'],
name=additional_losses),
budget_tracker=budget_tracker,
optimizer=X['optimizer'],
device=get_device_from_fit_dictionary(X),
metrics_during_training=X['metrics_during_training'],
scheduler=X['lr_scheduler'],
task_type=STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']],
output_type=STRING_TO_OUTPUT_TYPES[X['dataset_properties']['output_type']],
labels=X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]]
)
total_parameter_count, trainable_parameter_count = self.count_parameters(X['network'])
Expand Down
27 changes: 21 additions & 6 deletions autoPyTorch/utils/implementations.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
from typing import Callable, Union
from typing import Any, Callable, Dict, Type, Union

import numpy as np

import torch

from autoPyTorch.constants import BINARY


def get_loss_weight_strategy(output_type: int) -> Callable:
if output_type == BINARY:
def get_loss_weight_strategy(loss: Type[torch.nn.Module]) -> Callable:
"""
Utility function that returns strategy for the given loss
Args:
loss (Type[torch.nn.Module]): type of the loss function
Returns:
(Callable): Relevant Callable strategy
"""
if loss.__name__ in LossWeightStrategyWeightedBinary.get_properties()['supported_losses']:
return LossWeightStrategyWeightedBinary()
else:
elif loss.__name__ in LossWeightStrategyWeighted.get_properties()['supported_losses']:
return LossWeightStrategyWeighted()
else:
raise ValueError("No strategy currently supports the given loss, {}".format(loss.__name__))


class LossWeightStrategyWeighted():
Expand All @@ -34,6 +41,10 @@ def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray:

return weights

@staticmethod
def get_properties() -> Dict[str, Any]:
return {'supported_losses': ['CrossEntropyLoss']}


class LossWeightStrategyWeightedBinary():
def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
Expand All @@ -46,3 +57,7 @@ def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
weights = counts_zero / np.maximum(counts_one, 1)

return np.array(weights)

@staticmethod
def get_properties() -> Dict[str, Any]:
return {'supported_losses': ['BCEWithLogitsLoss']}
60 changes: 60 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from sklearn.datasets import fetch_openml, make_classification, make_regression

import torch

from autoPyTorch.data.tabular_validator import TabularInputValidator
from autoPyTorch.datasets.tabular_dataset import TabularDataset
from autoPyTorch.utils.backend import create
Expand Down Expand Up @@ -357,3 +359,61 @@ def error_search_space_updates():
value_range=[0, 0.5],
default_value=0.2)
return updates


@pytest.fixture
def loss_cross_entropy_multiclass():
dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'multiclass'}
predictions = torch.randn(4, 4, requires_grad=True)
name = 'CrossEntropyLoss'
targets = torch.empty(4, dtype=torch.long).random_(4)
# to ensure we have all classes in the labels
while True:
labels = torch.empty(20, dtype=torch.long).random_(4)
if len(torch.unique(labels)) == 4:
break

return dataset_properties, predictions, name, targets, labels


@pytest.fixture
def loss_cross_entropy_binary():
dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'binary'}
predictions = torch.randn(4, 2, requires_grad=True)
name = 'CrossEntropyLoss'
targets = torch.empty(4, dtype=torch.long).random_(2)
# to ensure we have all classes in the labels
while True:
labels = torch.empty(20, dtype=torch.long).random_(2)
if len(torch.unique(labels)) == 2:
break
return dataset_properties, predictions, name, targets, labels


@pytest.fixture
def loss_bce():
dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'binary'}
predictions = torch.empty(4).random_(2)
name = 'BCEWithLogitsLoss'
targets = torch.empty(4).random_(2)
# to ensure we have all classes in the labels
while True:
labels = torch.empty(20, dtype=torch.long).random_(2)
if len(torch.unique(labels)) == 2:
break
return dataset_properties, predictions, name, targets, labels


@pytest.fixture
def loss_mse():
dataset_properties = {'task_type': 'tabular_regression', 'output_type': 'continuous'}
predictions = torch.randn(4)
name = 'MSELoss'
targets = torch.randn(4)
labels = None
return dataset_properties, predictions, name, targets, labels


@pytest.fixture
def loss_details(request):
return request.getfixturevalue(request.param)
1 change: 0 additions & 1 deletion test/test_pipeline/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def prepare_trainer(self,
device=device,
metrics_during_training=True,
task_type=task_type,
output_type=output_type,
labels=y
)
return trainer, model, optimizer, loader, criterion, epochs, logger
Expand Down
Empty file.
Empty file.