Skip to content

Commit

Permalink
Merge pull request #689 from mkmenta/remove-big-tensor-nonzero
Browse files Browse the repository at this point in the history
Fixes the "nonzero is not supported for tensors with more than INT_MAX elements" in get_all_triplets_indices
  • Loading branch information
KevinMusgrave committed Apr 1, 2024
2 parents dd40036 + b5383c9 commit cfafd3b
Showing 1 changed file with 35 additions and 3 deletions.
38 changes: 35 additions & 3 deletions src/pytorch_metric_learning/utils/loss_and_miner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,41 @@ def neg_pairs_from_tuple(indices_tuple):


def get_all_triplets_indices(labels, ref_labels=None):
matches, diffs = get_matches_and_diffs(labels, ref_labels)
triplets = matches.unsqueeze(2) * diffs.unsqueeze(1)
return torch.where(triplets)
all_matches, all_diffs = get_matches_and_diffs(labels, ref_labels)

if (all_matches.shape[0] * all_matches.shape[1] * all_matches.shape[1]
< torch.iinfo(torch.int32).max):
# torch.nonzero is not supported for tensors with more than INT_MAX elements
triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)
return torch.where(triplets)

all_matches, all_diffs = all_matches.bool(), all_diffs.bool()

# Find anchors with at least a positive and a negative
indices = torch.arange(0, len(labels), device=labels.device)
indices = indices[all_matches.any(dim=1) & all_diffs.any(dim=1)]

# No triplets found
if len(indices) == 0:
return (torch.tensor([], device=labels.device, dtype=labels.dtype),
torch.tensor([], device=labels.device, dtype=labels.dtype),
torch.tensor([], device=labels.device, dtype=labels.dtype))

# Compute all triplets
anchors = []
positives = []
negatives = []
for i in indices:
matches = all_matches[i].nonzero(as_tuple=False).squeeze(1)
diffs = all_diffs[i].nonzero(as_tuple=False).squeeze(1)
nd = len(diffs)
nm = len(matches)
matches = matches.repeat_interleave(nd)
diffs = diffs.repeat(nm)
anchors.append(torch.full((len(matches),), i, dtype=labels.dtype, device=labels.device))
positives.append(matches)
negatives.append(diffs)
return torch.cat(anchors), torch.cat(positives), torch.cat(negatives)


# sample triplets, with a weighted distribution if weights is specified.
Expand Down

0 comments on commit cfafd3b

Please sign in to comment.