diff --git a/monai/losses/dice.py b/monai/losses/dice.py index d74d40fe37..b3c0f57c6e 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -111,8 +111,9 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch - self.weight = weight - self.register_buffer("class_weight", torch.ones(1)) + weight = torch.as_tensor(weight) if weight is not None else None + self.register_buffer("class_weight", weight) + self.class_weight: None | torch.Tensor def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -189,13 +190,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr) - if self.weight is not None and target.shape[1] != 1: + num_of_classes = target.shape[1] + if self.class_weight is not None and num_of_classes != 1: # make sure the lengths of weights are equal to the number of classes - num_of_classes = target.shape[1] - if isinstance(self.weight, (float, int)): - self.class_weight = torch.as_tensor([self.weight] * num_of_classes) + if self.class_weight.ndim == 0: + self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes) else: - self.class_weight = torch.as_tensor(self.weight) if self.class_weight.shape[0] != num_of_classes: raise ValueError( """the length of the `weight` sequence should be the same as the number of classes. diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index fbd0e6efb8..98c1a071b6 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -113,7 +113,9 @@ def __init__( self.alpha = alpha self.weight = weight self.use_softmax = use_softmax - self.register_buffer("class_weight", torch.ones(1)) + weight = torch.as_tensor(weight) if weight is not None else None + self.register_buffer("class_weight", weight) + self.class_weight: None | torch.Tensor def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -162,13 +164,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: else: loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha) - if self.weight is not None: + num_of_classes = target.shape[1] + if self.class_weight is not None and num_of_classes != 1: # make sure the lengths of weights are equal to the number of classes - num_of_classes = target.shape[1] - if isinstance(self.weight, (float, int)): - self.class_weight = torch.as_tensor([self.weight] * num_of_classes) + if self.class_weight.ndim == 0: + self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes) else: - self.class_weight = torch.as_tensor(self.weight) if self.class_weight.shape[0] != num_of_classes: raise ValueError( """the length of the `weight` sequence should be the same as the number of classes.