In [17]:
import torch
from torch.distributions import Categorical

def sample(pi, mu, sigma):
    """Draw samples from a MoG.
    """
    # Choose which gaussian we'll sample from
    pis = Categorical(pi).sample().view(pi.size(0), 1, 1) # (B, 1, 1)
    # Choose a random sample, one randn for batch X output dims
    # Do a (output dims)X(batch size) tensor here, so the broadcast works in
    # the next step, but we have to transpose back.
    gaussian_noise = torch.randn(
        (sigma.size(2), sigma.size(0)), requires_grad=False) # (out_dim, B)
    variance_samples = sigma.gather(1, pis).detach().squeeze() # (B)
    mean_samples = mu.detach().gather(1, pis).squeeze()

    print(variance_samples.shape)
    print(mean_samples.shape)
    print((gaussian_noise * variance_samples).shape) # (Dim, B)

    return (gaussian_noise * variance_samples + mean_samples).transpose(0, 1)

def sample_n_times(pi, mu, sigma, n):
    """
    Draw n samples from a MoG.
    pi: (B, G)
    mu: (B, G, D)
    sigma: (B, G, D)
    # B Batch
    # n number of samples
    # G number of gaussians
    # D output dimension
    """
    # B = pi.size(0)
    # G = mu.size(1)
    # D = mu.size(2)

    # Choose which gaussian we'll sample from
    pis = torch.multinomial(pi, n, replacement=True) # (B, n)

    def gather_and_select(pis, obj):
        all_samples = []
        for index in range(pi.size(0)):
            pi_indexed = pis[index]
            # print(pi_indexed)
            obj_indexed = obj[index]
            # print(mus)
            samples = torch.index_select(obj_indexed, 0, pi_indexed)
            all_samples.append(samples)
        return torch.stack(all_samples, dim=0)
        
    mean_samples = gather_and_select(pis, mu)
    variance_samples = gather_and_select(pis, sigma)
    # print(mean_samples)
    # print(variance_samples)
    gaussian_noise = torch.randn(variance_samples.shape, requires_grad=False)

    return gaussian_noise * variance_samples + mean_samples

In [20]:
pi = torch.tensor([[0.1, 0.2, 0.7], 
                   [0.7, 0.1, 0.2]]) # B, n_gaussians
# mu = torch.tensor([[[1.0], [2.0], [3]], [[1.0], [2.0], [3.0]]]) # B x n_gaussians x 1
# sigma = torch.tensor([[[1.0], [2.0], [3.0]], [[1.0], [2.0], [3.0]]])

mu = torch.tensor([[[1.0, -1.0], [2.0, -2.0], [3, -3]], [[1.0, -1.0], [2.0, -2.0], [3, -3]]]) # B x n_gaussians x 1
sigma = torch.tensor([[[1.0, -1.0], [1.0, -1.0], [1.0, -1.0]], [[1.0, -1.0], [1.0, -1.0], [1.0, -1.0]]])

res = sample_n_times(pi, mu, sigma, 10)

print(res)

# pis = torch.multinomial(pi, 5, replacement=True) # (B, n)
# pis = pis.unsqueeze(2)

# def gather_and_select(pis, obj):
#     all_samples = []
#     for index in range(pi.size(0)):
#         pi_indexed = pis[index]
#         # print(pi_indexed)
#         obj_indexed = obj[index]
#         # print(mus)
#         samples = torch.index_select(obj_indexed, 0, pi_indexed)
#         all_samples.append(samples)
#     return torch.stack(all_samples, dim=0)
    
# mean_samples = gather_and_select(pis, mu)
# variance_samples = gather_and_select(pis, sigma)
# gaussian_noise = torch.randn(variance_samples.shape, requires_grad=False)

# print(mean_samples + gaussian_noise * variance_samples)

# print(mean_samples.shape, variance_samples.shape, gaussian_noise.shape)
# print(mean_samples)

# mean_samples = mu.detach().gather(1, pis).squeeze()
# print(mean_samples)

# pi = torch.randn((20, 4)).softmax(dim=1)
# mu = torch.randn((20, 4, 2))
# sigma = torch.randn((20, 4, 2))

# print(pi.shape)


# pis = torch.tensor([[[0]], [[1]]])
# pis = torch.tensor([[0, 0, 0, 0], [1, 1, 1, 1]]).unsqueeze(2)
# print(mu.shape)
# print(pis.shape)
# print(mu.gather(1, pis))

# print(mu.shape)
# print(sample(pi, mu, sigma))
# sample_n_times(pi, mu, sigma, 10)


tensor([[[ 0.8769, -3.6063],
         [ 0.6461, -1.3874],
         [ 1.4144, -0.9026],
         [ 1.5645, -1.9044],
         [ 2.3986, -2.6693],
         [-0.6374, -0.2439],
         [ 2.3664, -2.6135],
         [ 1.9097, -4.5650],
         [ 3.7312, -2.1519],
         [ 0.7809, -1.3948]],

        [[ 0.8097, -0.3424],
         [ 1.3807, -0.5973],
         [ 1.6847, -2.0185],
         [ 2.4039, -1.6870],
         [-1.1404, -1.2580],
         [ 1.0310, -1.7590],
         [ 1.5485, -0.1587],
         [ 0.2869, -0.1614],
         [ 3.2787, -4.6529],
         [ 0.8590,  2.0920]]])


In [19]:
x = torch.randn(3, 4)
print(x)
indices = torch.tensor([0, 0, 0, 0])
res = torch.index_select(x, 0, indices)
print(res)

tensor([[ 0.6579, -0.2727,  0.6533, -0.2790],
        [-0.5052,  1.7635,  0.8555, -0.1423],
        [ 0.9466,  0.3008, -0.0817,  0.5212]])
tensor([[ 0.6579, -0.2727,  0.6533, -0.2790],
        [ 0.6579, -0.2727,  0.6533, -0.2790],
        [ 0.6579, -0.2727,  0.6533, -0.2790],
        [ 0.6579, -0.2727,  0.6533, -0.2790]])
