-
Notifications
You must be signed in to change notification settings - Fork 1
/
Loss.py
87 lines (64 loc) · 2.96 KB
/
Loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from abc import ABC, abstractmethod
import numpy as np
import torch.nn.functional as F
class Loss(ABC):
@abstractmethod
def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs) -> torch.FloatTensor:
pass
def __call__(self, anchor, sample, pos_mask=None, neg_mask=None, *args, **kwargs) -> torch.FloatTensor:
loss = self.compute(anchor, sample, pos_mask, neg_mask, *args, **kwargs)
return loss
class JSD(Loss):
def __init__(self, discriminator=lambda x, y: x @ y.t()):
super(JSD, self).__init__()
self.discriminator = discriminator
def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs):
num_neg = neg_mask.int().sum()
num_pos = pos_mask.int().sum()
similarity = self.discriminator(anchor, sample)
E_pos = (np.log(2) - F.softplus(- similarity * pos_mask)).sum()
E_pos /= num_pos
neg_sim = similarity * neg_mask
E_neg = (F.softplus(- neg_sim) + neg_sim - np.log(2)).sum()
E_neg /= num_neg
return E_neg - E_pos
class DebiasedJSD(Loss):
def __init__(self, discriminator=lambda x, y: x @ y.t(), tau_plus=0.1):
super(DebiasedJSD, self).__init__()
self.discriminator = discriminator
self.tau_plus = tau_plus
def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs):
num_neg = neg_mask.int().sum()
num_pos = pos_mask.int().sum()
similarity = self.discriminator(anchor, sample)
pos_sim = similarity * pos_mask
E_pos = np.log(2) - F.softplus(- pos_sim)
E_pos -= (self.tau_plus / (1 - self.tau_plus)) * (F.softplus(-pos_sim) + pos_sim)
E_pos = E_pos.sum() / num_pos
neg_sim = similarity * neg_mask
E_neg = (F.softplus(- neg_sim) + neg_sim - np.log(2)) / (1 - self.tau_plus)
E_neg = E_neg.sum() / num_neg
return E_neg - E_pos
class HardnessJSD(Loss):
def __init__(self, discriminator=lambda x, y: x @ y.t(), tau_plus=0.1, beta=0.05):
super(HardnessJSD, self).__init__()
self.discriminator = discriminator
self.tau_plus = tau_plus
self.beta = beta
def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs):
num_neg = neg_mask.int().sum()
num_pos = pos_mask.int().sum()
similarity = self.discriminator(anchor, sample)
pos_sim = similarity * pos_mask
E_pos = np.log(2) - F.softplus(- pos_sim)
E_pos -= (self.tau_plus / (1 - self.tau_plus)) * (F.softplus(-pos_sim) + pos_sim)
E_pos = E_pos.sum() / num_pos
neg_sim = similarity * neg_mask
E_neg = F.softplus(- neg_sim) + neg_sim
reweight = -2 * neg_sim / max(neg_sim.max(), neg_sim.min().abs())
reweight = (self.beta * reweight).exp()
reweight /= reweight.mean(dim=1, keepdim=True)
E_neg = (reweight * E_neg) / (1 - self.tau_plus) - np.log(2)
E_neg = E_neg.sum() / num_neg
return E_neg - E_pos