Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2.3.0 #653

Merged
merged 11 commits into from
Jul 25, 2023
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dist/
*.egg-info/
site/
venv/
**/.vscode
.ipynb_checkpoints
examples/notebooks/dataset
examples/notebooks/CIFAR10_Dataset
Expand Down
1 change: 1 addition & 0 deletions CONTENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
| [**CosFaceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#cosfaceloss) | - [CosFace: Large Margin Cosine Loss for Deep Face Recognition](https://arxiv.org/pdf/1801.09414.pdf) <br/> - [Additive Margin Softmax for Face Verification](https://arxiv.org/pdf/1801.05599.pdf)
| [**FastAPLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#fastaploss) | [Deep Metric Learning to Rank](http://openaccess.thecvf.com/content_CVPR_2019/papers/Cakir_Deep_Metric_Learning_to_Rank_CVPR_2019_paper.pdf)
| [**GeneralizedLiftedStructureLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#generalizedliftedstructureloss) | [In Defense of the Triplet Loss for Person Re-Identification](https://arxiv.org/pdf/1703.07737.pdf)
| [**HistogramLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss) | [Learning Deep Embeddings with Histogram Loss](https://arxiv.org/pdf/1611.00822.pdf)
| [**InstanceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#instanceloss) | [Dual-Path Convolutional Image-Text Embeddings with Instance Loss](https://arxiv.org/pdf/1711.05535.pdf)
| [**IntraPairVarianceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#intrapairvarianceloss) | [Deep Metric Learning with Tuplet Margin Loss](http://openaccess.thecvf.com/content_ICCV_2019/papers/Yu_Deep_Metric_Learning_With_Tuplet_Margin_Loss_ICCV_2019_paper.pdf)
| [**LargeMarginSoftmaxLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#largemarginsoftmaxloss) | [Large-Margin Softmax Loss for Convolutional Neural Networks](https://arxiv.org/pdf/1612.02295.pdf)
Expand Down
20 changes: 20 additions & 0 deletions docs/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,26 @@ losses.InstanceLoss(gamma=64, **kwargs)
* **gamma**: The cosine similarity matrix is scaled by this amount.


## HistogramLoss
[Learning Deep Embeddings with Histogram Loss](https://arxiv.org/pdf/1611.00822.pdf)
```python
losses.HistogramLoss(n_bins=None, delta=None)
```

**Parameters**:

* **n_bins**: The number of bins used to construct the histogram. Default is 100 when both `n_bins` and `delta` are `None`.
* **delta**: The mesh of the uniform partition of the interval [-1, 1] used to construct the histogram. If not set the value of n_bins will be used.

**Default distance**:

- [```CosineSimilarity()```](distances.md#cosinesimilarity)

**Default reducer**:

- This loss returns an **already reduced** loss.


## IntraPairVarianceLoss
[Deep Metric Learning with Tuplet Margin Loss](http://openaccess.thecvf.com/content_ICCV_2019/papers/Yu_Deep_Metric_Learning_With_Tuplet_Margin_Loss_ICCV_2019_paper.pdf){target=_blank}
```python
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.2.0"
__version__ = "2.3.0"
1 change: 1 addition & 0 deletions src/pytorch_metric_learning/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .cross_batch_memory import CrossBatchMemory
from .fast_ap_loss import FastAPLoss
from .generic_pair_loss import GenericPairLoss
from .histogram_loss import HistogramLoss
from .instance_loss import InstanceLoss
from .intra_pair_variance_loss import IntraPairVarianceLoss
from .large_margin_softmax_loss import LargeMarginSoftmaxLoss
Expand Down
80 changes: 80 additions & 0 deletions src/pytorch_metric_learning/losses/histogram_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch

from ..distances import CosineSimilarity
from ..utils import common_functions as c_f
from ..utils import loss_and_miner_utils as lmu
from .base_metric_loss_function import BaseMetricLossFunction


def filter_pairs(*tensors: torch.Tensor):
t = torch.stack(tensors)
t, _ = torch.sort(t, dim=0)
t = torch.unique(t, dim=1)
return t.tolist()


class HistogramLoss(BaseMetricLossFunction):
def __init__(self, n_bins: int = None, delta: float = None, **kwargs):
super().__init__(**kwargs)
if delta is not None and n_bins is not None:
assert (
delta == 2 / n_bins
), f"delta and n_bins must satisfy the equation delta = 2/n_bins.\nPassed values are delta={delta} and n_bins={n_bins}"

if delta is None and n_bins is None:
n_bins = 100

self.delta = delta if delta is not None else 2 / n_bins
self.add_to_recordable_attributes(name="delta", is_stat=True)

def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
c_f.labels_or_indices_tuple_required(labels, indices_tuple)
c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)
indices_tuple = lmu.convert_to_triplets(
indices_tuple, labels, ref_labels, t_per_anchor="all"
)
anchor_idx, positive_idx, negative_idx = indices_tuple
if len(anchor_idx) == 0:
return self.zero_losses()
mat = self.distance(embeddings, ref_emb)

anchor_positive_idx = filter_pairs(anchor_idx, positive_idx)
anchor_negative_idx = filter_pairs(anchor_idx, negative_idx)
ap_dists = mat[anchor_positive_idx]
an_dists = mat[anchor_negative_idx]

p_pos = self.compute_density(ap_dists)
phi = torch.cumsum(p_pos, dim=0)

p_neg = self.compute_density(an_dists)
return {
"loss": {
"losses": torch.sum(p_neg * phi),
"indices": None,
"reduction_type": "already_reduced",
}
}

def compute_density(self, distances):
size = distances.size(0)
r_star = torch.floor(
(distances.float() + 1) / self.delta
) # Indices of the bins containing the values of the distances
r_star = c_f.to_device(r_star, tensor=distances, dtype=torch.long)

delta_ijr_a = (distances + 1 - r_star * self.delta) / self.delta
delta_ijr_b = ((r_star + 1) * self.delta - 1 - distances) / self.delta
delta_ijr_a = c_f.to_dtype(delta_ijr_a, tensor=distances)
delta_ijr_b = c_f.to_dtype(delta_ijr_b, tensor=distances)

density = torch.zeros(round(1 + 2 / self.delta))
density = c_f.to_device(density, tensor=distances, dtype=distances.dtype)

# For each node sum the contributions of the bins whose ending node is this one
density.scatter_add_(0, r_star + 1, delta_ijr_a)
# For each node sum the contributions of the bins whose starting node is this one
density.scatter_add_(0, r_star, delta_ijr_b)
return density / size

def get_default_distance(self):
return CosineSimilarity()
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/utils/accuracy_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def get_accuracy(
):
raise ValueError(
"When ref_includes_query is True, the first len(query) elements of reference must be equal to query.\n"
"Likewise, the first len(query_labels) elements of reference_lbels must be equal to query_labels.\n"
"Likewise, the first len(query_labels) elements of reference_labels must be equal to query_labels.\n"
)

self.curr_function_dict = self.get_function_dict(include, exclude)
Expand Down
170 changes: 170 additions & 0 deletions tests/losses/test_histogram_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import unittest

import torch
from numpy.testing import assert_almost_equal

from pytorch_metric_learning.losses import HistogramLoss
from pytorch_metric_learning.utils import common_functions as c_f

from .. import TEST_DEVICE, TEST_DTYPES


######################################
#######ORIGINAL IMPLEMENTATION########
######################################
# DIRECTLY COPIED from https://github.com/valerystrizh/pytorch-histogram-loss/blob/master/losses.py.
# This code is copied from the official PyTorch implementation
# so that we can make sure our implementation returns the same result.
# Some minor changes were made to avoid errors during testing.
# Every change in the original code is reported and explained.
class OriginalImplementationHistogramLoss(torch.nn.Module):
def __init__(self, num_steps, cuda=True):
super(OriginalImplementationHistogramLoss, self).__init__()
self.step = 2 / (num_steps - 1)
self.eps = 1 / num_steps
self.cuda = cuda
self.t = torch.arange(-1, 1 + self.step, self.step).view(-1, 1)
self.tsize = self.t.size()[0]
if self.cuda:
self.t = self.t.cuda()

def forward(self, features, classes):
def histogram(inds, size):
s_repeat_ = s_repeat.clone()
inds = c_f.to_device(inds, tensor=s_repeat_floor) # Added to avoid errors
self.t = c_f.to_device(
self.t, tensor=s_repeat_floor
) # Added to avoid errors
indsa = (
(s_repeat_floor - (self.t - self.step) > -self.eps)
& (s_repeat_floor - (self.t - self.step) < self.eps)
& inds
)
assert (
indsa.nonzero().size()[0] == size
), "Another number of bins should be used"
zeros = torch.zeros((1, indsa.size()[1])).to(
device=indsa.device, dtype=torch.uint8
)
if self.cuda:
zeros = zeros.cuda()
indsb = torch.cat((indsa, zeros))[1:, :].to(
dtype=torch.bool
) # Added to avoid bug with masks of uint8
s_repeat_[~(indsb | indsa)] = 0
# indsa corresponds to the first condition of the second equation of the paper
self.t = self.t.to(
dtype=s_repeat_.dtype
) # Added to avoid errors when using Half precision
s_repeat_[indsa] = (s_repeat_ - self.t + self.step)[indsa] / self.step
# indsb corresponds to the second condition of the second equation of the paper
s_repeat_[indsb] = (-s_repeat_ + self.t + self.step)[indsb] / self.step

return s_repeat_.sum(1) / size

classes_size = classes.size()[0]
classes_eq = (
classes.repeat(classes_size, 1)
== classes.view(-1, 1).repeat(1, classes_size)
).data
dists = torch.mm(features, features.transpose(0, 1))
assert (
(dists > 1 + self.eps).sum().item() + (dists < -1 - self.eps).sum().item()
) == 0, "L2 normalization should be used"
s_inds = torch.triu(torch.ones(classes_eq.size()), 1).byte()
if self.cuda:
s_inds = s_inds.cuda()
classes_eq = classes_eq.to(
device=s_inds.device
) # Added to avoid errors when using only cpu
pos_inds = classes_eq[s_inds].repeat(self.tsize, 1)
neg_inds = ~classes_eq[s_inds].repeat(self.tsize, 1)
pos_size = classes_eq[s_inds].sum().item()
neg_size = (~classes_eq[s_inds]).sum().item()
s = dists[s_inds].view(1, -1)
s_repeat = s.repeat(self.tsize, 1)
s_repeat_floor = (torch.floor(s_repeat.data / self.step) * self.step).float()

histogram_pos = histogram(pos_inds, pos_size)
assert_almost_equal(
histogram_pos.sum().item(),
1,
decimal=1,
err_msg="Not good positive histogram",
verbose=True,
)
histogram_neg = histogram(neg_inds, neg_size)
assert_almost_equal(
histogram_neg.sum().item(),
1,
decimal=1,
err_msg="Not good negative histogram",
verbose=True,
)
histogram_pos_repeat = histogram_pos.view(-1, 1).repeat(
1, histogram_pos.size()[0]
)
histogram_pos_inds = torch.tril(
torch.ones(histogram_pos_repeat.size()), -1
).byte()
if self.cuda:
histogram_pos_inds = histogram_pos_inds.cuda()
histogram_pos_repeat[histogram_pos_inds] = 0
histogram_pos_cdf = histogram_pos_repeat.sum(0)
loss = torch.sum(histogram_neg * histogram_pos_cdf)

return loss


class TestHistogramLoss(unittest.TestCase):
def test_histogram_loss(self):
batch_size = 32
embedding_size = 64
for dtype in TEST_DTYPES:
num_steps = 5 if dtype == torch.float16 else 21
num_bins = num_steps - 1
loss_func = HistogramLoss(n_bins=num_bins)
original_loss_func = OriginalImplementationHistogramLoss(
num_steps=num_steps, cuda=False
)

# test multiple times
for _ in range(2):
embeddings = torch.randn(
batch_size,
embedding_size,
requires_grad=True,
dtype=dtype,
).to(TEST_DEVICE)
labels = torch.randint(0, 5, size=(batch_size,))

loss = loss_func(embeddings, labels)
correct_loss = original_loss_func(
torch.nn.functional.normalize(embeddings), labels
)

rtol = 1e-2 if dtype == torch.float16 else 1e-5
self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol))

loss.backward()

def test_with_no_valid_triplets(self):
loss_func = HistogramLoss(n_bins=4)
for dtype in TEST_DTYPES:
embeddings = torch.randn(
5,
32,
requires_grad=True,
dtype=dtype,
).to(TEST_DEVICE)
labels = torch.LongTensor([0, 1, 2, 3, 4])
loss = loss_func(embeddings, labels)
self.assertEqual(loss, 0)
loss.backward()

def test_assertion_raises(self):
with self.assertRaises(AssertionError):
_ = HistogramLoss(n_bins=1, delta=0.5)

with self.assertRaises(AssertionError):
_ = HistogramLoss(n_bins=10, delta=0.4)
Loading