From 373d598a79cab6b2acff17629e55a46a7ca434fb Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Mon, 22 Aug 2022 20:57:50 +0800 Subject: [PATCH 01/10] 2022-08-30_update nn.layer.loss nn.functional.loss, test_file --- .../tests/unittests/test_multimarginloss.py | 338 ++++++++++++++++++ python/paddle/nn/functional/loss.py | 110 ++++++ python/paddle/nn/layer/loss.py | 101 ++++++ 3 files changed, 549 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_multimarginloss.py diff --git a/python/paddle/fluid/tests/unittests/test_multimarginloss.py b/python/paddle/fluid/tests/unittests/test_multimarginloss.py new file mode 100644 index 0000000000000..ddbd1e853bc58 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multimarginloss.py @@ -0,0 +1,338 @@ +# -*- coding: utf-8 -* +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import unittest + + +def call_MultiMarginLoss_layer( + input, + label, + p=1, + margin=1.0, + weight=None, + reduction='mean', +): + triplet_margin_loss = paddle.nn.MultiMarginLoss(p=p, + margin=margin, + weight=weight, + reduction=reduction) + res = triplet_margin_loss( + input=input, + label=label, + ) + return res + + +def call_MultiMarginLoss_functional( + input, + label, + p=1, + margin=1.0, + weight=None, + reduction='mean', +): + res = paddle.nn.functional.multi_margin_loss(input=input, + label=label, + p=p, + margin=margin, + weight=weight, + reduction=reduction) + return res + + +def test_static(place, + input_np, + label_np, + p=1, + margin=1.0, + weight=None, + reduction='mean', + functional=False): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + input = paddle.static.data(name='input', + shape=input_np.shape, + dtype='float64') + label = paddle.static.data(name='label', + shape=label_np.shape, + dtype='float64') + feed_dict = { + "input": input_np, + "label": label_np, + } + + if functional: + res = call_MultiMarginLoss_functional(input=input, + label=label, + p=p, + margin=margin, + weight=weight, + reduction=reduction) + else: + res = call_MultiMarginLoss_layer(input=input, + label=label, + p=p, + margin=margin, + weight=weight, + reduction=reduction) + + exe = paddle.static.Executor(place) + static_result = exe.run(prog, feed=feed_dict, fetch_list=[res]) + return static_result + + +def test_dygraph(place, + input, + label, + p=1, + margin=1.0, + weight=None, + reduction='mean', + functional=False): + paddle.disable_static() + input = paddle.to_tensor(input) + label = paddle.to_tensor(label) + + if functional: + dy_res = call_MultiMarginLoss_functional(input=input, + label=label, + p=p, + margin=margin, + weight=weight, + reduction=reduction) + else: + dy_res = call_MultiMarginLoss_layer(input=input, + label=label, + p=p, + margin=margin, + weight=weight, + reduction=reduction) + dy_result = dy_res.numpy() + paddle.enable_static() + return dy_result + + +def calc_multi_margin_loss( + input, + label, + p=1, + margin=1.0, + weight=None, + reduction='mean', +): + label = label.reshape(-1, 1) + index_sample = [] + for i in range(len(label)): + index_sample.append(input[i, label[i]]) + index_sample = np.array(index_sample).reshape(-1, 1) + + if weight is None: + expected = np.mean(np.maximum(margin + input - index_sample, 0.0)**p, + axis=1) - margin**p / input.shape[1] + else: + weight = weight.reshape(-1, 1) + expected = np.mean(np.maximum(weight * (margin + input - index_sample), 0.0) ** p, axis=1) - margin ** p / \ + input.shape[1] + + if reduction == 'mean': + expected = np.mean(expected) + elif reduction == 'sum': + expected = np.sum(expected) + else: + expected = expected + + return expected + + +class TestMultiMarginLoss(unittest.TestCase): + + def test_MultiMarginLoss(self): + shape = (2, 2) + input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) + label = np.random.uniform(0, 2, size=(2, )).astype(np.float64) + + places = [paddle.CPUPlace()] + if paddle.device.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + reductions = ['sum', 'mean', 'none'] + for place in places: + for reduction in reductions: + expected = calc_multi_margin_loss(input=input, + label=label, + reduction=reduction) + + dy_result = test_dygraph( + place=place, + input=input, + label=label, + reduction=reduction, + ) + + static_result = test_static( + place=place, + input_np=input, + label_np=label, + reduction=reduction, + ) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static(place=place, + input_np=input, + label_np=label, + reduction=reduction, + functional=True) + dy_functional = test_dygraph(place=place, + input=input, + label=label, + reduction=reduction, + functional=True) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + def test_MultiMarginLoss_error(self): + paddle.disable_static() + self.assertRaises(ValueError, + paddle.nn.loss.MultiMarginLoss, + reduction="unsupport reduction") + input = paddle.to_tensor([[0.1, 0.3]], dtype='float32') + label = paddle.to_tensor([0.0], dtype='float32') + self.assertRaises(ValueError, + paddle.nn.functional.multi_margin_loss, + input=input, + label=label, + reduction="unsupport reduction") + paddle.enable_static() + + def test_MultiMarginLoss_dimension(self): + paddle.disable_static() + + input = paddle.to_tensor([[0.1, 0.3], [1, 2]], dtype='float32') + label = paddle.to_tensor([0.0, 1.0, 2.0], dtype='float32') + + self.assertRaises( + ValueError, + paddle.nn.functional.multi_margin_loss, + input=input, + label=label, + ) + MMLoss = paddle.nn.loss.MultiMarginLoss() + self.assertRaises( + ValueError, + MMLoss, + input=input, + label=label, + ) + paddle.enable_static() + + def test_MultiMarginLoss_p(self): + p = 2 + shape = (2, 2) + reduction = 'mean' + place = paddle.CPUPlace() + input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) + label = np.random.uniform(0, 2, size=(2, )).astype(np.float64) + expected = calc_multi_margin_loss(input=input, + p=p, + label=label, + reduction=reduction) + + dy_result = test_dygraph( + place=place, + p=p, + input=input, + label=label, + reduction=reduction, + ) + + static_result = test_static( + place=place, + p=p, + input_np=input, + label_np=label, + reduction=reduction, + ) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static(place=place, + p=p, + input_np=input, + label_np=label, + reduction=reduction, + functional=True) + dy_functional = test_dygraph(place=place, + p=p, + input=input, + label=label, + reduction=reduction, + functional=True) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + def test_MultiMarginLoss_weight(self): + shape = (2, 2) + reduction = 'mean' + place = paddle.CPUPlace() + input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) + label = np.random.uniform(0, 2, size=(2, )).astype(np.float64) + weight = np.random.uniform(0, 2, size=(2, )).astype(np.float64) + expected = calc_multi_margin_loss(input=input, + label=label, + weight=weight, + reduction=reduction) + + dy_result = test_dygraph( + place=place, + input=input, + label=label, + weight=weight, + reduction=reduction, + ) + + static_result = test_static( + place=place, + input_np=input, + label_np=label, + weight=weight, + reduction=reduction, + ) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static(place=place, + input_np=input, + label_np=label, + weight=weight, + reduction=reduction, + functional=True) + dy_functional = test_dygraph(place=place, + input=input, + label=label, + weight=weight, + reduction=reduction, + functional=True) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 7fb4ccb233b2c..eb009c92be5a0 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3386,6 +3386,116 @@ def triplet_margin_loss(input, return loss +def multi_margin_loss(input, + label, + p: int = 1, + margin: float = 1.0, + weight=None, + reduction='mean', + name=None): + r""" + Measures a multi-class classification hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) + and output :math:`y` (which is a 1D tensor of target class indices,:math:`0 \leq y \leq \text{x.size}(1)-1`): + + For each mini-batch sample, the loss in terms of the 1D input :math:`x` and scalar + output :math:`y` is: + + .. math:: + \text{loss}(x, y) = \frac{\sum_i \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)} + + where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}` + and :math:`i \neq y`. + + Optionally, you can give non-equal weighting on the classes by passing + a 1D :attr:`weight` tensor into the constructor. + + The loss function then becomes: + + .. math:: + \text{loss}(x, y) = \frac{\sum_i \max(0, w[y] * (\text{margin} - x[y] + x[i]))^p}{\text{x.size}(0)} + + + Parameters: + input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes. + + label (Tensor): Label tensor, the data type is float32 or float64. The shape of label is (N,) + + p (int, Optional): The norm degree for pairwise distance. Default: :math:`1`. + + margin (float, Optional): Default: :math:`1`. + + weight (Tensor,optional): a manual rescaling weight given to each class. + If given, has to be a Tensor of size C and the data type is float32, float64. + Default is ``'None'`` . + + + reduction (str, Optional):Indicate how to average the loss by batch_size. + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default: ``'mean'`` + + name (str, Optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Output: Tensor. The tensor variable storing the triplet_margin_loss of input and positive and negative. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + + input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32) + positive= paddle.to_tensor([1, 2, 1], dtype=paddle.float32) + loss = F.multi_margin_loss(input, label, margin=1.0, reduction='none') + print(loss) + + + loss = F.multi_margin_loss(input, label, margin=1.0, reduction='mean') + print(loss) + + """ + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "'reduction' in 'multi_margin_loss' should be 'sum', 'mean' or 'none', " + "but received {}.".format(reduction)) + + if not _non_static_mode(): + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + 'multi_margin_loss') + check_variable_and_dtype(label, 'positive', ['float32', 'float64'], + 'multi_margin_loss') + if not (input.shape[0] == label.shape[0]): + raise ValueError("The label's shape is wrong") + label = label.reshape((-1, 1)) + index_sample = paddle.index_sample(input, label) + if weight is not None: + if not _non_static_mode(): + check_variable_and_dtype(weight, 'weight', ['float32', 'float64'], + 'multi_margin_loss') + if not (input.shape[0] == weight.shape[0]): + raise ValueError("The weight's shape is wrong ") + + weight = weight.reshape((-1, 1)) + loss = paddle.mean(paddle.pow( + paddle.clip(weight * (margin + index_sample - input), min=0.0), p), + axis=1) - margin**p / input.shape[1] + else: + loss = paddle.mean(paddle.pow( + paddle.clip(margin + index_sample - input, min=0.0), p), + axis=1) - margin**p / input.shape[1] + + if reduction == 'mean': + return paddle.mean(loss, name=name) + elif reduction == 'sum': + return paddle.sum(loss, name=name) + elif reduction == 'none': + return loss + + def soft_margin_loss(input, label, reduction='mean', name=None): """ The API measures the soft margin loss between input predictions ``input`` diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 1ff37afa1412e..211bf92e41085 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1689,6 +1689,107 @@ def forward(self, input, positive, negative): name=self.name) +class MultiMarginLoss(Layer): + r"""Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) + between input :math:`x` (a 2D mini-batch `Tensor`) + and output :math:`y` (which is a 1D tensor of target class indices,:math:`0 \leq y \leq \text{x.size}(1)-1`): + + For each mini-batch sample, the loss in terms of the 1D input :math:`x` and scalar + output :math:`y` is: + + .. math:: + \text{loss}(x, y) = \frac{\sum_i \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)} + + where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}` + and :math:`i \neq y`. + + Optionally, you can give non-equal weighting on the classes by passing + a 1D :attr:`weight` tensor into the constructor. + + The loss function then becomes: + + .. math:: + \text{loss}(x, y) = \frac{\sum_i \max(0, w[y] * (\text{margin} - x[y] + x[i]))^p}{\text{x.size}(0)} + + Parameters: + + p (int, Optional):The norm degree for pairwise distance. Default: :math:`1`. + + margin (float, Optional):Default: :math:`1`. + + weight (Tensor,optional): a manual rescaling weight given to each class. + If given, has to be a Tensor of size C and the data type is float32, float64. + Default is ``'None'`` . + + reduction (str, optional): Indicate how to average the loss by batch_size, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default: ``'mean'`` + + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Call parameters: + input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1. + label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64. The shape of label is the same as the shape of input. + + Shape: + input: N-D Tensor, the shape is [N, C], N is batch size and `C` means number of classes, available dtype is float32, float64. The sum operationoperates over all the elements. + label: N-D Tensor, the shape is [N,]. + output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input. + + Returns: + A callable object of MultiMarginLoss. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + + input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32) + label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32) + + multi_margin_loss = nn.MultiLabelSoftMarginLoss(reduction='none') + loss = multi_margin_loss(input, label) + print(loss) + # Tensor([3.49625897, 0.71111226, 0.43989015]) + + multi_margin_loss = nn.MultiMarginLoss(reduction='mean') + loss = multi_margin_loss(input, label) + print(loss) + # Tensor([1.54908717]) + """ + + def __init__(self, + p: int = 1, + margin: float = 1.0, + weight=None, + reduction="mean", + name=None): + super(MultiMarginLoss, self).__init__() + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "'reduction' in 'MultiLabelSoftMarginloss' should be 'sum', 'mean' or 'none', " + "but received {}.".format(reduction)) + self.p = p + self.margin = margin + self.weight = weight + self.reduction = reduction + self.name = name + + def forward(self, input, label): + return F.multi_margin_loss(input, + label, + p=self.p, + margin=self.margin, + weight=self.weight, + reduction=self.reduction, + name=self.name) + + class SoftMarginLoss(Layer): r""" Creates a criterion that measures a two-class soft margin loss between input predictions ``input`` From 9355a4385c7c71b802dc7a846a3eea80f80fb206 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Mon, 22 Aug 2022 20:57:50 +0800 Subject: [PATCH 02/10] 2022-08-30_update nn.layer.loss nn.functional.loss, test_file --- .../tests/unittests/test_multimarginloss.py | 34 ++++++++++++------- python/paddle/nn/__init__.py | 2 ++ python/paddle/nn/functional/__init__.py | 2 ++ python/paddle/nn/functional/loss.py | 14 +++----- python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/loss.py | 16 +++------ 6 files changed, 37 insertions(+), 32 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_multimarginloss.py b/python/paddle/fluid/tests/unittests/test_multimarginloss.py index ddbd1e853bc58..442c4603fe7ee 100644 --- a/python/paddle/fluid/tests/unittests/test_multimarginloss.py +++ b/python/paddle/fluid/tests/unittests/test_multimarginloss.py @@ -59,7 +59,7 @@ def test_static(place, label_np, p=1, margin=1.0, - weight=None, + weight_np=None, reduction='mean', functional=False): prog = paddle.static.Program() @@ -67,15 +67,20 @@ def test_static(place, with paddle.static.program_guard(prog, startup_prog): input = paddle.static.data(name='input', shape=input_np.shape, - dtype='float64') + dtype=input_np.dtype) label = paddle.static.data(name='label', shape=label_np.shape, - dtype='float64') + dtype=label_np.dtype) feed_dict = { "input": input_np, "label": label_np, } - + weight = None + if weight_np is not None: + weight = paddle.static.data(name='weight', + shape=weight_np.shape, + dtype=weight_np.dtype) + feed_dict['weight'] = weight_np if functional: res = call_MultiMarginLoss_functional(input=input, label=label, @@ -108,6 +113,8 @@ def test_dygraph(place, input = paddle.to_tensor(input) label = paddle.to_tensor(label) + if weight is not None: + weight = paddle.to_tensor(weight) if functional: dy_res = call_MultiMarginLoss_functional(input=input, label=label, @@ -164,7 +171,8 @@ class TestMultiMarginLoss(unittest.TestCase): def test_MultiMarginLoss(self): shape = (2, 2) input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) - label = np.random.uniform(0, 2, size=(2, )).astype(np.float64) + label = np.random.uniform(0, input.shape[1], + size=(2, )).astype(np.int32) places = [paddle.CPUPlace()] if paddle.device.is_compiled_with_cuda(): @@ -209,10 +217,10 @@ def test_MultiMarginLoss(self): def test_MultiMarginLoss_error(self): paddle.disable_static() self.assertRaises(ValueError, - paddle.nn.loss.MultiMarginLoss, + paddle.nn.MultiMarginLoss, reduction="unsupport reduction") input = paddle.to_tensor([[0.1, 0.3]], dtype='float32') - label = paddle.to_tensor([0.0], dtype='float32') + label = paddle.to_tensor([0], dtype='int32') self.assertRaises(ValueError, paddle.nn.functional.multi_margin_loss, input=input, @@ -224,7 +232,7 @@ def test_MultiMarginLoss_dimension(self): paddle.disable_static() input = paddle.to_tensor([[0.1, 0.3], [1, 2]], dtype='float32') - label = paddle.to_tensor([0.0, 1.0, 2.0], dtype='float32') + label = paddle.to_tensor([0, 1], dtype='int32') self.assertRaises( ValueError, @@ -247,7 +255,8 @@ def test_MultiMarginLoss_p(self): reduction = 'mean' place = paddle.CPUPlace() input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) - label = np.random.uniform(0, 2, size=(2, )).astype(np.float64) + label = np.random.uniform(0, input.shape[1], + size=(2, )).astype(np.int64) expected = calc_multi_margin_loss(input=input, p=p, label=label, @@ -292,7 +301,8 @@ def test_MultiMarginLoss_weight(self): reduction = 'mean' place = paddle.CPUPlace() input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) - label = np.random.uniform(0, 2, size=(2, )).astype(np.float64) + label = np.random.uniform(0, input.shape[1], + size=(2, )).astype(np.int64) weight = np.random.uniform(0, 2, size=(2, )).astype(np.float64) expected = calc_multi_margin_loss(input=input, label=label, @@ -311,7 +321,7 @@ def test_MultiMarginLoss_weight(self): place=place, input_np=input, label_np=label, - weight=weight, + weight_np=weight, reduction=reduction, ) self.assertTrue(np.allclose(static_result, expected)) @@ -320,7 +330,7 @@ def test_MultiMarginLoss_weight(self): static_functional = test_static(place=place, input_np=input, label_np=label, - weight=weight, + weight_np=weight, reduction=reduction, functional=True) dy_functional = test_dygraph(place=place, diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index e47fa8c3c5480..331131d6e2319 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -109,6 +109,7 @@ from .layer.loss import SmoothL1Loss # noqa: F401 from .layer.loss import HingeEmbeddingLoss # noqa: F401 from .layer.loss import CosineEmbeddingLoss # noqa: F401 +from .layer.loss import MultiMarginLoss from .layer.loss import TripletMarginWithDistanceLoss from .layer.loss import TripletMarginLoss from .layer.loss import SoftMarginLoss @@ -319,6 +320,7 @@ def weight_norm(*args): 'Identity', 'CosineEmbeddingLoss', 'RReLU', + 'MultiMarginLoss', 'TripletMarginWithDistanceLoss', 'TripletMarginLoss', 'SoftMarginLoss', diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 701997e0d0ab5..bf0554d78d8b3 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -92,6 +92,7 @@ from .loss import ctc_loss # noqa: F401 from .loss import hinge_embedding_loss # noqa: F401 from .loss import cosine_embedding_loss # noqa: F401 +from .loss import multi_margin_loss from .loss import multi_label_soft_margin_loss from .loss import triplet_margin_with_distance_loss from .loss import triplet_margin_loss @@ -241,5 +242,6 @@ 'rrelu', 'triplet_margin_with_distance_loss', 'triplet_margin_loss', + 'multi_margin_loss', 'soft_margin_loss', ] diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 4e71dbdb033da..201ae64f3b8d9 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3492,7 +3492,7 @@ def multi_margin_loss(input, Parameters: input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes. - label (Tensor): Label tensor, the data type is float32 or float64. The shape of label is (N,) + label (Tensor): Label tensor, the data type is int32 or int64. The shape of label is (N,) p (int, Optional): The norm degree for pairwise distance. Default: :math:`1`. @@ -3523,14 +3523,10 @@ def multi_margin_loss(input, import paddle.nn.functional as F input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32) - positive= paddle.to_tensor([1, 2, 1], dtype=paddle.float32) + label = paddle.to_tensor([1, 2, 1], dtype=paddle.int32) loss = F.multi_margin_loss(input, label, margin=1.0, reduction='none') print(loss) - - loss = F.multi_margin_loss(input, label, margin=1.0, reduction='mean') - print(loss) - """ if reduction not in ['sum', 'mean', 'none']: raise ValueError( @@ -3540,7 +3536,7 @@ def multi_margin_loss(input, if not _non_static_mode(): check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'multi_margin_loss') - check_variable_and_dtype(label, 'positive', ['float32', 'float64'], + check_variable_and_dtype(label, 'label', ['int32', 'int64'], 'multi_margin_loss') if not (input.shape[0] == label.shape[0]): raise ValueError("The label's shape is wrong") @@ -3555,11 +3551,11 @@ def multi_margin_loss(input, weight = weight.reshape((-1, 1)) loss = paddle.mean(paddle.pow( - paddle.clip(weight * (margin + index_sample - input), min=0.0), p), + paddle.clip(weight * (margin - index_sample + input), min=0.0), p), axis=1) - margin**p / input.shape[1] else: loss = paddle.mean(paddle.pow( - paddle.clip(margin + index_sample - input, min=0.0), p), + paddle.clip(margin - index_sample + input, min=0.0), p), axis=1) - margin**p / input.shape[1] if reduction == 'mean': diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 45cb652332b3a..1acea10d6755c 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -83,6 +83,7 @@ from .loss import TripletMarginWithDistanceLoss from .loss import TripletMarginLoss from .loss import SoftMarginLoss +from .loss import MultiMarginLoss from .norm import BatchNorm1D # noqa: F401 from .norm import BatchNorm2D # noqa: F401 from .norm import BatchNorm3D # noqa: F401 diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 211bf92e41085..b0ccf54838a30 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1732,11 +1732,11 @@ class MultiMarginLoss(Layer): For more information, please refer to :ref:`api_guide_Name`. Call parameters: - input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1. - label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64. The shape of label is the same as the shape of input. + input (Tensor): Input tensor, the data type is float32 or float64. + label (Tensor): Label tensor, 0<= label < input.shape[1], the data type is int32 or int64. Shape: - input: N-D Tensor, the shape is [N, C], N is batch size and `C` means number of classes, available dtype is float32, float64. The sum operationoperates over all the elements. + input: N-D Tensor, the shape is [N, C], N is batch size and `C` means number of classes. label: N-D Tensor, the shape is [N,]. output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input. @@ -1750,17 +1750,11 @@ class MultiMarginLoss(Layer): import paddle.nn as nn input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32) - label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32) - - multi_margin_loss = nn.MultiLabelSoftMarginLoss(reduction='none') - loss = multi_margin_loss(input, label) - print(loss) - # Tensor([3.49625897, 0.71111226, 0.43989015]) + label = paddle.to_tensor([0, 1, 2], dtype=paddle.int32) multi_margin_loss = nn.MultiMarginLoss(reduction='mean') loss = multi_margin_loss(input, label) print(loss) - # Tensor([1.54908717]) """ def __init__(self, @@ -1772,7 +1766,7 @@ def __init__(self, super(MultiMarginLoss, self).__init__() if reduction not in ['sum', 'mean', 'none']: raise ValueError( - "'reduction' in 'MultiLabelSoftMarginloss' should be 'sum', 'mean' or 'none', " + "'reduction' in 'MultiMarginLoss' should be 'sum', 'mean' or 'none', " "but received {}.".format(reduction)) self.p = p self.margin = margin From a317e417c80229e0e0d40750f7372aa3361a4428 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Tue, 6 Sep 2022 09:46:48 +0800 Subject: [PATCH 03/10] fix: test_file --- python/paddle/fluid/tests/unittests/test_multimarginloss.py | 4 ++-- python/paddle/nn/functional/loss.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_multimarginloss.py b/python/paddle/fluid/tests/unittests/test_multimarginloss.py index 442c4603fe7ee..0316c4e32c588 100644 --- a/python/paddle/fluid/tests/unittests/test_multimarginloss.py +++ b/python/paddle/fluid/tests/unittests/test_multimarginloss.py @@ -232,7 +232,7 @@ def test_MultiMarginLoss_dimension(self): paddle.disable_static() input = paddle.to_tensor([[0.1, 0.3], [1, 2]], dtype='float32') - label = paddle.to_tensor([0, 1], dtype='int32') + label = paddle.to_tensor([0, 1, 1], dtype='int32') self.assertRaises( ValueError, @@ -240,7 +240,7 @@ def test_MultiMarginLoss_dimension(self): input=input, label=label, ) - MMLoss = paddle.nn.loss.MultiMarginLoss() + MMLoss = paddle.nn.MultiMarginLoss() self.assertRaises( ValueError, MMLoss, diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 201ae64f3b8d9..52d67d779728f 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3539,7 +3539,8 @@ def multi_margin_loss(input, check_variable_and_dtype(label, 'label', ['int32', 'int64'], 'multi_margin_loss') if not (input.shape[0] == label.shape[0]): - raise ValueError("The label's shape is wrong") + raise ValueError( + "The label's shape[0] should be equal to input's shape[0]") label = label.reshape((-1, 1)) index_sample = paddle.index_sample(input, label) if weight is not None: @@ -3547,7 +3548,8 @@ def multi_margin_loss(input, check_variable_and_dtype(weight, 'weight', ['float32', 'float64'], 'multi_margin_loss') if not (input.shape[0] == weight.shape[0]): - raise ValueError("The weight's shape is wrong ") + raise ValueError( + "The weight's shape[0] should be equal to weight's shape[0]") weight = weight.reshape((-1, 1)) loss = paddle.mean(paddle.pow( From eceeaab82709fdf4d3b6a0cd843a05cfbb21891b Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Wed, 14 Sep 2022 23:06:51 +0800 Subject: [PATCH 04/10] fix: test_file, docs, multi_margin_loss --- .../tests/unittests/test_multimarginloss.py | 126 ++++++++++++++++-- python/paddle/nn/functional/loss.py | 42 +++--- python/paddle/nn/layer/loss.py | 26 ++-- 3 files changed, 152 insertions(+), 42 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_multimarginloss.py b/python/paddle/fluid/tests/unittests/test_multimarginloss.py index 0316c4e32c588..1e91d9bae9936 100644 --- a/python/paddle/fluid/tests/unittests/test_multimarginloss.py +++ b/python/paddle/fluid/tests/unittests/test_multimarginloss.py @@ -101,6 +101,58 @@ def test_static(place, return static_result +def test_static_data_shape(place, + input_np, + label_np, + wrong_label_shape=None, + weight_np=None, + wrong_weight_shape=None, + functional=False): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + input = paddle.static.data(name='input', + shape=input_np.shape, + dtype=input_np.dtype) + if wrong_label_shape is None: + label_shape = label_np.shape + else: + label_shape = wrong_label_shape + label = paddle.static.data(name='label', + shape=label_shape, + dtype=label_np.dtype) + feed_dict = { + "input": input_np, + "label": label_np, + } + weight = None + if weight_np is not None: + if wrong_weight_shape is None: + weight_shape = weight_np.shape + else: + weight_shape = wrong_weight_shape + weight = paddle.static.data(name='weight', + shape=weight_shape, + dtype=weight_np.dtype) + feed_dict['weight'] = weight_np + if functional: + res = call_MultiMarginLoss_functional( + input=input, + label=label, + weight=weight, + ) + else: + res = call_MultiMarginLoss_layer( + input=input, + label=label, + weight=weight, + ) + + exe = paddle.static.Executor(place) + static_result = exe.run(prog, feed=feed_dict, fetch_list=[res]) + return static_result + + def test_dygraph(place, input, label, @@ -152,9 +204,8 @@ def calc_multi_margin_loss( expected = np.mean(np.maximum(margin + input - index_sample, 0.0)**p, axis=1) - margin**p / input.shape[1] else: - weight = weight.reshape(-1, 1) expected = np.mean(np.maximum(weight * (margin + input - index_sample), 0.0) ** p, axis=1) - margin ** p / \ - input.shape[1] + input.shape[1] if reduction == 'mean': expected = np.mean(expected) @@ -169,10 +220,12 @@ def calc_multi_margin_loss( class TestMultiMarginLoss(unittest.TestCase): def test_MultiMarginLoss(self): - shape = (2, 2) + batch_size = 5 + num_classes = 2 + shape = (batch_size, num_classes) input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) label = np.random.uniform(0, input.shape[1], - size=(2, )).astype(np.int32) + size=(batch_size, )).astype(np.int64) places = [paddle.CPUPlace()] if paddle.device.is_compiled_with_cuda(): @@ -251,12 +304,14 @@ def test_MultiMarginLoss_dimension(self): def test_MultiMarginLoss_p(self): p = 2 - shape = (2, 2) + batch_size = 5 + num_classes = 2 + shape = (batch_size, num_classes) reduction = 'mean' place = paddle.CPUPlace() input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) label = np.random.uniform(0, input.shape[1], - size=(2, )).astype(np.int64) + size=(batch_size, )).astype(np.int64) expected = calc_multi_margin_loss(input=input, p=p, label=label, @@ -297,13 +352,16 @@ def test_MultiMarginLoss_p(self): self.assertTrue(np.allclose(dy_functional, expected)) def test_MultiMarginLoss_weight(self): - shape = (2, 2) + batch_size = 5 + num_classes = 2 + shape = (batch_size, num_classes) reduction = 'mean' place = paddle.CPUPlace() input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) label = np.random.uniform(0, input.shape[1], - size=(2, )).astype(np.int64) - weight = np.random.uniform(0, 2, size=(2, )).astype(np.float64) + size=(batch_size, )).astype(np.int64) + weight = np.random.uniform(0, 2, + size=(num_classes, )).astype(np.float64) expected = calc_multi_margin_loss(input=input, label=label, weight=weight, @@ -343,6 +401,56 @@ def test_MultiMarginLoss_weight(self): self.assertTrue(np.allclose(static_functional, dy_functional)) self.assertTrue(np.allclose(dy_functional, expected)) + def test_MultiMarginLoss_static_data_shape(self): + batch_size = 5 + num_classes = 2 + shape = (batch_size, num_classes) + place = paddle.CPUPlace() + input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) + label = np.random.uniform(0, input.shape[1], + size=(batch_size, )).astype(np.int64) + weight = np.random.uniform(0, 2, + size=(num_classes, )).astype(np.float64) + + self.assertRaises( + ValueError, + test_static_data_shape, + place=place, + input_np=input, + label_np=label, + wrong_label_shape=(10, ), + functional=True, + ) + self.assertRaises( + ValueError, + test_static_data_shape, + place=place, + input_np=input, + label_np=label, + wrong_label_shape=(10, ), + functional=False, + ) + self.assertRaises( + ValueError, + test_static_data_shape, + place=place, + input_np=input, + label_np=label, + weight_np=weight, + wrong_weight_shape=(3, ), + functional=True, + ) + self.assertRaises( + ValueError, + test_static_data_shape, + place=place, + input_np=input, + label_np=label, + weight_np=weight, + wrong_weight_shape=(3, ), + functional=False, + ) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 52d67d779728f..1d3c3ca0304ff 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3468,25 +3468,25 @@ def multi_margin_loss(input, reduction='mean', name=None): r""" - Measures a multi-class classification hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) - and output :math:`y` (which is a 1D tensor of target class indices,:math:`0 \leq y \leq \text{x.size}(1)-1`): + Measures a multi-class classification hinge loss (margin-based loss) between input :math:`input` (a 2D mini-batch `Tensor`, in shape (N, C), + where C is number of classes) and label :math:`label` (which is a 1D tensor of target class indices,:math:`0 \leq label \leq \text{C}-1`): - For each mini-batch sample, the loss in terms of the 1D input :math:`x` and scalar - output :math:`y` is: + For ith mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar + output :math:`label_i` is: .. math:: - \text{loss}(x, y) = \frac{\sum_i \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)} + \text{loss}(input_i, label_i) = \frac{\sum^C_j \max(0, \text{margin} - input_i[label_i] + input_i[j])^p}{\text{C}} - where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}` - and :math:`i \neq y`. + where :math:`input_i \in \left\{0, \; \cdots , \; \text{C} - 1\right\}` + and :math:`j \neq label_i`. Optionally, you can give non-equal weighting on the classes by passing a 1D :attr:`weight` tensor into the constructor. - The loss function then becomes: + The loss function for ith mini-batch then becomes: .. math:: - \text{loss}(x, y) = \frac{\sum_i \max(0, w[y] * (\text{margin} - x[y] + x[i]))^p}{\text{x.size}(0)} + \text{loss}(input_i, label_i) = \frac{\sum^C_j \max(0, weight[label_i] * (\text{margin} - input_i[label_i] + input_i[j]))^p}{\text{C}} Parameters: @@ -3494,17 +3494,17 @@ def multi_margin_loss(input, label (Tensor): Label tensor, the data type is int32 or int64. The shape of label is (N,) - p (int, Optional): The norm degree for pairwise distance. Default: :math:`1`. + p (int, Optional): The power num. Default: :math:`1`. margin (float, Optional): Default: :math:`1`. weight (Tensor,optional): a manual rescaling weight given to each class. - If given, has to be a Tensor of size C and the data type is float32, float64. + If given, has to be a Tensor of shape (C,) and the data type is float32, float64. Default is ``'None'`` . - reduction (str, Optional):Indicate how to average the loss by batch_size. - the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + reduction (str, Optional):Indicate how to calculate the loss by batch_size. + the candidates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; If :attr:`reduction` is ``'sum'``, the summed loss is returned. @@ -3514,7 +3514,7 @@ def multi_margin_loss(input, For more information, please refer to :ref:`api_guide_Name`. Returns: - Output: Tensor. The tensor variable storing the triplet_margin_loss of input and positive and negative. + Output: Tensor. The tensor variable storing the multi_margin_loss of input and label. Examples: .. code-block:: python @@ -3540,25 +3540,27 @@ def multi_margin_loss(input, 'multi_margin_loss') if not (input.shape[0] == label.shape[0]): raise ValueError( - "The label's shape[0] should be equal to input's shape[0]") + "The label's shape[0] should be equal to input's shape[0], but received input's shape[0] {} and label's shape[0]:{}. " + .format(input.shape[0], label.shape[0])) label = label.reshape((-1, 1)) index_sample = paddle.index_sample(input, label) if weight is not None: if not _non_static_mode(): check_variable_and_dtype(weight, 'weight', ['float32', 'float64'], 'multi_margin_loss') - if not (input.shape[0] == weight.shape[0]): + if not (input.shape[1] == weight.shape[0]): raise ValueError( - "The weight's shape[0] should be equal to weight's shape[0]") + "The weight's shape[0] should be equal to input's shape[1]" + "but received weight's shape[0]: {} and input's shape[1]: {}". + format(weight.shape[0], input.shape[1])) - weight = weight.reshape((-1, 1)) loss = paddle.mean(paddle.pow( paddle.clip(weight * (margin - index_sample + input), min=0.0), p), - axis=1) - margin**p / input.shape[1] + axis=1) - margin**p / paddle.shape(input)[1] else: loss = paddle.mean(paddle.pow( paddle.clip(margin - index_sample + input, min=0.0), p), - axis=1) - margin**p / input.shape[1] + axis=1) - margin**p / paddle.shape(input)[1] if reduction == 'mean': return paddle.mean(loss, name=name) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index b0ccf54838a30..df07061e73612 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1690,26 +1690,26 @@ def forward(self, input, positive, negative): class MultiMarginLoss(Layer): - r"""Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) - between input :math:`x` (a 2D mini-batch `Tensor`) - and output :math:`y` (which is a 1D tensor of target class indices,:math:`0 \leq y \leq \text{x.size}(1)-1`): + r"""Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) between input :math:`input` (a 2D mini-batch `Tensor`, in shape (N, C), + where C is number of classes) and label :math:`label` (which is a 1D tensor of target class indices,:math:`0 \leq label \leq \text{C}-1`): - For each mini-batch sample, the loss in terms of the 1D input :math:`x` and scalar - output :math:`y` is: + For ith mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar + output :math:`label_i` is: .. math:: - \text{loss}(x, y) = \frac{\sum_i \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)} + \text{loss}(input_i, label_i) = \frac{\sum^C_j \max(0, \text{margin} - input_i[label_i] + input_i[j])^p}{\text{C}} - where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}` - and :math:`i \neq y`. + where :math:`input_i \in \left\{0, \; \cdots , \; \text{C} - 1\right\}` + and :math:`j \neq label_i`. Optionally, you can give non-equal weighting on the classes by passing a 1D :attr:`weight` tensor into the constructor. - The loss function then becomes: + The loss function for ith mini-batch then becomes: .. math:: - \text{loss}(x, y) = \frac{\sum_i \max(0, w[y] * (\text{margin} - x[y] + x[i]))^p}{\text{x.size}(0)} + \text{loss}(input_i, label_i) = \frac{\sum^C_j \max(0, weight[label_i] * (\text{margin} - input_i[label_i] + input_i[j]))^p}{\text{C}} + Parameters: @@ -1718,11 +1718,11 @@ class MultiMarginLoss(Layer): margin (float, Optional):Default: :math:`1`. weight (Tensor,optional): a manual rescaling weight given to each class. - If given, has to be a Tensor of size C and the data type is float32, float64. + If given, has to be a Tensor of shape (C,) and the data type is float32, float64. Default is ``'None'`` . - reduction (str, optional): Indicate how to average the loss by batch_size, - the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + reduction (str, optional): Indicate how to calculate the loss by batch_size, + the candidates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; If :attr:`reduction` is ``'sum'``, the summed loss is returned. From f49bfc346a5bca846e0154d3255d4d0947e942f0 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Thu, 22 Sep 2022 21:29:24 +0800 Subject: [PATCH 05/10] fix: doc weight function --- .../tests/unittests/test_multimarginloss.py | 10 +++---- python/paddle/nn/functional/loss.py | 28 +++++++++++-------- python/paddle/nn/layer/loss.py | 11 ++++---- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_multimarginloss.py b/python/paddle/fluid/tests/unittests/test_multimarginloss.py index 1e91d9bae9936..c9e4012742142 100644 --- a/python/paddle/fluid/tests/unittests/test_multimarginloss.py +++ b/python/paddle/fluid/tests/unittests/test_multimarginloss.py @@ -194,16 +194,14 @@ def calc_multi_margin_loss( weight=None, reduction='mean', ): - label = label.reshape(-1, 1) - index_sample = [] - for i in range(len(label)): - index_sample.append(input[i, label[i]]) - index_sample = np.array(index_sample).reshape(-1, 1) - + index_sample = np.array([input[i, label[i]] + for i in range(label.size)]).reshape(-1, 1) if weight is None: expected = np.mean(np.maximum(margin + input - index_sample, 0.0)**p, axis=1) - margin**p / input.shape[1] else: + weight = np.array([weight[label[i]] + for i in range(label.size)]).reshape(-1, 1) expected = np.mean(np.maximum(weight * (margin + input - index_sample), 0.0) ** p, axis=1) - margin ** p / \ input.shape[1] diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 734572b23bf66..194298de96d82 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3471,22 +3471,21 @@ def multi_margin_loss(input, Measures a multi-class classification hinge loss (margin-based loss) between input :math:`input` (a 2D mini-batch `Tensor`, in shape (N, C), where C is number of classes) and label :math:`label` (which is a 1D tensor of target class indices,:math:`0 \leq label \leq \text{C}-1`): - For ith mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar + For i-th mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar output :math:`label_i` is: .. math:: - \text{loss}(input_i, label_i) = \frac{\sum^C_j \max(0, \text{margin} - input_i[label_i] + input_i[j])^p}{\text{C}} + \text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, \text{margin} - input_i[label_i] + input_i[j])^p}{\text{C}} - where :math:`input_i \in \left\{0, \; \cdots , \; \text{C} - 1\right\}` - and :math:`j \neq label_i`. + where :math:`0 \leq j \leq \text{C}-1`, :math:`0 \leq i \leq \text{N}-1` and :math:`j \neq label_i`. Optionally, you can give non-equal weighting on the classes by passing a 1D :attr:`weight` tensor into the constructor. - The loss function for ith mini-batch then becomes: + The loss function for i-th sample then becomes: .. math:: - \text{loss}(input_i, label_i) = \frac{\sum^C_j \max(0, weight[label_i] * (\text{margin} - input_i[label_i] + input_i[j]))^p}{\text{C}} + \text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, weight[label_i] * (\text{margin} - input_i[label_i] + input_i[j]))^p}{\text{C}} Parameters: @@ -3540,8 +3539,9 @@ def multi_margin_loss(input, 'multi_margin_loss') if not (input.shape[0] == label.shape[0]): raise ValueError( - "The label's shape[0] should be equal to input's shape[0], but received input's shape[0] {} and label's shape[0]:{}. " - .format(input.shape[0], label.shape[0])) + "The label's shape[0] should be equal to input's shape[0], " + "but received input's shape[0] {} and label's shape[0]:{}. ".format( + input.shape[0], label.shape[0])) label = label.reshape((-1, 1)) index_sample = paddle.index_sample(input, label) if weight is not None: @@ -3553,10 +3553,14 @@ def multi_margin_loss(input, "The weight's shape[0] should be equal to input's shape[1]" "but received weight's shape[0]: {} and input's shape[1]: {}". format(weight.shape[0], input.shape[1])) - - loss = paddle.mean(paddle.pow( - paddle.clip(weight * (margin - index_sample + input), min=0.0), p), - axis=1) - margin**p / paddle.shape(input)[1] + weight = weight.reshape((-1, weight.shape[0])) + weight = paddle.repeat_interleave(weight, label.shape[0], 0) + weight = paddle.index_sample(weight, label) + loss = paddle.mean( + paddle.pow( + paddle.clip(weight * + (margin - index_sample + input), min=0.0), p), + axis=1) - weight * (margin**p / paddle.shape(input)[1]) else: loss = paddle.mean(paddle.pow( paddle.clip(margin - index_sample + input, min=0.0), p), diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index bae7fc8aceaa8..fdde20c93c413 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1693,22 +1693,21 @@ class MultiMarginLoss(Layer): r"""Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) between input :math:`input` (a 2D mini-batch `Tensor`, in shape (N, C), where C is number of classes) and label :math:`label` (which is a 1D tensor of target class indices,:math:`0 \leq label \leq \text{C}-1`): - For ith mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar + For i-th mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar output :math:`label_i` is: .. math:: - \text{loss}(input_i, label_i) = \frac{\sum^C_j \max(0, \text{margin} - input_i[label_i] + input_i[j])^p}{\text{C}} + \text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, \text{margin} - input_i[label_i] + input_i[j])^p}{\text{C}} - where :math:`input_i \in \left\{0, \; \cdots , \; \text{C} - 1\right\}` - and :math:`j \neq label_i`. + where :math:`0 \leq j \leq \text{C}-1`, :math:`0 \leq i \leq \text{N}-1` and :math:`j \neq label_i`. Optionally, you can give non-equal weighting on the classes by passing a 1D :attr:`weight` tensor into the constructor. - The loss function for ith mini-batch then becomes: + The loss function for i-th sample then becomes: .. math:: - \text{loss}(input_i, label_i) = \frac{\sum^C_j \max(0, weight[label_i] * (\text{margin} - input_i[label_i] + input_i[j]))^p}{\text{C}} + \text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, weight[label_i] * (\text{margin} - input_i[label_i] + input_i[j]))^p}{\text{C}} Parameters: From 0a042c8b524d7e3e2a0c14c0e2e6bb2d5c59065a Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Fri, 23 Sep 2022 22:35:27 +0800 Subject: [PATCH 06/10] fix: test_multi_margin_loss --- python/paddle/fluid/tests/unittests/test_multimarginloss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_multimarginloss.py b/python/paddle/fluid/tests/unittests/test_multimarginloss.py index c9e4012742142..87a920da230cd 100644 --- a/python/paddle/fluid/tests/unittests/test_multimarginloss.py +++ b/python/paddle/fluid/tests/unittests/test_multimarginloss.py @@ -202,8 +202,8 @@ def calc_multi_margin_loss( else: weight = np.array([weight[label[i]] for i in range(label.size)]).reshape(-1, 1) - expected = np.mean(np.maximum(weight * (margin + input - index_sample), 0.0) ** p, axis=1) - margin ** p / \ - input.shape[1] + expected = np.mean(np.maximum(weight * (margin + input - index_sample), 0.0) ** p, axis=1) - weight*(margin ** p / \ + input.shape[1]) if reduction == 'mean': expected = np.mean(expected) From 8cb585b6a53e1c07d5e99d2c84a44b295ad8353d Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Tue, 27 Sep 2022 23:39:58 +0800 Subject: [PATCH 07/10] fix: weight np.testing.assert_allclose --- .../tests/unittests/test_multimarginloss.py | 38 +++++++++---------- python/paddle/nn/functional/loss.py | 4 +- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_multimarginloss.py b/python/paddle/fluid/tests/unittests/test_multimarginloss.py index 87a920da230cd..d09d8779fa148 100644 --- a/python/paddle/fluid/tests/unittests/test_multimarginloss.py +++ b/python/paddle/fluid/tests/unittests/test_multimarginloss.py @@ -150,7 +150,7 @@ def test_static_data_shape(place, exe = paddle.static.Executor(place) static_result = exe.run(prog, feed=feed_dict, fetch_list=[res]) - return static_result + return static_result[0] def test_dygraph(place, @@ -248,9 +248,9 @@ def test_MultiMarginLoss(self): label_np=label, reduction=reduction, ) - self.assertTrue(np.allclose(static_result, expected)) - self.assertTrue(np.allclose(static_result, dy_result)) - self.assertTrue(np.allclose(dy_result, expected)) + np.testing.assert_allclose(static_result, expected) + np.testing.assert_allclose(static_result, dy_result) + np.testing.assert_allclose(dy_result, expected) static_functional = test_static(place=place, input_np=input, label_np=label, @@ -261,9 +261,9 @@ def test_MultiMarginLoss(self): label=label, reduction=reduction, functional=True) - self.assertTrue(np.allclose(static_functional, expected)) - self.assertTrue(np.allclose(static_functional, dy_functional)) - self.assertTrue(np.allclose(dy_functional, expected)) + np.testing.assert_allclose(static_functional, expected) + np.testing.assert_allclose(static_functional, dy_functional) + np.testing.assert_allclose(dy_functional, expected) def test_MultiMarginLoss_error(self): paddle.disable_static() @@ -330,9 +330,9 @@ def test_MultiMarginLoss_p(self): label_np=label, reduction=reduction, ) - self.assertTrue(np.allclose(static_result, expected)) - self.assertTrue(np.allclose(static_result, dy_result)) - self.assertTrue(np.allclose(dy_result, expected)) + np.testing.assert_allclose(static_result, expected) + np.testing.assert_allclose(static_result, dy_result) + np.testing.assert_allclose(dy_result, expected) static_functional = test_static(place=place, p=p, input_np=input, @@ -345,9 +345,9 @@ def test_MultiMarginLoss_p(self): label=label, reduction=reduction, functional=True) - self.assertTrue(np.allclose(static_functional, expected)) - self.assertTrue(np.allclose(static_functional, dy_functional)) - self.assertTrue(np.allclose(dy_functional, expected)) + np.testing.assert_allclose(static_functional, expected) + np.testing.assert_allclose(static_functional, dy_functional) + np.testing.assert_allclose(dy_functional, expected) def test_MultiMarginLoss_weight(self): batch_size = 5 @@ -380,9 +380,9 @@ def test_MultiMarginLoss_weight(self): weight_np=weight, reduction=reduction, ) - self.assertTrue(np.allclose(static_result, expected)) - self.assertTrue(np.allclose(static_result, dy_result)) - self.assertTrue(np.allclose(dy_result, expected)) + np.testing.assert_allclose(static_result, expected) + np.testing.assert_allclose(static_result, dy_result) + np.testing.assert_allclose(dy_result, expected) static_functional = test_static(place=place, input_np=input, label_np=label, @@ -395,9 +395,9 @@ def test_MultiMarginLoss_weight(self): weight=weight, reduction=reduction, functional=True) - self.assertTrue(np.allclose(static_functional, expected)) - self.assertTrue(np.allclose(static_functional, dy_functional)) - self.assertTrue(np.allclose(dy_functional, expected)) + np.testing.assert_allclose(static_functional, expected) + np.testing.assert_allclose(static_functional, dy_functional) + np.testing.assert_allclose(dy_functional, expected) def test_MultiMarginLoss_static_data_shape(self): batch_size = 5 diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 194298de96d82..a3e5f3ab4ffba 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3553,9 +3553,7 @@ def multi_margin_loss(input, "The weight's shape[0] should be equal to input's shape[1]" "but received weight's shape[0]: {} and input's shape[1]: {}". format(weight.shape[0], input.shape[1])) - weight = weight.reshape((-1, weight.shape[0])) - weight = paddle.repeat_interleave(weight, label.shape[0], 0) - weight = paddle.index_sample(weight, label) + weight = paddle.gather(weight, label, axis=0).reshape((-1, 1)) loss = paddle.mean( paddle.pow( paddle.clip(weight * From 43c6226cc0dcdeb573a316e4829c10001054f462 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Wed, 28 Sep 2022 23:36:58 +0800 Subject: [PATCH 08/10] fix: test_file --- python/paddle/fluid/tests/unittests/test_multimarginloss.py | 4 ++-- python/paddle/nn/layer/loss.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_multimarginloss.py b/python/paddle/fluid/tests/unittests/test_multimarginloss.py index d09d8779fa148..1eff1deb69295 100644 --- a/python/paddle/fluid/tests/unittests/test_multimarginloss.py +++ b/python/paddle/fluid/tests/unittests/test_multimarginloss.py @@ -98,7 +98,7 @@ def test_static(place, exe = paddle.static.Executor(place) static_result = exe.run(prog, feed=feed_dict, fetch_list=[res]) - return static_result + return static_result[0] def test_static_data_shape(place, @@ -150,7 +150,7 @@ def test_static_data_shape(place, exe = paddle.static.Executor(place) static_result = exe.run(prog, feed=feed_dict, fetch_list=[res]) - return static_result[0] + return static_result def test_dygraph(place, diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 19d8d08c5a913..0f0b2883239eb 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1705,7 +1705,7 @@ class MultiMarginLoss(Layer): margin (float, Optional):Default: :math:`1`. - weight (Tensor,optional): a manual rescaling weight given to each class. + weight (Tensor,optional): a manual rescaling weight given to each class. If given, has to be a Tensor of shape (C,) and the data type is float32, float64. Default is ``'None'`` . From 518c53cc9ab5f221a28dc76ba1a0358f6a1476b0 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Thu, 29 Sep 2022 21:17:05 +0800 Subject: [PATCH 09/10] fix: en_doc --- python/paddle/nn/functional/loss.py | 3 +-- python/paddle/nn/layer/loss.py | 11 +++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 2e92ae4563489..f7563733c76f9 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3477,8 +3477,7 @@ def multi_margin_loss(input, reduction='mean', name=None): r""" - Measures a multi-class classification hinge loss (margin-based loss) between input :math:`input` (a 2D mini-batch `Tensor`, in shape (N, C), - where C is number of classes) and label :math:`label` (which is a 1D tensor of target class indices,:math:`0 \leq label \leq \text{C}-1`): + Measures a multi-class classification hinge loss between input :math:`input` and label :math:`label`: For i-th mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar output :math:`label_i` is: diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 0f0b2883239eb..c6368855837fd 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1679,8 +1679,8 @@ def forward(self, input, positive, negative): class MultiMarginLoss(Layer): - r"""Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) between input :math:`input` (a 2D mini-batch `Tensor`, in shape (N, C), - where C is number of classes) and label :math:`label` (which is a 1D tensor of target class indices,:math:`0 \leq label \leq \text{C}-1`): + r"""Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) between + input :math:`input` and label :math:`label`: For i-th mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar output :math:`label_i` is: @@ -1721,11 +1721,14 @@ class MultiMarginLoss(Layer): Call parameters: input (Tensor): Input tensor, the data type is float32 or float64. + label (Tensor): Label tensor, 0<= label < input.shape[1], the data type is int32 or int64. Shape: - input: N-D Tensor, the shape is [N, C], N is batch size and `C` means number of classes. - label: N-D Tensor, the shape is [N,]. + input: 2-D Tensor, the shape is [N, C], N is batch size and `C` means number of classes. + + label: 1-D Tensor, the shape is [N,]. + output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input. Returns: From ef91530a11e33e1059e34e85868425fe4c1e9f3e Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Mon, 10 Oct 2022 21:19:23 +0800 Subject: [PATCH 10/10] 2022-10-10 --- python/paddle/nn/layer/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index c6368855837fd..2a540eb0006d0 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1729,7 +1729,7 @@ class MultiMarginLoss(Layer): label: 1-D Tensor, the shape is [N,]. - output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input. + output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the label. Returns: A callable object of MultiMarginLoss.