-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET -1030] Cosine Embedding Loss #12750
Changes from 2 commits
b3b5de0
eb9b9b4
7fdd85d
013a604
aac12ad
1c97924
9766983
f05eb7b
c02e111
c10f1ef
95dd2a7
01607b4
c8bca0b
5194cd8
c01f8cb
4b3fe81
78bd725
3b3e117
2df6953
5c642cb
4be5104
410a708
1f48429
b618b61
ed762e5
89aafbc
4a3167b
d80baac
b36e097
c195ed0
308666b
16c3ecd
2dfeaa2
0bd4b24
ca030a7
ede1588
67572c5
55d4b1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ | |
'SigmoidBinaryCrossEntropyLoss', 'SigmoidBCELoss', | ||
'SoftmaxCrossEntropyLoss', 'SoftmaxCELoss', | ||
'KLDivLoss', 'CTCLoss', 'HuberLoss', 'HingeLoss', | ||
'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss'] | ||
'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss', 'CosineEmbeddingLoss'] | ||
|
||
from .. import ndarray | ||
from ..base import numeric_types | ||
|
@@ -706,3 +706,53 @@ def hybrid_forward(self, F, pred, positive, negative): | |
axis=self._batch_axis, exclude=True) | ||
loss = F.relu(loss + self._margin) | ||
return _apply_weighting(F, loss, self._weight, None) | ||
|
||
class CosineEmbeddingLoss(Loss): | ||
r"""For a target label 1 or -1, vectors target and pred, the function computes the cosine distance | ||
between the vectors. This can be interpretted as how similar/dissimilar two input vectors are. | ||
|
||
|
||
`pred`, `target` can have arbitrary shape as long as they have the same number of elements. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you document the formula here? our website supports mathjax so you can write latex for it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @szha Thank you! Incorporating that in the documentation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rendering needs to be fixed: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Latest doc: http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-12750/6/api/python/gluon/loss.html#mxnet.gluon.loss.CosineEmbeddingLoss There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please refer to huber loss and use only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update document on param names |
||
|
||
Parameters | ||
---------- | ||
weight : float or None | ||
Global scalar weight for loss. | ||
batch_axis : int, default 0 | ||
The axis that represents mini-batch. | ||
margin : float | ||
Margin of separation between correct and incorrect pair. | ||
|
||
|
||
Inputs: | ||
------ | ||
- **pred**: prediction tensor with arbitrary shape | ||
- **target**: target tensor with same shape as pred. | ||
- **sample_weight**: element-wise weighting tensor. Must be broadcastable | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be added to hybrid_forward. Could you also put this after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed this. Thanks |
||
to the same shape as pred. For example, if pred has shape (64, 10) | ||
and you want to weigh each sample in the batch separately, | ||
sample_weight should have shape (64, 1). | ||
- label: A 1-D tensor indicating for each pair input and pred, target label is 1 or -1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: ** label** |
||
|
||
Outputs: | ||
-------- | ||
- **loss**: Average loss (shape=(1,1)) of the loss tensor with shape (batch_size,). | ||
""" | ||
def __init__(self, weight=None, batch_axis=0, margin=0, **kwargs): | ||
super(CosineEmbeddingLoss, self).__init__(weight, batch_axis, **kwargs) | ||
self._margin = margin | ||
|
||
def hybrid_forward(self, F, pred, target, label): | ||
pred = _reshape_like(F, pred, target) | ||
cos_sim = self.cosine_similarity(F, pred, target) | ||
y_1 = label == 1 | ||
y_minus_1 = label == -1 | ||
cos_sim_a = (1 - cos_sim) * y_1 | ||
cos_sim_b = F.broadcast_maximum(F.array([0]), y_minus_1 * (cos_sim - self._margin), axis=1) | ||
return cos_sim_a + cos_sim_b | ||
|
||
def cosine_similarity(self, F, F1, F2, axis=-1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: rename F1 and F2 to x, y or x1, x2 to avoid confusion with F. F is for hybridization, F1 and F2 are vectors. |
||
F1_norm = F1.norm(axis=axis).reshape(-1, 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use F.norm(x, axis=axis)... to support hybridization |
||
F2_norm = F2.norm(axis=axis).reshape(-1, 1) | ||
F1_dot_F2 = F.sum(F1*F2, axis=axis).reshape(-1, 1) | ||
return (F1_dot_F2 / F.broadcast_maximum(F1_norm * F2_norm, F.array([1e-12]))) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -348,6 +348,22 @@ def test_triplet_loss(): | |
optimizer='adam') | ||
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05 | ||
|
||
@with_seed() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add additional tests for label = -1 and hybridization. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
def test_cosine_loss(): | ||
#For similarity check | ||
label = mx.nd.array([1]) | ||
pred = mx.nd.array([[1, 1, 1, 1], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please also rename params in test to input1, input2, label to avoid confusion |
||
[1, 2, 3, 4]]) | ||
target = mx.nd.array([[1, 1, 1, 1], | ||
[1, 2, 3, 4]]) | ||
Loss = gluon.loss.CosineEmbeddingLoss() | ||
loss = Loss(pred, target, label) | ||
|
||
#computing numpy way | ||
numerator = mx.nd.sum(pred * target, keepdims=True, axis=1) | ||
denominator = mx.nd.sqrt(mx.nd.sum(pred**2, axis=1, keepdims=True)) \ | ||
* mx.nd.sqrt(mx.nd.sum(target**2, axis=1, keepdims=True)) | ||
assert_almost_equal(loss.asnumpy(), (1-numerator/denominator).asnumpy()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not test using the Module API with |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: additional empty line not needed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. It is not required. Following the other test modules. They have a 2 lines gap after the last test function and main function. |
||
if __name__ == '__main__': | ||
import nose | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: interpreted