Skip to content

Commit 227959b

Browse files
authoredSep 4, 2020
added emb similarity (Lightning-AI#3349)
1 parent 7bd2f94 commit 227959b

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed
 

‎docs/source/metrics.rst

+6
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,12 @@ dice_score (F)
315315
.. autofunction:: pytorch_lightning.metrics.functional.dice_score
316316
:noindex:
317317

318+
embedding_similarity (F)
319+
^^^^^^^^^^^^^^^^^^^^^^^^
320+
321+
.. autofunction:: pytorch_lightning.metrics.functional.embedding_similarity
322+
:noindex:
323+
318324
f1_score (F)
319325
^^^^^^^^^^^^
320326

‎pytorch_lightning/metrics/functional/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,6 @@
2929
rmsle,
3030
ssim
3131
)
32+
from pytorch_lightning.metrics.functional.self_supervised import (
33+
embedding_similarity
34+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
3+
4+
def embedding_similarity(
5+
batch: torch.Tensor,
6+
similarity: str = 'cosine',
7+
reduction: str = 'none',
8+
zero_diagonal: bool = True
9+
) -> torch.Tensor:
10+
"""
11+
Computes representation similarity
12+
13+
Example:
14+
15+
>>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]])
16+
>>> embedding_similarity(embeddings)
17+
tensor([[0.0000, 1.0000, 0.9759],
18+
[1.0000, 0.0000, 0.9759],
19+
[0.9759, 0.9759, 0.0000]])
20+
21+
Args:
22+
batch: (batch, dim)
23+
similarity: 'dot' or 'cosine'
24+
reduction: 'none', 'sum', 'mean' (all along dim -1)
25+
zero_diagonal: if True, the diagonals are set to zero
26+
27+
Return:
28+
A square matrix (batch, batch) with the similarity scores between all elements
29+
If sum or mean are used, then returns (b, 1) with the reduced value for each row
30+
"""
31+
if similarity == 'cosine':
32+
norm = torch.norm(batch, p=2, dim=1)
33+
batch = batch / norm.unsqueeze(1)
34+
35+
sqr_mtx = batch.mm(batch.transpose(1, 0))
36+
37+
if zero_diagonal:
38+
sqr_mtx = sqr_mtx.fill_diagonal_(0)
39+
40+
if reduction == 'mean':
41+
sqr_mtx = sqr_mtx.mean(dim=-1)
42+
43+
return sqr_mtx
44+
45+
46+
if __name__ == '__main__':
47+
a = torch.rand(3, 5)
48+
49+
print(embedding_similarity(a, 'cosine'))

0 commit comments

Comments
 (0)
Failed to load comments.