In [66]:
import torch
from torch.distributions import Categorical
from torch.distributions.gumbel import Gumbel

def sample_mixture_gumbel(pi, mu, sigma, temp=0.1):
    """ 
    
    Given a mixture of gaussians, sample from the mixture in a way that we can backpropagate through

    pi: (B, G)
    mu: (B, G, D)
    sigma: (B, G, D)

    First, sample categorically from the mixture pi with gumbel softmax.
    Then, sample from the corresponding gaussian by multiplying and adding with mean and std.

    Returns shape of (B, D) where we have batch size and dimension of gaussian

    """
    m = Gumbel(torch.zeros_like(pi), torch.ones_like(pi))
    g = m.sample()
    gumbel_softmax = torch.softmax((torch.log(pi) + g)/temp, dim=-1) # (B, num_gaussians)

    eps = torch.randn_like(sigma)
    samples = mu + (eps * sigma)

    print(samples.shape)
    print(gumbel_softmax.shape)


    print(samples)
    print(gumbel_softmax)
    gumbel_weighted = torch.einsum('bgd,bg->bd', [samples, gumbel_softmax])
    print(gumbel_weighted)
    print(gumbel_weighted.shape)
    return gumbel_softmax
    
    

In [67]:
pi = torch.tensor([[0.05, 0.5, 0.7], 
                   [0.7, 0.1, 0.2]]) # B, n_gaussians

# pi = torch.tensor([[0.05, 0.95], 
#                    [0.5, 0.5]]) # B, n_gaussians
# mu = torch.tensor([[[1.0], [2.0], [3]], [[1.0], [2.0], [3.0]]]) # B x n_gaussians x 1
# sigma = torch.tensor([[[1.0], [2.0], [3.0]], [[1.0], [2.0], [3.0]]])

mu = torch.tensor([[[1.0, -1.0], [2.0, -2.0], [3, -3]], [[1.0, -1.0], [2.0, -2.0], [3, -3]]]) # B x n_gaussians x 1
# sigma = torch.tensor([[[1.0, -1.0], [1.0, -1.0], [1.0, -1.0]], [[1.0, -1.0], [1.0, -1.0], [1.0, -1.0]]])
sigma = torch.ones((2, 3, 2)) * 0.001

# print(pi.shape, mu.shape, sigma.shape)

a = torch.zeros(2)
for i in range(1):
    res = sample_mixture_gumbel(pi, mu, sigma)
    # print(res.shape)
    b = torch.argmax(res, dim=1)
    a = a + b
    # print(b)

# print(a)

torch.Size([2, 3, 2])
torch.Size([2, 3])
tensor([[[ 1.0011, -1.0004],
         [ 1.9999, -1.9996],
         [ 3.0007, -2.9999]],

        [[ 1.0003, -0.9988],
         [ 1.9992, -2.0002],
         [ 3.0006, -2.9991]]])
tensor([[1.0000e+00, 8.0938e-08, 4.2895e-07],
        [1.0000e+00, 3.4458e-20, 2.4563e-18]])
tensor([[ 1.0011, -1.0005],
        [ 1.0003, -0.9988]])
torch.Size([2, 2])
