In [1]:
%load_ext autoreload
%autoreload 2

In [99]:
########################
### Hyper-paramaters ###
########################
M = 3
batch_size = 4
dim = 5

#######################
### Generated Input ###
#######################
pred_list, prob_list = [], []
for i in range(1, M, 1):
    torch.manual_seed(i)
    pred = torch.randn((batch_size, dim))
    pred_list.append(pred)
    prob_list.append(F.softmax(pred, dim=1))
targets = F.one_hot(torch.tensor([0, 2, 4, 2]), num_classes=dim).float()
weights = torch.tensor([0.1, 0.3, 0.6])
print(f"prob_list (len={len(prob_list)}, shape={prob_list[0].shape})")
for prob in prob_list:
    print(prob)
print('')
print(f"targets ({targets.shape})")
print(targets, end='\n\n')
print(f"weights ({weights.shape})")
print(weights, end='\n\n')

prob_list (len=2, shape=torch.Size([4, 5]))
tensor([[0.0574, 0.1247, 0.1373, 0.0528, 0.6277],
        [0.1449, 0.0585, 0.2544, 0.1761, 0.3661],
        [0.6641, 0.1031, 0.1206, 0.0611, 0.0511],
        [0.0913, 0.0468, 0.1456, 0.1371, 0.5792]])
tensor([[0.0787, 0.5575, 0.0605, 0.0735, 0.2297],
        [0.1096, 0.3341, 0.2793, 0.0512, 0.2259],
        [0.0477, 0.1337, 0.1687, 0.4489, 0.2009],
        [0.6147, 0.0508, 0.0308, 0.0894, 0.2143]])

targets (torch.Size([4, 5]))
tensor([[1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0.]])

weights (torch.Size([3]))
tensor([0.1000, 0.3000, 0.6000])



In [100]:
# Distributions contain [p1, p2, ..., pM] for pi has shape of (batch_size, dim)
assert prob_list[0].shape == targets.shape
distribs = [targets] + prob_list
distribs = torch.stack(distribs, dim=0)     # (M, batch_size, dim)

print(f"distribs ({distribs.shape})")
print(distribs)

distribs (torch.Size([3, 4, 5]))
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 1.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
         [0.0000, 0.0000, 1.0000, 0.0000, 0.0000]],

        [[0.0574, 0.1247, 0.1373, 0.0528, 0.6277],
         [0.1449, 0.0585, 0.2544, 0.1761, 0.3661],
         [0.6641, 0.1031, 0.1206, 0.0611, 0.0511],
         [0.0913, 0.0468, 0.1456, 0.1371, 0.5792]],

        [[0.0787, 0.5575, 0.0605, 0.0735, 0.2297],
         [0.1096, 0.3341, 0.2793, 0.0512, 0.2259],
         [0.0477, 0.1337, 0.1687, 0.4489, 0.2009],
         [0.6147, 0.0508, 0.0308, 0.0894, 0.2143]]])


In [109]:
# Weighted distributions: sum of {πj*pj} for j = 1, ..., M, where πj is j-th weight and pj is j-th prob_list
# a.k.a. mixture
assert len(weights) == len(distribs)
weighted_distrib_list = []
for i in range(len(weights)):
    weighted_distrib_list.append(weights[i] * distribs[i])
weighted_distrib = torch.stack(weighted_distrib_list, dim=0)
weighted_distrib = torch.sum(weighted_distrib, dim=0)
weighted_distrib_log = weighted_distrib.clamp(1e-7, 1.0).log()       # (batch_size, dim)
print(f"weighted_distrib ({weighted_distrib.shape})")
print(weighted_distrib, end='\n\n')
print(f"weighted_distrib_log ({weighted_distrib_log.shape})")
print(weighted_distrib_log, end='\n\n')

weighted_distrib (torch.Size([4, 5]))
tensor([[0.1645, 0.3719, 0.0775, 0.0599, 0.3261],
        [0.1092, 0.2180, 0.3439, 0.0835, 0.2454],
        [0.2278, 0.1112, 0.1374, 0.2877, 0.2359],
        [0.3962, 0.0446, 0.1621, 0.0948, 0.3023]])

weighted_distrib_log (torch.Size([4, 5]))
tensor([[-1.8050, -0.9890, -2.5575, -2.8143, -1.1204],
        [-2.2145, -1.5233, -1.0674, -2.4826, -1.4051],
        [-1.4791, -2.1967, -1.9848, -1.2458, -1.4444],
        [-0.9258, -3.1109, -1.8193, -2.3563, -1.1962]])



In [110]:
# Weighted KLD: sum of {πi*KLD(pi||sum(πj*pj))} for i = 1, ..., M
weighted_kld_list = []
for i in range(len(weights)):
    # NOTE: weighted_jsd has same shape with weighted_kld
    weighted_kld = weights[i] * kl_div(weighted_distrib_log, distribs[i], reduction=reduction)
    weighted_kld_list.append(weighted_kld)
weighted_jsd = torch.stack(weighted_kld_list, dim=0)
weighted_jsd = torch.sum(weighted_jsd, dim=0)

print(f"weighted_jsd ({weighted_jsd.shape})")
print(weighted_jsd)

weighted_jsd (torch.Size([]))
tensor(0.0622)


In [112]:
from mmdet.models.losses.ai28.divergence import WeightedGeneralizedJSD
wgjsd = WeightedGeneralizedJSD(weights)
wgjsd(prob_list, targets)

tensor(0.0622)