Skip to content

[TORCH] Add support for Kullback-Leibler divergence loss function #4203

Open
@zahidwx

Description

@zahidwx

Kullback-Leibler divergence loss

https://docs.pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html

Currently KL div loss function is not supported. This can be decomposed into primitive ops having existing support.

To avoid underflow issues when computing this quantity, this loss expects the argument input in the log-space. The argument target may also be provided in the log-space if log_target= True.

# To summarise, this function is roughly equivalent to computing

if not log_target: # default
    loss_pointwise = target * (target.log() - input)
else:
    loss_pointwise = target.exp() * (target - input)

# and then reducing this result depending on the argument reduction as

if reduction == "mean":  # default
    loss = loss_pointwise.mean()
elif reduction == "batchmean":  # mathematically correct
    loss = loss_pointwise.sum() / input.size(0)
elif reduction == "sum":
    loss = loss_pointwise.sum()
else:  # reduction == "none"
    loss = loss_pointwise

Reduction enum mapping

def get_enum(reduction: str) -> int:
    if reduction == "none":
        ret = 0
    elif reduction == "mean":
        ret = 1
    elif reduction == "elementwise_mean":
        warnings.warn(
            "reduction='elementwise_mean' is deprecated. "
            "Please use reduction='mean' instead."
        )
        ret = 1
    elif reduction == "sum":
        ret = 2
    else:
        ret = -1  # TODO: remove once JIT exceptions support control flow
        raise ValueError(f"{reduction} is not a valid value for reduction")
    return ret

Reduction: batchmean

# special case for batchmean
F.kl_div(input, target, reduction="batchmean"):
      out = torch.kl_div(input, target, reduction="sum")
      batch_size = input.shape[0]
      out = out / batch_size
      return out

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions