Skip to content

Commit

Permalink
Merge pull request #653 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v2.3.0
  • Loading branch information
KevinMusgrave authored Jul 25, 2023
2 parents 8e84386 + ea0946a commit bf8f2ab
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dist/
*.egg-info/
site/
venv/
**/.vscode
.ipynb_checkpoints
examples/notebooks/dataset
examples/notebooks/CIFAR10_Dataset
Expand Down
1 change: 1 addition & 0 deletions CONTENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
| [**CosFaceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#cosfaceloss) | - [CosFace: Large Margin Cosine Loss for Deep Face Recognition](https://arxiv.org/pdf/1801.09414.pdf) <br/> - [Additive Margin Softmax for Face Verification](https://arxiv.org/pdf/1801.05599.pdf)
| [**FastAPLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#fastaploss) | [Deep Metric Learning to Rank](http://openaccess.thecvf.com/content_CVPR_2019/papers/Cakir_Deep_Metric_Learning_to_Rank_CVPR_2019_paper.pdf)
| [**GeneralizedLiftedStructureLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#generalizedliftedstructureloss) | [In Defense of the Triplet Loss for Person Re-Identification](https://arxiv.org/pdf/1703.07737.pdf)
| [**HistogramLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss) | [Learning Deep Embeddings with Histogram Loss](https://arxiv.org/pdf/1611.00822.pdf)
| [**InstanceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#instanceloss) | [Dual-Path Convolutional Image-Text Embeddings with Instance Loss](https://arxiv.org/pdf/1711.05535.pdf)
| [**IntraPairVarianceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#intrapairvarianceloss) | [Deep Metric Learning with Tuplet Margin Loss](http://openaccess.thecvf.com/content_ICCV_2019/papers/Yu_Deep_Metric_Learning_With_Tuplet_Margin_Loss_ICCV_2019_paper.pdf)
| [**LargeMarginSoftmaxLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#largemarginsoftmaxloss) | [Large-Margin Softmax Loss for Convolutional Neural Networks](https://arxiv.org/pdf/1612.02295.pdf)
Expand Down
20 changes: 20 additions & 0 deletions docs/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,26 @@ losses.InstanceLoss(gamma=64, **kwargs)
* **gamma**: The cosine similarity matrix is scaled by this amount.


## HistogramLoss
[Learning Deep Embeddings with Histogram Loss](https://arxiv.org/pdf/1611.00822.pdf)
```python
losses.HistogramLoss(n_bins=None, delta=None)
```

**Parameters**:

* **n_bins**: The number of bins used to construct the histogram. Default is 100 when both `n_bins` and `delta` are `None`.
* **delta**: The mesh of the uniform partition of the interval [-1, 1] used to construct the histogram. If not set the value of n_bins will be used.

**Default distance**:

- [```CosineSimilarity()```](distances.md#cosinesimilarity)

**Default reducer**:

- This loss returns an **already reduced** loss.


## IntraPairVarianceLoss
[Deep Metric Learning with Tuplet Margin Loss](http://openaccess.thecvf.com/content_ICCV_2019/papers/Yu_Deep_Metric_Learning_With_Tuplet_Margin_Loss_ICCV_2019_paper.pdf){target=_blank}
```python
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.2.0"
__version__ = "2.3.0"
1 change: 1 addition & 0 deletions src/pytorch_metric_learning/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .cross_batch_memory import CrossBatchMemory
from .fast_ap_loss import FastAPLoss
from .generic_pair_loss import GenericPairLoss
from .histogram_loss import HistogramLoss
from .instance_loss import InstanceLoss
from .intra_pair_variance_loss import IntraPairVarianceLoss
from .large_margin_softmax_loss import LargeMarginSoftmaxLoss
Expand Down
80 changes: 80 additions & 0 deletions src/pytorch_metric_learning/losses/histogram_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch

from ..distances import CosineSimilarity
from ..utils import common_functions as c_f
from ..utils import loss_and_miner_utils as lmu
from .base_metric_loss_function import BaseMetricLossFunction


def filter_pairs(*tensors: torch.Tensor):
t = torch.stack(tensors)
t, _ = torch.sort(t, dim=0)
t = torch.unique(t, dim=1)
return t.tolist()


class HistogramLoss(BaseMetricLossFunction):
def __init__(self, n_bins: int = None, delta: float = None, **kwargs):
super().__init__(**kwargs)
if delta is not None and n_bins is not None:
assert (
delta == 2 / n_bins
), f"delta and n_bins must satisfy the equation delta = 2/n_bins.\nPassed values are delta={delta} and n_bins={n_bins}"

if delta is None and n_bins is None:
n_bins = 100

self.delta = delta if delta is not None else 2 / n_bins
self.add_to_recordable_attributes(name="delta", is_stat=True)

def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
c_f.labels_or_indices_tuple_required(labels, indices_tuple)
c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)
indices_tuple = lmu.convert_to_triplets(
indices_tuple, labels, ref_labels, t_per_anchor="all"
)
anchor_idx, positive_idx, negative_idx = indices_tuple
if len(anchor_idx) == 0:
return self.zero_losses()
mat = self.distance(embeddings, ref_emb)

anchor_positive_idx = filter_pairs(anchor_idx, positive_idx)
anchor_negative_idx = filter_pairs(anchor_idx, negative_idx)
ap_dists = mat[anchor_positive_idx]
an_dists = mat[anchor_negative_idx]

p_pos = self.compute_density(ap_dists)
phi = torch.cumsum(p_pos, dim=0)

p_neg = self.compute_density(an_dists)
return {
"loss": {
"losses": torch.sum(p_neg * phi),
"indices": None,
"reduction_type": "already_reduced",
}
}

def compute_density(self, distances):
size = distances.size(0)
r_star = torch.floor(
(distances.float() + 1) / self.delta
) # Indices of the bins containing the values of the distances
r_star = c_f.to_device(r_star, tensor=distances, dtype=torch.long)

delta_ijr_a = (distances + 1 - r_star * self.delta) / self.delta
delta_ijr_b = ((r_star + 1) * self.delta - 1 - distances) / self.delta
delta_ijr_a = c_f.to_dtype(delta_ijr_a, tensor=distances)
delta_ijr_b = c_f.to_dtype(delta_ijr_b, tensor=distances)

density = torch.zeros(round(1 + 2 / self.delta))
density = c_f.to_device(density, tensor=distances, dtype=distances.dtype)

# For each node sum the contributions of the bins whose ending node is this one
density.scatter_add_(0, r_star + 1, delta_ijr_a)
# For each node sum the contributions of the bins whose starting node is this one
density.scatter_add_(0, r_star, delta_ijr_b)
return density / size

def get_default_distance(self):
return CosineSimilarity()
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/utils/accuracy_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def get_accuracy(
):
raise ValueError(
"When ref_includes_query is True, the first len(query) elements of reference must be equal to query.\n"
"Likewise, the first len(query_labels) elements of reference_lbels must be equal to query_labels.\n"
"Likewise, the first len(query_labels) elements of reference_labels must be equal to query_labels.\n"
)

self.curr_function_dict = self.get_function_dict(include, exclude)
Expand Down
170 changes: 170 additions & 0 deletions tests/losses/test_histogram_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import unittest

import torch
from numpy.testing import assert_almost_equal

from pytorch_metric_learning.losses import HistogramLoss
from pytorch_metric_learning.utils import common_functions as c_f

from .. import TEST_DEVICE, TEST_DTYPES


######################################
#######ORIGINAL IMPLEMENTATION########
######################################
# DIRECTLY COPIED from https://github.com/valerystrizh/pytorch-histogram-loss/blob/master/losses.py.
# This code is copied from the official PyTorch implementation
# so that we can make sure our implementation returns the same result.
# Some minor changes were made to avoid errors during testing.
# Every change in the original code is reported and explained.
class OriginalImplementationHistogramLoss(torch.nn.Module):
def __init__(self, num_steps, cuda=True):
super(OriginalImplementationHistogramLoss, self).__init__()
self.step = 2 / (num_steps - 1)
self.eps = 1 / num_steps
self.cuda = cuda
self.t = torch.arange(-1, 1 + self.step, self.step).view(-1, 1)
self.tsize = self.t.size()[0]
if self.cuda:
self.t = self.t.cuda()

def forward(self, features, classes):
def histogram(inds, size):
s_repeat_ = s_repeat.clone()
inds = c_f.to_device(inds, tensor=s_repeat_floor) # Added to avoid errors
self.t = c_f.to_device(
self.t, tensor=s_repeat_floor
) # Added to avoid errors
indsa = (
(s_repeat_floor - (self.t - self.step) > -self.eps)
& (s_repeat_floor - (self.t - self.step) < self.eps)
& inds
)
assert (
indsa.nonzero().size()[0] == size
), "Another number of bins should be used"
zeros = torch.zeros((1, indsa.size()[1])).to(
device=indsa.device, dtype=torch.uint8
)
if self.cuda:
zeros = zeros.cuda()
indsb = torch.cat((indsa, zeros))[1:, :].to(
dtype=torch.bool
) # Added to avoid bug with masks of uint8
s_repeat_[~(indsb | indsa)] = 0
# indsa corresponds to the first condition of the second equation of the paper
self.t = self.t.to(
dtype=s_repeat_.dtype
) # Added to avoid errors when using Half precision
s_repeat_[indsa] = (s_repeat_ - self.t + self.step)[indsa] / self.step
# indsb corresponds to the second condition of the second equation of the paper
s_repeat_[indsb] = (-s_repeat_ + self.t + self.step)[indsb] / self.step

return s_repeat_.sum(1) / size

classes_size = classes.size()[0]
classes_eq = (
classes.repeat(classes_size, 1)
== classes.view(-1, 1).repeat(1, classes_size)
).data
dists = torch.mm(features, features.transpose(0, 1))
assert (
(dists > 1 + self.eps).sum().item() + (dists < -1 - self.eps).sum().item()
) == 0, "L2 normalization should be used"
s_inds = torch.triu(torch.ones(classes_eq.size()), 1).byte()
if self.cuda:
s_inds = s_inds.cuda()
classes_eq = classes_eq.to(
device=s_inds.device
) # Added to avoid errors when using only cpu
pos_inds = classes_eq[s_inds].repeat(self.tsize, 1)
neg_inds = ~classes_eq[s_inds].repeat(self.tsize, 1)
pos_size = classes_eq[s_inds].sum().item()
neg_size = (~classes_eq[s_inds]).sum().item()
s = dists[s_inds].view(1, -1)
s_repeat = s.repeat(self.tsize, 1)
s_repeat_floor = (torch.floor(s_repeat.data / self.step) * self.step).float()

histogram_pos = histogram(pos_inds, pos_size)
assert_almost_equal(
histogram_pos.sum().item(),
1,
decimal=1,
err_msg="Not good positive histogram",
verbose=True,
)
histogram_neg = histogram(neg_inds, neg_size)
assert_almost_equal(
histogram_neg.sum().item(),
1,
decimal=1,
err_msg="Not good negative histogram",
verbose=True,
)
histogram_pos_repeat = histogram_pos.view(-1, 1).repeat(
1, histogram_pos.size()[0]
)
histogram_pos_inds = torch.tril(
torch.ones(histogram_pos_repeat.size()), -1
).byte()
if self.cuda:
histogram_pos_inds = histogram_pos_inds.cuda()
histogram_pos_repeat[histogram_pos_inds] = 0
histogram_pos_cdf = histogram_pos_repeat.sum(0)
loss = torch.sum(histogram_neg * histogram_pos_cdf)

return loss


class TestHistogramLoss(unittest.TestCase):
def test_histogram_loss(self):
batch_size = 32
embedding_size = 64
for dtype in TEST_DTYPES:
num_steps = 5 if dtype == torch.float16 else 21
num_bins = num_steps - 1
loss_func = HistogramLoss(n_bins=num_bins)
original_loss_func = OriginalImplementationHistogramLoss(
num_steps=num_steps, cuda=False
)

# test multiple times
for _ in range(2):
embeddings = torch.randn(
batch_size,
embedding_size,
requires_grad=True,
dtype=dtype,
).to(TEST_DEVICE)
labels = torch.randint(0, 5, size=(batch_size,))

loss = loss_func(embeddings, labels)
correct_loss = original_loss_func(
torch.nn.functional.normalize(embeddings), labels
)

rtol = 1e-2 if dtype == torch.float16 else 1e-5
self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol))

loss.backward()

def test_with_no_valid_triplets(self):
loss_func = HistogramLoss(n_bins=4)
for dtype in TEST_DTYPES:
embeddings = torch.randn(
5,
32,
requires_grad=True,
dtype=dtype,
).to(TEST_DEVICE)
labels = torch.LongTensor([0, 1, 2, 3, 4])
loss = loss_func(embeddings, labels)
self.assertEqual(loss, 0)
loss.backward()

def test_assertion_raises(self):
with self.assertRaises(AssertionError):
_ = HistogramLoss(n_bins=1, delta=0.5)

with self.assertRaises(AssertionError):
_ = HistogramLoss(n_bins=10, delta=0.4)

0 comments on commit bf8f2ab

Please sign in to comment.