Skip to content

Commit

Permalink
add kldivloss module (#5155)
Browse files Browse the repository at this point in the history
* add kldivloss module

* support for larger range of target values

* add doctest

* separate testcase

* delete deprecated argument

* fix zeros device bug

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
YongtaoShi and oneflow-ci-bot committed Jun 15, 2021
1 parent 0c3a90f commit 0eec807
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 34 deletions.
1 change: 1 addition & 0 deletions docs/source/experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ Experimental features
.. autofunction:: oneflow.experimental.nn.CrossEntropyLoss
.. autofunction:: oneflow.experimental.nn.L1Loss
.. autofunction:: oneflow.experimental.nn.NLLLoss
.. autofunction:: oneflow.experimental.nn.KLDivLoss
.. autofunction:: oneflow.experimental.nn.MSELoss
.. autofunction:: oneflow.experimental.nn.MarginRankingLoss
.. autofunction:: oneflow.experimental.masked_fill
Expand Down
171 changes: 137 additions & 34 deletions oneflow/python/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@ class CrossEntropyLoss(Module):
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
be applied, ``'mean'``: the weighted mean of the output is taken,
``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in
the meantime, specifying either of those two args will override
:attr:`reduction`. Default: ``'mean'``
``'sum'``: the output will be summed. Default: ``'mean'``
For example:
Expand Down Expand Up @@ -193,10 +190,7 @@ class NLLLoss(Module):
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
be applied, ``'mean'``: the weighted mean of the output is taken,
``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in
the meantime, specifying either of those two args will override
:attr:`reduction`. Default: ``'mean'``
``'sum'``: the output will be summed. Default: ``'mean'``
For example:
Expand Down Expand Up @@ -297,6 +291,132 @@ def forward(self, input, target):
return res.mean()


@oneflow_export("nn.KLDivLoss")
@experimental_api
class KLDivLoss(Module):
r"""The interface is consistent with PyTorch.
The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html?highlight=kldivloss#torch.nn.KLDivLoss
The Kullback-Leibler divergence loss measure
`Kullback-Leibler divergence`_ is a useful distance measure for continuous
distributions and is often useful when performing direct regression over
the space of (discretely sampled) continuous output distributions.
As with :class:`~torch.nn.NLLLoss`, the `input` given is expected to contain
*log-probabilities* and is not restricted to a 2D Tensor.
The targets are interpreted as *probabilities* by default, but could be considered
as *log-probabilities* with :attr:`log_target` set to ``True``.
This criterion expects a `target` `Tensor` of the same size as the
`input` `Tensor`.
The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
.. math::
l(x,y) = L = \{ l_1,\dots,l_N \}, \quad
l_n = y_n \cdot \left( \log y_n - x_n \right)
where the index :math:`N` spans all dimensions of ``input`` and :math:`L` has the same
shape as ``input``. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then:
.. math::
\ell(x, y) = \begin{cases}
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';} \\
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
\end{cases}
In default :attr:`reduction` mode ``'mean'``, the losses are averaged for each minibatch over observations
**as well as** over dimensions. ``'batchmean'`` mode gives the correct KL divergence where losses
are averaged over batch dimension only. ``'mean'`` mode's behavior will be changed to the same as
``'batchmean'`` in the next major release.
.. _`kullback-leibler divergence`: https://en.wikipedia.org/wiki/Kullback-Leibler_divergence
Args:
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``.
``'none'``: no reduction will be applied.
``'batchmean'``: the sum of the output will be divided by batchsize.
``'sum'``: the output will be summed.
``'mean'``: the output will be divided by the number of elements in the output.
Default: ``'mean'``
log_target (bool, optional): Specifies whether `target` is passed in the log space.
Default: ``False``
.. note::
:attr:`reduction` = ``'mean'`` doesn't return the true kl divergence value, please use
:attr:`reduction` = ``'batchmean'`` which aligns with KL math definition.
In the next major release, ``'mean'`` will be changed to be the same as ``'batchmean'``.
Shape:
- Input: :math:`(N, *)` where :math:`*` means, any number of additional
dimensions
- Target: :math:`(N, *)`, same shape as the input
- Output: scalar by default. If :attr:``reduction`` is ``'none'``, then :math:`(N, *)`,
the same shape as the input
For example:
.. code-block:: python
>>> import oneflow.experimental as flow
>>> import numpy as np
>>> flow.enable_eager_execution()
>>> input = flow.Tensor([-0.9021705, 0.08798598, 1.04686249], dtype=flow.float32)
>>> target = flow.Tensor([1.22386942, -0.89729659, 0.01615712], dtype=flow.float32)
>>> m = flow.nn.KLDivLoss(reduction="none", log_target=False)
>>> out = m(input, target)
>>> out
tensor([ 1.3514, 0. , -0.0836], dtype=oneflow.float32)
>>> m = flow.nn.KLDivLoss(reduction="mean", log_target=False)
>>> out = m(input, target)
>>> out
tensor([0.4226], dtype=oneflow.float32)
>>> m = flow.nn.KLDivLoss(reduction="sum", log_target=True)
>>> out = m(input, target)
>>> out
tensor([5.7801], dtype=oneflow.float32)
"""

def __init__(self, reduction: str = "mean", log_target: bool = False,) -> None:
super().__init__()
assert reduction in [
"sum",
"none",
"mean",
None,
], "Argument reduction only support 'sum'/'mean'/'none'/None for now!"
self.reduction = reduction
self.log_target = log_target

def forward(self, input: Tensor, target: Tensor) -> Tensor:
if self.log_target:
_kl_div_loss = flow.experimental.exp(target) * (target - input)
else:
_kl_div_out_loss = target * (flow.experimental.log(target) - input)
_zeros = flow.experimental.zeros(
size=_kl_div_out_loss.shape,
dtype=_kl_div_out_loss.dtype,
device=_kl_div_out_loss.device,
)
# when target < 0, we set to `0`, when target > 0, we set to `1`.
_condition = flow.experimental.gt(target, 0)
# To avoid the `nan` value in log operation
# We set those positions which `target` is less than zero as `0`
_kl_div_loss = flow.experimental.where(_condition, _kl_div_out_loss, _zeros)

if self.reduction == "mean":
return flow.experimental.mean(_kl_div_loss)
elif self.reduction == "sum":
return flow.experimental.sum(_kl_div_loss)
else:
return _kl_div_loss


@oneflow_export("nn.MSELoss")
@experimental_api
class MSELoss(Module):
Expand Down Expand Up @@ -331,21 +451,10 @@ class MSELoss(Module):
The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``.
Args:
size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
the losses are averaged over each loss element in the batch. Note that for
some losses, there are multiple elements per sample. If the field :attr:`size_average`
is set to ``False``, the losses are instead summed for each minibatch. Ignored
when :attr:`reduce` is ``False``. Default: ``True``
reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
losses are averaged or summed over observations for each minibatch depending
on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
batch element instead and ignores :attr:`size_average`. Default: ``True``
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
Shape:
- Input: :math:`(N, *)` where :math:`*` means, any number of additional
Expand All @@ -368,28 +477,22 @@ class MSELoss(Module):
... [-0.49158347, 0.93673637, 0.1324141]], dtype=flow.float32)
>>> m = flow.nn.MSELoss(reduction="none")
>>> out = m(input, target)
>>> print(out.numpy())
[[2.266468 0.50750285 0.61121327]
[0.55887264 4.082267 0.1172941 ]]
>>> out
tensor([[2.2665, 0.5075, 0.6112],
[0.5589, 4.0823, 0.1173]], dtype=oneflow.float32)
>>> m = flow.nn.MSELoss(reduction="mean")
>>> out = m(input, target)
>>> print(out.numpy())
[1.3572696]
>>> out
tensor([1.3573], dtype=oneflow.float32)
>>> m = flow.nn.MSELoss(reduction="sum")
>>> out = m(input, target)
>>> print(out.numpy())
[8.143618]
>>> out
tensor([8.1436], dtype=oneflow.float32)
"""

def __init__(
self, reduction: str = "mean", size_average: bool = True, reduce: bool = True
) -> None:
def __init__(self, reduction: str = "mean") -> None:
super().__init__()
if size_average is False:
raise ValueError("Argument size_average is not supported yet")
if reduce is False:
raise ValueError("Argument reduce is not supported yet")
assert reduction in [
"sum",
"none",
Expand Down
1 change: 1 addition & 0 deletions oneflow/python/ops/nn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4406,6 +4406,7 @@ def PixelShufflev2Job(input: tp.Numpy.Placeholder(shape=(3, 16, 2, 4), dtype=flo


@oneflow_export("nn.KLDivLoss")
@stable_api
def kldivloss(
input: oneflow._oneflow_internal.BlobDesc,
target: oneflow._oneflow_internal.BlobDesc,
Expand Down
118 changes: 118 additions & 0 deletions oneflow/python/test/modules/test_kldivloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict

import numpy as np

import oneflow.experimental as flow
from test_util import GenArgList


def _np_kldivloss(np_input, np_target, np_log_target):
if np_log_target:
np_kl_div_loss = np.exp(np_target) * (np_target - np_input)
else:
np_kl_div_out_loss = np_target * (np.log(np_target) - np_input)
np_zeros = np.zeros_like(np_kl_div_out_loss, dtype=np.float32)
# when target < 0, we set to `0`, when target > 0, we set to `1`.
# set the element in _kl_div_loss as `0` to avoid `nan` value.
np_kl_div_loss = np.where(np_target > 0, np_kl_div_out_loss, np_zeros)

return {
"none": np_kl_div_loss,
"mean": np.mean(np_kl_div_loss),
"sum": np.sum(np_kl_div_loss),
}


def _np_kldivloss_grad(input, target, np_log_target):
elem_cnt = input.size
if np_log_target:
_np_diff = -np.exp(target)
else:
_np_diff = -target
# Because when np_log_target == False, the loss will be set to zero when target < 0
_zero_index = np.where(target > 0, 1, 0)
_np_diff = _np_diff * _zero_index

return {
"none": _np_diff,
"mean": _np_diff / elem_cnt,
"sum": _np_diff,
}


def _test_kldivloss_forward(test_case, device, shape, reduction, log_target):
x = np.random.randn(*shape)
y = np.random.randn(*shape)
input = flow.Tensor(
x, dtype=flow.float32, requires_grad=True, device=flow.device(device)
)
target = flow.Tensor(y, dtype=flow.float32, device=flow.device(device))

loss = flow.nn.KLDivLoss(reduction=reduction, log_target=log_target)
loss = loss.to(device)
of_out = loss(input, target)
np_out = _np_kldivloss(x, y, log_target)[reduction]
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))


def _test_kldivloss_backward(test_case, device, shape, reduction, log_target):
x = np.random.randn(*shape)
y = np.random.randn(*shape)
input = flow.Tensor(
x, dtype=flow.float32, requires_grad=True, device=flow.device(device)
)
target = flow.Tensor(y, dtype=flow.float32, device=flow.device(device))

loss = flow.nn.KLDivLoss(reduction=reduction, log_target=log_target)
loss = loss.to(device)
of_out = loss(input, target)

of_out = of_out.sum()
of_out.backward()
np_grad = _np_kldivloss_grad(x, y, log_target)[reduction]
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))


@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestKLDivLossModule(flow.unittest.TestCase):
def test_kldivloss(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_kldivloss_forward,
_test_kldivloss_backward,
]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["shape"] = [
(3, 5),
(10, 9, 21),
(14, 22, 9, 21),
(3, 2, 4, 16, 5),
(1,),
]
arg_dict["reduction"] = ["none", "mean", "sum"]
arg_dict["log_target"] = [False, True]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])


if __name__ == "__main__":
unittest.main()

0 comments on commit 0eec807

Please sign in to comment.