In [3]:
import torch
import torch.nn as nn
from icecream import ic
class Divergence(nn.Module):
    """
    Jensen-Shannon divergence, used to measure ranking consistency between similarity lists obtained from examples with two different dropout masks
    """
    def __init__(self, beta_):
        super(Divergence, self).__init__()
        self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True)
        self.eps = 1e-7
        self.beta_ = beta_

    def forward(self, p: torch.tensor, q: torch.tensor):
        p, q = p.view(-1, p.size(-1)), q.view(-1, q.size(-1))
        # ic(p)
        # ic(q.shape)
        m = (0.5 * (p + q)).log().clamp(min=self.eps)
        return 0.5 * (self.kl(m, p.log()) + self.kl(m, q.log()))

In [5]:
p = torch.randn(4,4)
q = torch.randn(4,4)
div = Divergence(0.5)
p.softmax(dim=-1)


tensor([[0.5281, 0.1457, 0.0907, 0.2356],
        [0.3567, 0.0649, 0.4409, 0.1375],
        [0.1604, 0.5561, 0.1401, 0.1434],
        [0.1053, 0.2656, 0.3845, 0.2445]])