-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
28 lines (22 loc) · 1.01 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn.functional as F
class LossDiscriminative(torch.nn.Module):
def __init__(self):
super(LossDiscriminative, self).__init__()
def forward(self, output1, output2, label):
euclidean_distance = F.pairwise_distance(output1, output2)
label = label.squeeze()
p = 2 / (1 + torch.exp(2 * euclidean_distance))
loss_discriminative = torch.sum(
-1 / 2 * label * (label + 1) * torch.log(p) - 1 / 2 * label * (label - 1) * torch.log(1 - p))
return loss_discriminative
class LossGenerative(torch.nn.Module):
def __init__(self):
super(LossGenerative, self).__init__()
def forward(self, input1, input2, output1, output2):
input1 = input1.flatten(start_dim=1)
input2 = input2.flatten(start_dim=1)
output1 = output1.flatten(start_dim=1)
output2 = output2.flatten(start_dim=1)
distances = F.pairwise_distance(input1, output1) + F.pairwise_distance(input2, output2)
return distances.sum()