In [None]:
class AttentionNCE(nn.Module):
    def __init__(self, n_anchors, dim, n_pos, d_pos=1., d_neg=1., temperature=0.5):
        super(AttentionNCE, self).__init__()
        self.n_anchors = n_anchors
        self.dim = dim
        self.n_pos = n_pos
        self.d_pos = d_pos
        self.d_neg = d_neg
        self.temperature = temperature
        

    def forward(self, q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
        assert q.shape == (self.n_anchors, self.dim)
        assert k.shape == (self.n_anchors, self.n_pos, self.dim)
        
        n_anchors = self.n_anchors
        n_pos = self.n_pos
        dim = self.dim
        
        
        scores = torch.matmul(q, k.reshape(-1, dim).t()) # (n_anchors, n_anchors*n_pos)
        # print(scores.mean())

        pos_mask = torch.zeros((n_anchors, n_anchors*n_pos), dtype=torch.bool, device=device)
        for i in range(n_anchors):
            pos_mask[i, i*n_pos:(i+1)*n_pos] = 1.
        neg_mask = ~pos_mask

        pos_scores = scores[pos_mask].reshape(n_anchors, n_pos)    # (n_anchors, n_pos)
        neg_scores = scores[neg_mask].reshape(n_anchors, n_anchors*n_pos - n_pos) # (n_anchors, n_anchors*n_pos - n_pos)


        pos_scores = torch.softmax(pos_scores / self.d_pos, dim=-1) # (n_anchors, n_pos)
        neg_scores = torch.softmax(neg_scores / self.d_neg, dim=-1) * n_anchors # (n_anchors, n_anchors*n_pos - n_pos)




        h_pos = (pos_scores.unsqueeze(-1) * k).sum(-2) # (n_anchors, dim)
        
        # (n_anchors, dim) = (n_anchors, dim) * (n_anchors, dim)
        res = (q * h_pos)
        res = res.sum(-1) # (n_anchors,)
        pos_terms = torch.exp(res / self.temperature) # (n_anchors,)




        neg_terms = torch.zeros((n_anchors, ), dtype=torch.float32, device=device)

        for i in range(n_anchors):
            # Take all the negative samples compared to the i-th query
            neg_k = torch.concat((k[:i], k[i+1:]), dim=0).reshape(-1, dim) # (n_anchors*n_pos - n_pos, dim)
            
            # Mutiply the negative scores with the negative keys
            
            # (n_anchors*n_pos - n_pos, dim) = (n_anchors*n_pos - n_pos) * (n_anchors*n_pos - n_pos, dim)
            n = neg_scores[i].unsqueeze(-1) * neg_k
            
            # (n_anchors*n_pos - n_pos) = (n_anchors*n_pos - n_pos, dim) * (dim)
            res = torch.matmul(n, q[i])
            
            res = torch.exp(res / self.temperature) # (n_anchors*n_pos - n_pos)
            res = res.sum() # (1)
            
            # Store the negative term of the i-th query
            neg_terms[i] = res
            
            
        # print(pos_terms / (pos_terms + neg_terms))
        # attentionNCE_loss = - torch.log( pos_terms / (pos_terms + neg_terms) ).mean()
        attentionNCE_loss = - torch.log( pos_terms / (pos_terms + neg_terms) ).mean()
        return attentionNCE_loss, pos_terms, neg_terms, scores
        
        