Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET -1030] Cosine Embedding Loss #12750

Merged
merged 38 commits into from
Oct 29, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b3b5de0
COsine Embedding Loss function added
gaurav-gireesh Oct 5, 2018
eb9b9b4
Added unit tests for Cosine Embedding Loss Function
gaurav-gireesh Oct 5, 2018
7fdd85d
Added Latex code for formula for cosine embedding loss
gaurav-gireesh Oct 7, 2018
013a604
Fixing document rendering
gaurav-gireesh Oct 8, 2018
aac12ad
Fixing documentation issue
gaurav-gireesh Oct 8, 2018
1c97924
PR Comments addressed for using F (NDArray or Symbol) to calculate no…
gaurav-gireesh Oct 8, 2018
9766983
Markdown file updated. Added entry for CosineEmbeddingLoss
gaurav-gireesh Oct 8, 2018
f05eb7b
Added a line after .. math:: to fix documentation
gaurav-gireesh Oct 9, 2018
c02e111
Documentation check - pylint fix
gaurav-gireesh Oct 9, 2018
c10f1ef
Formula update
gaurav-gireesh Oct 9, 2018
95dd2a7
Making the formula simpler for correct rendering incrementally - Upda…
gaurav-gireesh Oct 9, 2018
01607b4
Making the formula simpler for correct rendering incrementally - Upda…
gaurav-gireesh Oct 9, 2018
c8bca0b
Making the formula simpler for correct rendering incrementally - Upda…
gaurav-gireesh Oct 9, 2018
5194cd8
Making the formula simpler for correct rendering incrementally - Upda…
gaurav-gireesh Oct 9, 2018
c01f8cb
Making the formula simpler for correct rendering incrementally - Upda…
gaurav-gireesh Oct 9, 2018
4b3fe81
Trigger CI
gaurav-gireesh Oct 9, 2018
78bd725
making the utility function cosine similarity internal
gaurav-gireesh Oct 9, 2018
3b3e117
Added a test case for label = -1, for dissimilar vectors
gaurav-gireesh Oct 10, 2018
2df6953
Refactored names of parameters to the loss functions and updated the …
gaurav-gireesh Oct 10, 2018
5c642cb
PR comments addressed changes in documentation
gaurav-gireesh Oct 10, 2018
4be5104
Added random input vectors and labelled tests
gaurav-gireesh Oct 11, 2018
410a708
Renaming variables
gaurav-gireesh Oct 11, 2018
1f48429
Pylint issues fixed
gaurav-gireesh Oct 11, 2018
b618b61
Merged from upstream master branch + Resolved conflicts
gaurav-gireesh Oct 15, 2018
ed762e5
Resolving conflicts
gaurav-gireesh Oct 15, 2018
89aafbc
Pylint issues fixed
gaurav-gireesh Oct 15, 2018
4a3167b
Style issues fixed trailing whitespaces removed
gaurav-gireesh Oct 15, 2018
d80baac
Review comment addressed, sample_weight added in the parameter
gaurav-gireesh Oct 25, 2018
b36e097
Merge remote-tracking branch 'upstream/master' into cosineloss
gaurav-gireesh Oct 26, 2018
c195ed0
Trigger CI
gaurav-gireesh Oct 26, 2018
308666b
Reordered Parameter description
gaurav-gireesh Oct 26, 2018
16c3ecd
comments addressed - spelling errors
gaurav-gireesh Oct 26, 2018
2dfeaa2
nit comments addressed
gaurav-gireesh Oct 26, 2018
0bd4b24
Trigger CI
gaurav-gireesh Oct 26, 2018
ca030a7
Merge remote-tracking branch 'upstream/master' into cosineloss
gaurav-gireesh Oct 26, 2018
ede1588
Trugger CI
gaurav-gireesh Oct 26, 2018
67572c5
Trigger CI
gaurav-gireesh Oct 26, 2018
55d4b1e
Trigger CI
gaurav-gireesh Oct 27, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion python/mxnet/gluon/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
'SigmoidBinaryCrossEntropyLoss', 'SigmoidBCELoss',
'SoftmaxCrossEntropyLoss', 'SoftmaxCELoss',
'KLDivLoss', 'CTCLoss', 'HuberLoss', 'HingeLoss',
'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss']
'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss', 'CosineEmbeddingLoss']

from .. import ndarray
from ..base import numeric_types
Expand Down Expand Up @@ -706,3 +706,53 @@ def hybrid_forward(self, F, pred, positive, negative):
axis=self._batch_axis, exclude=True)
loss = F.relu(loss + self._margin)
return _apply_weighting(F, loss, self._weight, None)

class CosineEmbeddingLoss(Loss):
r"""For a target label 1 or -1, vectors target and pred, the function computes the cosine distance
between the vectors. This can be interpretted as how similar/dissimilar two input vectors are.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: interpreted



`pred`, `target` can have arbitrary shape as long as they have the same number of elements.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you document the formula here? our website supports mathjax so you can write latex for it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szha Thank you! Incorporating that in the documentation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@roywei roywei Oct 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to huber loss and use only {case}, remove {gather}. maybe mathjax version you tested on and what we used in the website is different, so some formula is not recognized. A safe way is to use only what have been correctly rendered, those should be more than enough

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update document on param names


Parameters
----------
weight : float or None
Global scalar weight for loss.
batch_axis : int, default 0
The axis that represents mini-batch.
margin : float
Margin of separation between correct and incorrect pair.


Inputs:
------
- **pred**: prediction tensor with arbitrary shape
- **target**: target tensor with same shape as pred.
- **sample_weight**: element-wise weighting tensor. Must be broadcastable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be added to hybrid_forward. Could you also put this after label?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed this. Thanks

to the same shape as pred. For example, if pred has shape (64, 10)
and you want to weigh each sample in the batch separately,
sample_weight should have shape (64, 1).
- label: A 1-D tensor indicating for each pair input and pred, target label is 1 or -1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ** label**


Outputs:
--------
- **loss**: Average loss (shape=(1,1)) of the loss tensor with shape (batch_size,).
"""
def __init__(self, weight=None, batch_axis=0, margin=0, **kwargs):
super(CosineEmbeddingLoss, self).__init__(weight, batch_axis, **kwargs)
self._margin = margin

def hybrid_forward(self, F, pred, target, label):
pred = _reshape_like(F, pred, target)
cos_sim = self.cosine_similarity(F, pred, target)
y_1 = label == 1
y_minus_1 = label == -1
cos_sim_a = (1 - cos_sim) * y_1
cos_sim_b = F.broadcast_maximum(F.array([0]), y_minus_1 * (cos_sim - self._margin), axis=1)
return cos_sim_a + cos_sim_b

def cosine_similarity(self, F, F1, F2, axis=-1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename F1 and F2 to x, y or x1, x2 to avoid confusion with F. F is for hybridization, F1 and F2 are vectors.
Also write a small comment, explaining you are computing the cosine similarity here

F1_norm = F1.norm(axis=axis).reshape(-1, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use F.norm(x, axis=axis)... to support hybridization

F2_norm = F2.norm(axis=axis).reshape(-1, 1)
F1_dot_F2 = F.sum(F1*F2, axis=axis).reshape(-1, 1)
return (F1_dot_F2 / F.broadcast_maximum(F1_norm * F2_norm, F.array([1e-12])))
16 changes: 16 additions & 0 deletions tests/python/unittest/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,22 @@ def test_triplet_loss():
optimizer='adam')
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05

@with_seed()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add additional tests for label = -1 and hybridization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add additional tests for label = -1 and hybridization.

  1. Added tests for label 1 and -1 in a randomly generated set of labels,
  2. This function is meant to be utility similar to cosine_distance.

def test_cosine_loss():
#For similarity check
label = mx.nd.array([1])
pred = mx.nd.array([[1, 1, 1, 1],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also rename params in test to input1, input2, label to avoid confusion

[1, 2, 3, 4]])
target = mx.nd.array([[1, 1, 1, 1],
[1, 2, 3, 4]])
Loss = gluon.loss.CosineEmbeddingLoss()
loss = Loss(pred, target, label)

#computing numpy way
numerator = mx.nd.sum(pred * target, keepdims=True, axis=1)
denominator = mx.nd.sqrt(mx.nd.sum(pred**2, axis=1, keepdims=True)) \
* mx.nd.sqrt(mx.nd.sum(target**2, axis=1, keepdims=True))
assert_almost_equal(loss.asnumpy(), (1-numerator/denominator).asnumpy())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not test using the Module API with mod.score and mod.fit


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: additional empty line not needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It is not required. Following the other test modules. They have a 2 lines gap after the last test function and main function.

if __name__ == '__main__':
import nose
Expand Down