In [134]:
from torch.distributions import MultivariateNormal
import torch

In [135]:
class_dim = 512
K = 10

In [136]:
def reparameterize(mu, var):
    """
    Samples z from a multivariate Gaussian with diagonal covariance matrix using the
     reparameterization trick.
    """

    std = var.sqrt()
    eps = torch.FloatTensor(std.size()).normal_()
    z = eps.mul(std).add_(mu)

    return z

def reparameterize_witheps(mu, var, eps):
    std = var.sqrt()
    z = eps.mul(std).add_(mu)
    return z

In [137]:
eps_1 = MultivariateNormal(torch.zeros(class_dim),
                                  torch.eye(class_dim)). \
            sample((K,)).reshape((K, class_dim))

mu  = torch.zeros(class_dim)
var = torch.ones(class_dim)

print(mu.shape, var.shape)

std_norm_samples = torch.cat([reparameterize(mu, var).unsqueeze(dim=0) for eps in  eps_1], dim = 0)
std_norm_samples_sorted = torch.cat([reparameterize_witheps(mu, var, eps).unsqueeze(dim=0) for eps in  eps_1], dim = 0)

print(std_norm_samples_sorted.shape)

unsorted_div = torch.nn.functional.kl_div(eps_1, std_norm_samples, reduce=False, log_target=True).sum(-1).mean()
sorted_div = torch.nn.functional.kl_div(eps_1, std_norm_samples_sorted, reduce=False, log_target=True).sum(-1).mean()

print('unsorted_div: ',unsorted_div)
print('sorted_div: ',sorted_div)

torch.Size([512]) torch.Size([512])
torch.Size([10, 512])
unsorted_div:  tensor(848.6182)
sorted_div:  tensor(0.)


In [138]:
mu = torch.rand(class_dim)
var = torch.rand(class_dim)
logvar = var.log()
KLD = -0.5 * torch.sum(1 - var - mu.pow(2) + logvar)
print(KLD)

tensor(240.3409)


In [139]:
unsorted_samples = torch.cat([reparameterize(mu, var).unsqueeze(dim=0) for eps in  eps_1], dim = 0)

In [140]:
unsorted_div = torch.nn.functional.kl_div(eps_1, unsorted_samples, reduce=False, log_target=True).sum(-1).mean()
print(unsorted_div)

tensor(1369.1810)


In [141]:
sorted_samples = torch.cat([reparameterize_witheps(mu, var, eps).unsqueeze(dim=0) for eps in  eps_1], dim = 0)


sorted_sampled_div = torch.nn.functional.kl_div(eps_1, sorted_samples, reduce=False, log_target=True).sum(-1).mean()
print(sorted_sampled_div)

tensor(519.1244)
