Skip to content

Commit

Permalink
add type promotion (#27756)
Browse files Browse the repository at this point in the history
  • Loading branch information
LielinJiang committed Oct 9, 2020
1 parent 9089841 commit b9c7c66
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
15 changes: 15 additions & 0 deletions python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,20 @@ def test_kl_loss_static_api(self):
pred_loss = paddle.nn.functional.kl_div(input, label)


class TestKLDivLossTypePromotion(unittest.TestCase):
def test_kl_div_promotion(self):

with paddle.fluid.dygraph.guard():
x1 = paddle.rand([5, 20], dtype='float32')
target1 = paddle.rand([5, 20], dtype='float64')

kldiv_criterion = paddle.nn.KLDivLoss()
pred_loss1 = kldiv_criterion(x1, target1)

x2 = paddle.rand([5, 20], dtype='float64')
target2 = paddle.rand([5, 20], dtype='float32')
pred_loss2 = paddle.nn.functional.kl_div(x2, target2)


if __name__ == "__main__":
unittest.main()
10 changes: 10 additions & 0 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,16 @@ def kl_div(input, label, reduction='mean', name=None):
# shape=[5, 20]
"""
# ugly type promotion
if fluid.data_feeder.convert_dtype(
input.dtype) == 'float32' and fluid.data_feeder.convert_dtype(
label.dtype) == 'float64':
input = fluid.layers.cast(input, 'float64')
elif fluid.data_feeder.convert_dtype(
input.dtype) == 'float64' and fluid.data_feeder.convert_dtype(
label.dtype) == 'float32':
label = fluid.layers.cast(label, 'float64')

if paddle.in_dynamic_mode():
out = core.ops.kldiv_loss(input, label, 'reduction', reduction)
return out
Expand Down

0 comments on commit b9c7c66

Please sign in to comment.