-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add nll loss API #26019
add nll loss API #26019
Conversation
Thanks for your contribution! |
693432b
to
60e7df1
Compare
python/paddle/nn/functional/loss.py
Outdated
if x_dims != 2 and x_dims != 4: | ||
x = paddle.reshape(x, shape=[n, c, 1, -1]) | ||
label = paddle.reshape(label, shape=[n, 1, -1]) | ||
out_shape = [n] + x_shape[2:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块是否要分一下静态图和动态图?
python/paddle/nn/functional/loss.py
Outdated
ignore_index, 'reduction', | ||
reduction) | ||
if x_dims != 2 and x_dims != 4 and reduction == 'none': | ||
out = paddle.reshape(out, shape=out_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用core.ops
python/paddle/nn/layer/loss.py
Outdated
@@ -505,18 +506,18 @@ class NLLLoss(fluid.dygraph.Layer): | |||
\\end{cases} | |||
|
|||
Parameters: | |||
input (Variable): Input tensor, the data type is float32, float64. | |||
x (Variable): Input tensor, the data type is float32, float64. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable -> Tensor 下面都是把Variable改成Tensor
python/paddle/nn/functional/loss.py
Outdated
raise ValueError( | ||
"The value of 'reduction' in nll_loss should be 'sum', 'mean' or " | ||
"'none', but received %s, which is not allowed." % reduction) | ||
x_shape = list(x.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
要确认一下Tensor这种shape属性的调用方式
826b71b
to
ee9ee56
Compare
f659ddf
to
493292a
Compare
python/paddle/nn/functional/loss.py
Outdated
|
||
Parameters: | ||
x (Tensor): Input tensor, the data type is float32, float64. | ||
label (Tensor): Label tensor, the data type is int64_t. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Int64_t -> int64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
按照规范copy一下
python/paddle/nn/functional/loss.py
Outdated
and does not contribute to the input gradient. | ||
reduction (str, optional): Indicate how to average the loss, | ||
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. | ||
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sum也要介绍一下
python/paddle/nn/functional/loss.py
Outdated
prog, | ||
feed={"x": x_np, | ||
"label": label_np}, | ||
fetch_list=[res]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
静态图示例可以去掉
python/paddle/nn/functional/loss.py
Outdated
label = paddle.reshape(label, shape=[n, 1, -1]) | ||
out_shape = [n] + x_shape[2:] | ||
|
||
helper = LayerHelper('nll_loss', **locals()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块提前一下
reduction (str, optional): Indicate how to average the loss, | ||
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. | ||
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; | ||
Default is ``'mean'``. | ||
ignore_index (int64, optional): Specifies a target value that is ignored | ||
and does not contribute to the input gradient. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
增加name参数
python/paddle/nn/layer/loss.py
Outdated
|
||
Returns: | ||
The tensor variable storing the nll_loss. | ||
The callable object which calculates negative log likelihood loss when | ||
get the input `x` and `label` . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉x和label相关信息
493292a
to
08cee15
Compare
python/paddle/nn/functional/loss.py
Outdated
and does not contribute to the input gradient. | ||
reduction (str, optional): Indicate how to average the loss, | ||
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. | ||
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
格式应该有问题 【If reduction
is】就可以 下面同理
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
python/paddle/nn/functional/loss.py
Outdated
"'none', but received %s, which is not allowed." % reduction) | ||
|
||
if in_dygraph_mode(): | ||
x_shape = list(core.ops.shape(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can write x_shape = x.shape
directly here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/nn/layer/loss.py
Outdated
"The value of 'reduction' in nll_loss should be 'sum', 'mean' or " | ||
"'none', but received %s, which is not allowed." % reduction) | ||
super(NLLLoss, self).__init__() | ||
self.weight = weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe self._weight
is better for private member, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/nn/functional/loss.py
Outdated
x (Tensor): Input tensor, the data type is float32, float64. | ||
label (Tensor): Label tensor, the data type is int64. | ||
weight (Tensor, optional): Weight tensor, a manual rescaling weight given | ||
to each class. If given, it has to be a Tensor of size `C`. Otherwise, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it has to be a Tensor of size
C
Better describe this clearly, does it means x should be of format [N, C, H, W]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it has to be a Tensor of size
C
Better describe this clearly, does it means x should be of format [N, C, H, W]
@zhiqiu Already describe the shape of x
, label
, weight
more clearly.
0c5698e
to
349c5e8
Compare
349c5e8
to
ede9a87
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Describe
Add nll_loss function, and encapsulate NLL_Loss class and optimize the performance of nll_loss in dygraph mode