In [None]:
import torch

In [None]:
class GaussianMixtureDistribution:
    def __init__(self, locs: torch.Tensor, covs: torch.Tensor) -> None:        
        self.dists = []
        for loc, cov in zip(locs, covs):
            dist = torch.distributions.MultivariateNormal(loc, cov)
            self.dists.append(dist)

        self.ndim = len(locs[0])
        self.nmodes = len(self.dists)
        
    def sample(self, size: int) -> torch.Tensor:
        sizes = torch.ones(self.nmodes) * (size // self.nmodes)
        
        indices = torch.arange(self.nmodes)
        if self.nmodes > 1:
            indices = indices[sizes > 0]

        x = torch.empty(0, device=sizes.device)
        for i in indices:
            dist = self.dists[i]
            size = int(sizes[i])
            x_k = dist.sample((size,))
            x = torch.cat((x, x_k), dim=0)
        return x

    def prob(self, x: torch.Tensor) -> None:
        p = torch.zeros(x.shape[0])
        for dist in self.dists:
            p += torch.exp(dist.log_prob(x))
        return p

In [None]:
ndim = 2
nmodes = 7
seed = 11
xmax = 7.0

torch.manual_seed(seed)

dist_locs = []
dist_covs = []
for _ in range(nmodes):
    loc = 5.0 * (torch.rand(size=(ndim,)) - 0.5)
    std = 1.0 * (torch.rand(size=(ndim,))) + 0.5
    cov = torch.eye(ndim) * std**2
    dist_locs.append(loc)
    dist_covs.append(cov)