-
Notifications
You must be signed in to change notification settings - Fork 658
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #653 from KevinMusgrave/dev
v2.3.0
- Loading branch information
Showing
8 changed files
with
275 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "2.2.0" | ||
__version__ = "2.3.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import torch | ||
|
||
from ..distances import CosineSimilarity | ||
from ..utils import common_functions as c_f | ||
from ..utils import loss_and_miner_utils as lmu | ||
from .base_metric_loss_function import BaseMetricLossFunction | ||
|
||
|
||
def filter_pairs(*tensors: torch.Tensor): | ||
t = torch.stack(tensors) | ||
t, _ = torch.sort(t, dim=0) | ||
t = torch.unique(t, dim=1) | ||
return t.tolist() | ||
|
||
|
||
class HistogramLoss(BaseMetricLossFunction): | ||
def __init__(self, n_bins: int = None, delta: float = None, **kwargs): | ||
super().__init__(**kwargs) | ||
if delta is not None and n_bins is not None: | ||
assert ( | ||
delta == 2 / n_bins | ||
), f"delta and n_bins must satisfy the equation delta = 2/n_bins.\nPassed values are delta={delta} and n_bins={n_bins}" | ||
|
||
if delta is None and n_bins is None: | ||
n_bins = 100 | ||
|
||
self.delta = delta if delta is not None else 2 / n_bins | ||
self.add_to_recordable_attributes(name="delta", is_stat=True) | ||
|
||
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): | ||
c_f.labels_or_indices_tuple_required(labels, indices_tuple) | ||
c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels) | ||
indices_tuple = lmu.convert_to_triplets( | ||
indices_tuple, labels, ref_labels, t_per_anchor="all" | ||
) | ||
anchor_idx, positive_idx, negative_idx = indices_tuple | ||
if len(anchor_idx) == 0: | ||
return self.zero_losses() | ||
mat = self.distance(embeddings, ref_emb) | ||
|
||
anchor_positive_idx = filter_pairs(anchor_idx, positive_idx) | ||
anchor_negative_idx = filter_pairs(anchor_idx, negative_idx) | ||
ap_dists = mat[anchor_positive_idx] | ||
an_dists = mat[anchor_negative_idx] | ||
|
||
p_pos = self.compute_density(ap_dists) | ||
phi = torch.cumsum(p_pos, dim=0) | ||
|
||
p_neg = self.compute_density(an_dists) | ||
return { | ||
"loss": { | ||
"losses": torch.sum(p_neg * phi), | ||
"indices": None, | ||
"reduction_type": "already_reduced", | ||
} | ||
} | ||
|
||
def compute_density(self, distances): | ||
size = distances.size(0) | ||
r_star = torch.floor( | ||
(distances.float() + 1) / self.delta | ||
) # Indices of the bins containing the values of the distances | ||
r_star = c_f.to_device(r_star, tensor=distances, dtype=torch.long) | ||
|
||
delta_ijr_a = (distances + 1 - r_star * self.delta) / self.delta | ||
delta_ijr_b = ((r_star + 1) * self.delta - 1 - distances) / self.delta | ||
delta_ijr_a = c_f.to_dtype(delta_ijr_a, tensor=distances) | ||
delta_ijr_b = c_f.to_dtype(delta_ijr_b, tensor=distances) | ||
|
||
density = torch.zeros(round(1 + 2 / self.delta)) | ||
density = c_f.to_device(density, tensor=distances, dtype=distances.dtype) | ||
|
||
# For each node sum the contributions of the bins whose ending node is this one | ||
density.scatter_add_(0, r_star + 1, delta_ijr_a) | ||
# For each node sum the contributions of the bins whose starting node is this one | ||
density.scatter_add_(0, r_star, delta_ijr_b) | ||
return density / size | ||
|
||
def get_default_distance(self): | ||
return CosineSimilarity() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import unittest | ||
|
||
import torch | ||
from numpy.testing import assert_almost_equal | ||
|
||
from pytorch_metric_learning.losses import HistogramLoss | ||
from pytorch_metric_learning.utils import common_functions as c_f | ||
|
||
from .. import TEST_DEVICE, TEST_DTYPES | ||
|
||
|
||
###################################### | ||
#######ORIGINAL IMPLEMENTATION######## | ||
###################################### | ||
# DIRECTLY COPIED from https://github.com/valerystrizh/pytorch-histogram-loss/blob/master/losses.py. | ||
# This code is copied from the official PyTorch implementation | ||
# so that we can make sure our implementation returns the same result. | ||
# Some minor changes were made to avoid errors during testing. | ||
# Every change in the original code is reported and explained. | ||
class OriginalImplementationHistogramLoss(torch.nn.Module): | ||
def __init__(self, num_steps, cuda=True): | ||
super(OriginalImplementationHistogramLoss, self).__init__() | ||
self.step = 2 / (num_steps - 1) | ||
self.eps = 1 / num_steps | ||
self.cuda = cuda | ||
self.t = torch.arange(-1, 1 + self.step, self.step).view(-1, 1) | ||
self.tsize = self.t.size()[0] | ||
if self.cuda: | ||
self.t = self.t.cuda() | ||
|
||
def forward(self, features, classes): | ||
def histogram(inds, size): | ||
s_repeat_ = s_repeat.clone() | ||
inds = c_f.to_device(inds, tensor=s_repeat_floor) # Added to avoid errors | ||
self.t = c_f.to_device( | ||
self.t, tensor=s_repeat_floor | ||
) # Added to avoid errors | ||
indsa = ( | ||
(s_repeat_floor - (self.t - self.step) > -self.eps) | ||
& (s_repeat_floor - (self.t - self.step) < self.eps) | ||
& inds | ||
) | ||
assert ( | ||
indsa.nonzero().size()[0] == size | ||
), "Another number of bins should be used" | ||
zeros = torch.zeros((1, indsa.size()[1])).to( | ||
device=indsa.device, dtype=torch.uint8 | ||
) | ||
if self.cuda: | ||
zeros = zeros.cuda() | ||
indsb = torch.cat((indsa, zeros))[1:, :].to( | ||
dtype=torch.bool | ||
) # Added to avoid bug with masks of uint8 | ||
s_repeat_[~(indsb | indsa)] = 0 | ||
# indsa corresponds to the first condition of the second equation of the paper | ||
self.t = self.t.to( | ||
dtype=s_repeat_.dtype | ||
) # Added to avoid errors when using Half precision | ||
s_repeat_[indsa] = (s_repeat_ - self.t + self.step)[indsa] / self.step | ||
# indsb corresponds to the second condition of the second equation of the paper | ||
s_repeat_[indsb] = (-s_repeat_ + self.t + self.step)[indsb] / self.step | ||
|
||
return s_repeat_.sum(1) / size | ||
|
||
classes_size = classes.size()[0] | ||
classes_eq = ( | ||
classes.repeat(classes_size, 1) | ||
== classes.view(-1, 1).repeat(1, classes_size) | ||
).data | ||
dists = torch.mm(features, features.transpose(0, 1)) | ||
assert ( | ||
(dists > 1 + self.eps).sum().item() + (dists < -1 - self.eps).sum().item() | ||
) == 0, "L2 normalization should be used" | ||
s_inds = torch.triu(torch.ones(classes_eq.size()), 1).byte() | ||
if self.cuda: | ||
s_inds = s_inds.cuda() | ||
classes_eq = classes_eq.to( | ||
device=s_inds.device | ||
) # Added to avoid errors when using only cpu | ||
pos_inds = classes_eq[s_inds].repeat(self.tsize, 1) | ||
neg_inds = ~classes_eq[s_inds].repeat(self.tsize, 1) | ||
pos_size = classes_eq[s_inds].sum().item() | ||
neg_size = (~classes_eq[s_inds]).sum().item() | ||
s = dists[s_inds].view(1, -1) | ||
s_repeat = s.repeat(self.tsize, 1) | ||
s_repeat_floor = (torch.floor(s_repeat.data / self.step) * self.step).float() | ||
|
||
histogram_pos = histogram(pos_inds, pos_size) | ||
assert_almost_equal( | ||
histogram_pos.sum().item(), | ||
1, | ||
decimal=1, | ||
err_msg="Not good positive histogram", | ||
verbose=True, | ||
) | ||
histogram_neg = histogram(neg_inds, neg_size) | ||
assert_almost_equal( | ||
histogram_neg.sum().item(), | ||
1, | ||
decimal=1, | ||
err_msg="Not good negative histogram", | ||
verbose=True, | ||
) | ||
histogram_pos_repeat = histogram_pos.view(-1, 1).repeat( | ||
1, histogram_pos.size()[0] | ||
) | ||
histogram_pos_inds = torch.tril( | ||
torch.ones(histogram_pos_repeat.size()), -1 | ||
).byte() | ||
if self.cuda: | ||
histogram_pos_inds = histogram_pos_inds.cuda() | ||
histogram_pos_repeat[histogram_pos_inds] = 0 | ||
histogram_pos_cdf = histogram_pos_repeat.sum(0) | ||
loss = torch.sum(histogram_neg * histogram_pos_cdf) | ||
|
||
return loss | ||
|
||
|
||
class TestHistogramLoss(unittest.TestCase): | ||
def test_histogram_loss(self): | ||
batch_size = 32 | ||
embedding_size = 64 | ||
for dtype in TEST_DTYPES: | ||
num_steps = 5 if dtype == torch.float16 else 21 | ||
num_bins = num_steps - 1 | ||
loss_func = HistogramLoss(n_bins=num_bins) | ||
original_loss_func = OriginalImplementationHistogramLoss( | ||
num_steps=num_steps, cuda=False | ||
) | ||
|
||
# test multiple times | ||
for _ in range(2): | ||
embeddings = torch.randn( | ||
batch_size, | ||
embedding_size, | ||
requires_grad=True, | ||
dtype=dtype, | ||
).to(TEST_DEVICE) | ||
labels = torch.randint(0, 5, size=(batch_size,)) | ||
|
||
loss = loss_func(embeddings, labels) | ||
correct_loss = original_loss_func( | ||
torch.nn.functional.normalize(embeddings), labels | ||
) | ||
|
||
rtol = 1e-2 if dtype == torch.float16 else 1e-5 | ||
self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol)) | ||
|
||
loss.backward() | ||
|
||
def test_with_no_valid_triplets(self): | ||
loss_func = HistogramLoss(n_bins=4) | ||
for dtype in TEST_DTYPES: | ||
embeddings = torch.randn( | ||
5, | ||
32, | ||
requires_grad=True, | ||
dtype=dtype, | ||
).to(TEST_DEVICE) | ||
labels = torch.LongTensor([0, 1, 2, 3, 4]) | ||
loss = loss_func(embeddings, labels) | ||
self.assertEqual(loss, 0) | ||
loss.backward() | ||
|
||
def test_assertion_raises(self): | ||
with self.assertRaises(AssertionError): | ||
_ = HistogramLoss(n_bins=1, delta=0.5) | ||
|
||
with self.assertRaises(AssertionError): | ||
_ = HistogramLoss(n_bins=10, delta=0.4) |