Open
Description
Kullback-Leibler divergence loss
https://docs.pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
Currently KL div loss function is not supported. This can be decomposed into primitive ops having existing support.
To avoid underflow issues when computing this quantity, this loss expects the argument input in the log-space. The argument target may also be provided in the log-space if log_target= True.
# To summarise, this function is roughly equivalent to computing
if not log_target: # default
loss_pointwise = target * (target.log() - input)
else:
loss_pointwise = target.exp() * (target - input)
# and then reducing this result depending on the argument reduction as
if reduction == "mean": # default
loss = loss_pointwise.mean()
elif reduction == "batchmean": # mathematically correct
loss = loss_pointwise.sum() / input.size(0)
elif reduction == "sum":
loss = loss_pointwise.sum()
else: # reduction == "none"
loss = loss_pointwise
Reduction enum mapping
def get_enum(reduction: str) -> int:
if reduction == "none":
ret = 0
elif reduction == "mean":
ret = 1
elif reduction == "elementwise_mean":
warnings.warn(
"reduction='elementwise_mean' is deprecated. "
"Please use reduction='mean' instead."
)
ret = 1
elif reduction == "sum":
ret = 2
else:
ret = -1 # TODO: remove once JIT exceptions support control flow
raise ValueError(f"{reduction} is not a valid value for reduction")
return ret
Reduction: batchmean
# special case for batchmean
F.kl_div(input, target, reduction="batchmean"):
out = torch.kl_div(input, target, reduction="sum")
batch_size = input.shape[0]
out = out / batch_size
return out
Metadata
Metadata
Assignees
Labels
No labels