diff --git a/losses.py b/losses.py index 8ab79c21..630935da 100644 --- a/losses.py +++ b/losses.py @@ -12,11 +12,12 @@ class ContrastiveLoss(nn.Module): def __init__(self, margin): super(ContrastiveLoss, self).__init__() self.margin = margin + self.eps = 1e-9 def forward(self, output1, output2, target, size_average=True): distances = (output2 - output1).pow(2).sum(1) # squared distances losses = 0.5 * (target.float() * distances + - (1 + -1 * target).float() * F.relu(self.margin - distances.sqrt()).pow(2)) + (1 + -1 * target).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2)) return losses.mean() if size_average else losses.sum()