In [102]:
import torch
import math

def mdn_negative_log_likelihood(pi, mu, sigma, target):
    """ Use torch.logsumexp for more stable training 
    
    This is equivalent to the mdn_loss but computed in a numerically stable way

    """
    target = target.unsqueeze(2).expand_as(sigma)
    neg_logprob = -torch.log(sigma) - (math.log(2 * math.pi) / 2) - \
        ((target - mu) / sigma)**2 / 2
    
    # (B, num_heads, num_gaussians)
    inner = torch.log(pi) + torch.sum(neg_logprob, 3) # Sum the log probabilities of (x, y) for each 2D Gaussian

    print(inner.shape)
    return -torch.logsumexp(inner, dim=2)



In [103]:
pi = torch.tensor([[[0.9, 0.05, 0.05], 
                   [0.05, 0.05, 0.9]], 
                   
                   [[0.1, 0.2, 0.7], 
                   [0.7, 0.1, 0.2]]]) # B, num_heads, n_gaussians


mu = torch.tensor([[[1.0, -1.0], [2.0, -2.0], [3, -3]], [[1.0, -1.0], [2.0, -2.0], [3, -3]]]) # B x n_gaussians x 2
sigma = torch.tensor([[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]])

# mu_2 = torch.randn((2, 2, 3, 2))
# sigma_2 = torch.randn((2, 2, 3, 2))


# print(pi.shape)
# print(mu.shape)
mu_2 = mu.repeat(2, 1, 1, 1) # B x num_heads, n_gaussians, 2
sigma_2 = sigma.repeat(2, 1, 1, 1) # B x num_heads, n_gaussians, 2


target = torch.randn((2, 2, 2))
target = torch.tensor([[[1.0, -1.0],
         [3, -3]],

        [[2.0,  -2.0],
         [ 2.0, -2.0]]])
# print(target)


res = mdn_negative_log_likelihood(pi, mu_2, sigma_2, target)
print(res.shape)
print(res)
# print(mu_2)
# print(mu_2.shape)

torch.Size([2, 2, 3])
torch.Size([2, 2])
tensor([[1.9220, 1.9220],
        [2.5425, 2.6793]])


In [114]:
import torch
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence

pi = torch.tensor([[0.9, 0.07, 0.03],
                #    [0.05, 0.05, 0.9]], 
                   
                #    [[0.1, 0.2, 0.7], 
                   [0.7, 0.1, 0.2]]) # B, n_gaussians


mu = torch.tensor([[[1.0, -1.0], [2.0, -2.0], [3, -3]], [[1.0, -1.0], [2.0, -2.0], [3, -3]]]) # B x n_gaussians x 2
sigma = torch.tensor([[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]])

# print(sigma.shape)

mu_2 = torch.randn((2, 3, 2))
sigma_2 = torch.rand((2, 3, 2)) + 1

torch.manual_seed(0)
# print(sigma_2.shape)

def kl_gaussian(mean_1, sigma_1, mean_2, sigma_2):
    """
    KL divergence between two Gaussian distributions
    """
    var_2 = sigma_2**2
    var_1 = sigma_1**2
    return torch.log(sigma_2 / sigma_1) + (var_2 + (mean_1 - mean_2) ** 2) / (2*var_2) - 0.5

def kl_between_gaussian_mixtures(mix_1, mix_2):

    pi_a, mu_a, sigma_a = mix_1 # B x n_gaussians x 2
    pi_b, mu_b, sigma_b = mix_2

    n_gaussians = mu_b.shape[1]
    kl_all = []
    for a in range(n_gaussians):

        kl_num = []
        for a_p in range(n_gaussians):
            p_dist = Normal(mu_a[:, a, :], sigma_a[:, a, :])
            q_dist = Normal(mu_a[:, a_p, :], sigma_a[:, a_p, :])
            kl_n = kl_divergence(p_dist, q_dist)
            pi = pi_a[:, a_p]
            kl_num.append(torch.einsum('b, bd -> bd', pi, torch.exp(-kl_n)))

        kl_num = torch.stack(kl_num) # num_gaussians x 2
        kl_num = kl_num.sum(dim=0)
        # print(kl_num)
        kl_den = []
        for b_p in range(n_gaussians):
            p_dist = Normal(mu_a[:, a, :], sigma_a[:, a, :]) # b x 2
            q_dist = Normal(mu_b[:, b_p, :], sigma_b[:, b_p, :])
            kl_d = kl_divergence(p_dist, q_dist)
            pi = pi_b[:, b_p]
            kl_den.append(torch.einsum('b, bd -> bd', pi, torch.exp(-kl_d)))

        kl_den = torch.stack(kl_den) # num_gaussians x 2
        kl_den = kl_den.sum(dim=0)
        kl_all.append(kl_num / kl_den)
        # print(kl_num/kl_den)

    # print(kl_all)
    kl_all = torch.log(torch.stack(kl_all)) # num_gaussians x b x dimension
    # print(kl_all, pi_a)
    res = torch.einsum('gbd,bg->gbd', kl_all, pi_a)
    # print(res)
    return torch.sum(res, dim=0)


def kl_gaussian_mixtures_vectorized(mix_1, mix_2):
    pi_a, mu_a, sigma_a = mix_1 # B x n_gaussians x 2
    pi_b, mu_b, sigma_b = mix_2
    n_gaussians = mu_b.shape[1]

    # outer
    mu_a_tog = torch.repeat_interleave(mu_a, n_gaussians, dim=1) # [a, b, c] --> [a, a, b, b, c, c]
    sigma_a_tog = torch.repeat_interleave(sigma_a, n_gaussians, dim=1)

    pi_a_p = pi_a.repeat(1, n_gaussians) # B x n_gaussians**2 [a, b, c] --> [a, b, c, a, b, c]
    mu_a_p = mu_a.repeat(1, n_gaussians, 1) # B x n_gaussians**2 x 2
    sigma_a_p = sigma_a.repeat(1, n_gaussians, 1) # B x n_gaussians**2 x 2

    pi_b_p = pi_b.repeat(1, n_gaussians) # B x n_gaussians**2
    mu_b_p = mu_b.repeat(1, n_gaussians, 1) # B x n_gaussians**2 x 2
    sigma_b_p = sigma_b.repeat(1, n_gaussians, 1) # B x n_gaussians**2 x 2

    p_num = Normal(mu_a_tog, sigma_a_tog)
    q_num = Normal(mu_a_p, sigma_a_p)
    kl_num = kl_divergence(p_num, q_num)
    # print(kl_num)

    kls_num = torch.einsum('bxd,bx->bxd', torch.exp(-kl_num), pi_a_p) # B x n_gaussians**2
    kls_num_reshaped = kls_num.reshape(kl_num.shape[0], n_gaussians, n_gaussians, 2) # B x n_gaussians x n_gaussians x 2
    num = kls_num_reshaped.sum(dim=2) # B x n_gaussians x 2

    p_den = Normal(mu_a_tog, sigma_a_tog)
    q_den = Normal(mu_b_p, sigma_b_p)
    kl_den = kl_divergence(p_den, q_den)
    kls_den = torch.einsum('bxd,bx->bxd', torch.exp(-kl_den), pi_b_p) # B x n_gaussians**2
    kls_den_reshaped = kls_den.reshape(kl_den.shape[0], n_gaussians, n_gaussians, 2) # B x n_gaussians x n_gaussians x 2
    den = kls_den_reshaped.sum(dim=2) # B x n_gaussians x 2

    divided = num / den # B x num_gaussians x 2
    # print(divided.shape)
    res = torch.einsum('bgd,bg->bgd', torch.log(divided), pi_a)
    return res.sum(dim=1)

res = kl_between_gaussian_mixtures((pi, mu, sigma), (pi, mu_2, sigma_2))
# res = kl_between_gaussian_mixtures((pi, mu_2, sigma_2), (pi, mu_2, sigma_2))
print("result", res)

# print(res)
vectorized = kl_gaussian_mixtures_vectorized((pi, mu, sigma), (pi, mu_2, sigma_2))
# vectorized = kl_gaussian_mixtures_vectorized((pi, mu_2, sigma_2), (pi, mu_2, sigma_2))
print("vectorized", vectorized)

result tensor([[0.1340, 0.2358],
        [0.4137, 0.6707]])
vectorized tensor([[0.1340, 0.2358],
        [0.4137, 0.6707]])


In [98]:
a = torch.tensor([1, 1, 2, 2])
b =a.reshape(2, 2)
print(b)
print(b.sum(dim=1))

tensor([[1, 1],
        [2, 2]])
tensor([2, 4])
