-
Notifications
You must be signed in to change notification settings - Fork 70
/
contrastive_head.py
38 lines (30 loc) · 1.04 KB
/
contrastive_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn as nn
from ..registry import HEADS
@HEADS.register_module
class ContrastiveHead(nn.Module):
"""Head for contrastive learning.
Args:
temperature (float): The temperature hyper-parameter that
controls the concentration level of the distribution.
Default: 0.1.
"""
def __init__(self, temperature=0.1):
super(ContrastiveHead, self).__init__()
self.criterion = nn.CrossEntropyLoss()
self.temperature = temperature
def forward(self, pos, neg):
"""Forward head.
Args:
pos (Tensor): Nx1 positive similarity.
neg (Tensor): Nxk negative similarity.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
N = pos.size(0)
logits = torch.cat((pos, neg), dim=1)
logits /= self.temperature
labels = torch.zeros((N, ), dtype=torch.long).cuda()
losses = dict()
losses['loss_contra'] = self.criterion(logits, labels)
return losses