In [None]:
# | default_exp losses/class_balanced_cross_entropy_loss

# Imports

In [None]:
# | export


from collections import Counter

import torch
from torch import nn
from torch.nn import functional as F

from vision_architectures.docstrings import populate_docstring
from vision_architectures.utils.custom_base_model import CustomBaseModel, Field

# Config

In [None]:
# | export


class ClassBalancedCrossEntropyLossConfig(CustomBaseModel):
    num_classes: int = Field(..., description="Number of classes to weight cross entropy loss.")
    ema_decay: float = Field(
        0.99, description="Exponential moving average decay. By default 0.99 is used which has a half life of ~69 steps"
    )

# The loss function

In [None]:
# | export


class ClassBalancedCrossEntropyLoss(nn.Module):
    @populate_docstring
    def __init__(self, config: ClassBalancedCrossEntropyLossConfig = {}, **kwargs):
        """Class-balanced cross-entropy loss with running prevalence estimation.

        This loss reweights the standard multi-class cross-entropy by the inverse of
        the observed class prevalences in the training data. Class prevalences are
        estimated online via an exponential moving average (EMA) using class counts
        from the incoming targets.

        Notes:
            - Targets must be discrete integer class indices in [0, num_classes-1].
              Probabilistic/soft labels are not supported.
            - For classes that haven't been observed yet, their prevalence is treated
              as NaN and replaced in the weight vector by the mean of observed weights
              (then clamped within 3 standard deviations to avoid extreme values).

        Args:
            config: {CONFIG_INSTANCE_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = ClassBalancedCrossEntropyLossConfig.model_validate(config | kwargs)

        # Save class prevalence percentages to weight the cross entropy loss
        self.class_prevalences = [None] * self.config.num_classes
        # Initialized with None as we don't know the initial class prevalence estimates.

    def _update_class_prevalences(self, target: torch.Tensor):
        """Update the running class-prevalence estimates from a target tensor.

        The method counts class occurrences in the provided target, converts counts
        to per-batch prevalences, then updates the internal EMA-tracked prevalence
        vector for each class.

        Args:
            target: A tensor of integer class indices with any shape, typically
                (N,) or (N, ...) for segmentation. Values outside
                [0, num_classes-1] are ignored.
        """
        # Count the number of times each class is encountered
        class_counts = Counter(target.flatten().int().tolist())

        # Remove those counts that exceed the number of classes that are being tracked
        class_counts = {k: v for k, v in class_counts.items() if k < self.config.num_classes}

        # Calculate prevalence of each class
        total_count = sum(class_counts.values())
        new_prevalences = (
            {class_id: count / total_count for class_id, count in class_counts.items()} if total_count > 0 else {}
        )

        # Update current prevalences using EMA to allow for distribution shift in data
        decay = self.config.ema_decay
        for i in range(self.config.num_classes):
            if self.class_prevalences[i] is None:
                # if encountered for the first time
                if new_prevalences.get(i, 0.0) > 0.0:
                    self.class_prevalences[i] = new_prevalences.get(i, 0.0)
                # otherwise let it stay None
            else:
                # update to new value using EMA
                self.class_prevalences[i] = self.class_prevalences[i] * decay + new_prevalences.get(i, 0.0) * (
                    1 - decay
                )

    def get_class_prevalences(self, device=torch.device("cpu")) -> torch.Tensor:
        """Return the current vector of class prevalences as a tensor.

        For classes that haven't been observed yet, the corresponding entry will be
        NaN. This method does not perform any imputation or normalization beyond
        returning the current EMA state.

        Args:
            device: The device on which to place the returned tensor.

        Returns:
            Tensor of shape (num_classes,) with dtype float32 containing per-class
            prevalence estimates in [0, 1] or NaN for unseen classes.
        """

        # Get class prevalences and replace None with nan
        class_prevalences = [
            prevalence if prevalence is not None else torch.nan for prevalence in self.class_prevalences
        ]

        # Calculate weights as inverse of class prevalences
        class_prevalences = torch.tensor(class_prevalences, dtype=torch.float32, device=device)

        return class_prevalences

    def get_class_weights(self, device=torch.device("cpu")) -> torch.Tensor:
        """Compute per-class weights as the inverse of prevalences with safeguards.

        Steps:
            1) Convert current EMA prevalences to a tensor with NaNs for unseen classes.
            2) Take the inverse to obtain raw weights (higher weight for rarer classes).
            3) Replace NaNs by the mean of observed weights to avoid biasing toward
               unseen classes, then clamp to mean ± 3·std to prevent extreme values.
            4) Renormalize weights to sum to num_classes (so the average weight is 1).

        Args:
            device: The device on which to place the returned tensor.

        Returns:
            Tensor of shape (num_classes,) containing normalized class weights.
        """

        # Get class prevalences
        class_prevalences = self.get_class_prevalences(device)

        # Calculate weights as inverse of class prevalences
        weights = 1 / torch.tensor(class_prevalences, dtype=torch.float32, device=device)

        # Substitute nan values with mean and clamp weights to a limit
        # Assumption: all classes are visited at least once in the dataset
        mu, std = weights.nanmean(), weights[~weights.isnan()].std()
        weights = weights.nan_to_num(mu)
        weights = weights.clamp(mu - 3 * std, mu + 3 * std)

        # Normalize weights to sum to self.config.num_classes
        weights = self.config.num_classes * weights / weights.sum()

        return weights

    def forward(
        self, input: torch.Tensor, target: torch.Tensor, return_class_weights: bool = False, *args, **kwargs
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """Compute class-balanced cross entropy.

        This updates the internal class-prevalence EMA using the provided targets,
        then computes cross-entropy with a weight vector derived from the current
        prevalences.

        Args:
            input: Logits of shape (N, C, ...) where C == num_classes. Any extra spatial
                dimensions (e.g., H, W, D) are supported as long as `target` is broadcastable
                to the same non-channel shape expected by torch.nn.functional.cross_entropy.
            target: Integer class indices with shape matching input without the channel
                dimension, e.g., (N, ...) with values in [0, C-1].
            return_class_weights: If True, also return the per-class weight tensor used
                for this call.
            *args, **kwargs: Additional keyword args forwarded to F.cross_entropy
                (e.g., reduction='mean').

        Returns:
            If return_class_weights is False: a scalar tensor loss.
            If True: a tuple (loss, class_weights).
        """
        # Update class prevalences
        self._update_class_prevalences(target)

        # Get class weights
        class_weights = self.get_class_weights(input.device)

        # Calculate loss
        loss = F.cross_entropy(input, target, weight=class_weights, *args, **kwargs)

        if return_class_weights:
            return loss, class_weights
        return loss

In [None]:
# Regular test


test = ClassBalancedCrossEntropyLoss(num_classes=3)

for _ in range(10):
    example_input = torch.rand(4, 3)
    example_target = torch.randint(0, 3, (4,))
    loss, class_weights = test(example_input, example_target, return_class_weights=True)
    print(loss, class_weights)

tensor(0.9960) tensor([1.2000, 0.6000, 1.2000])
tensor(1.0988) tensor([1.1916, 0.6048, 1.2036])
tensor(1.0598) tensor([1.1869, 0.6145, 1.1986])
tensor(1.1092) tensor([1.1871, 0.6143, 1.1986])
tensor(1.2541) tensor([1.1906, 0.6191, 1.1903])
tensor(1.0118) tensor([1.2024, 0.6188, 1.1787])
tensor(0.9824) tensor([1.2174, 0.6233, 1.1593])
tensor(0.9926) tensor([1.2087, 0.6282, 1.1631])
tensor(0.9627) tensor([1.2087, 0.6279, 1.1635])
tensor(1.1370) tensor([1.2117, 0.6325, 1.1558])


  weights = 1 / torch.tensor(class_prevalences, dtype=torch.float32, device=device)


In [None]:
# Test where last two classes are never encountered

test = ClassBalancedCrossEntropyLoss(num_classes=5)

for _ in range(10):
    example_input = torch.rand(4, 5)
    example_target = torch.randint(0, 3, (4,))
    loss, class_weights = test(example_input, example_target, return_class_weights=True)
    print(loss, class_weights)

tensor(1.5845) tensor([0.5000, 1.0000, 1.5000, 1.0000, 1.0000])
tensor(1.5643) tensor([0.4267, 1.2802, 1.2931, 1.0000, 1.0000])
tensor(1.5043) tensor([0.4310, 1.2717, 1.2973, 1.0000, 1.0000])
tensor(1.7012) tensor([0.4354, 1.2761, 1.2885, 1.0000, 1.0000])
tensor(1.7786) tensor([0.4428, 1.2725, 1.2846, 1.0000, 1.0000])
tensor(1.6650) tensor([0.4439, 1.2595, 1.2966, 1.0000, 1.0000])
tensor(1.7746) tensor([0.4515, 1.2685, 1.2800, 1.0000, 1.0000])
tensor(1.6959) tensor([0.4462, 1.2712, 1.2826, 1.0000, 1.0000])
tensor(1.5825) tensor([0.4537, 1.2551, 1.2912, 1.0000, 1.0000])
tensor(1.6506) tensor([0.4612, 1.2517, 1.2870, 1.0000, 1.0000])


  weights = 1 / torch.tensor(class_prevalences, dtype=torch.float32, device=device)


# nbdev

In [None]:
!nbdev_export