diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py index 22caf3fe7d..08ff5f3716 100644 --- a/monai/losses/contrastive.py +++ b/monai/losses/contrastive.py @@ -9,13 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union - import torch from torch.nn import functional as F from torch.nn.modules.loss import _Loss -from monai.utils import LossReduction +from monai.utils import deprecated_arg class ContrastiveLoss(_Loss): @@ -31,19 +29,23 @@ class ContrastiveLoss(_Loss): """ - def __init__( - self, temperature: float = 0.5, batch_size: int = 1, reduction: Union[LossReduction, str] = LossReduction.SUM - ) -> None: + @deprecated_arg(name="reduction", since="0.8", msg_suffix="`reduction` is no longer supported.") + def __init__(self, temperature: float = 0.5, batch_size: int = 1, reduction="sum") -> None: """ Args: temperature: Can be scaled between 0 and 1 for learning from negative samples, ideally set to 0.5. + batch_size: The number of samples. Raises: - AssertionError: When an input of dimension length > 2 is passed - AssertionError: When input and target are of different shapes + ValueError: When an input of dimension length > 2 is passed + ValueError: When input and target are of different shapes + + .. deprecated:: 0.8.0 + + `reduction` is no longer supported. """ - super().__init__(reduction=LossReduction(reduction).value) + super().__init__() self.batch_size = batch_size self.temperature = temperature @@ -53,18 +55,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Args: input: the shape should be B[F]. target: the shape should be B[F]. - - Raises: - ValueError: When ``self.reduction`` is not one of ["sum", "none"]. """ if len(target.shape) > 2 or len(input.shape) > 2: - raise AssertionError( + raise ValueError( f"Either target or input has dimensions greater than 2 where target " f"shape is ({target.shape}) and input shape is ({input.shape})" ) if target.shape != input.shape: - raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") + raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") temperature_tensor = torch.tensor(self.temperature).to(input.device) @@ -86,6 +85,4 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1)) - if self.reduction == LossReduction.SUM.value: - return torch.sum(loss_partial) / (2 * self.batch_size) - raise ValueError(f"Unsupported reduction: {self.reduction}, " f'available options are ["mean", "sum", "none"].') + return torch.sum(loss_partial) / (2 * self.batch_size) diff --git a/tests/test_contrastive_loss.py b/tests/test_contrastive_loss.py index b9caecce65..4586c27b7e 100644 --- a/tests/test_contrastive_loss.py +++ b/tests/test_contrastive_loss.py @@ -61,7 +61,7 @@ def test_result(self, input_param, input_data, expected_val): def test_ill_shape(self): loss = ContrastiveLoss(temperature=0.5, batch_size=1) - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) def test_with_cuda(self):