In [None]:
import torch
import matplotlib.pyplot as plt

import ment

plt.style.use("./style.mplstyle")

In [None]:
class GaussianMixtureDistribution:
    def __init__(self, locs: torch.Tensor, covs: torch.Tensor) -> None:
        self.dists = []
        for loc, cov in zip(locs, covs):
            dist = torch.distributions.MultivariateNormal(loc, cov)
            self.dists.append(dist)

        self.ndim = len(locs[0])
        self.nmodes = len(self.dists)

    def sample(self, size: int) -> torch.Tensor:
        sizes = torch.ones(self.nmodes) * (size // self.nmodes)

        indices = torch.arange(self.nmodes)
        if self.nmodes > 1:
            indices = indices[sizes > 0]

        x = torch.empty(0, device=sizes.device)
        for i in indices:
            dist = self.dists[i]
            size = int(sizes[i])
            x_k = dist.sample((size,))
            x = torch.cat((x, x_k), dim=0)
        return x

    def prob(self, x: torch.Tensor) -> None:
        p = torch.zeros(x.shape[0])
        for dist in self.dists:
            p += torch.exp(dist.log_prob(x))
        return p

In [None]:
ndim = 2
nmodes = 7
seed = 11
xmax = 7.0

torch.manual_seed(seed)

dist_locs = []
dist_covs = []
for _ in range(nmodes):
    loc = 5.0 * (torch.rand(size=(ndim,)) - 0.5)
    std = 1.0 * (torch.rand(size=(ndim,))) + 0.5
    cov = torch.eye(ndim) * std**2
    dist_locs.append(loc)
    dist_covs.append(cov)

dist = GaussianMixtureDistribution(locs=dist_locs, covs=dist_covs)

## Metropolis-Hastings (MH)

In [None]:
chains = 10
start = torch.randn((chains, ndim)) * 0.25
proposal_cov = torch.eye(ndim) * 0.1

sampler = ment.MetropolisHastingsSampler(
    ndim=ndim,
    start=start,
    proposal_cov=proposal_cov,
    verbose=0,
)

size = 50_000
x_pred = sampler(dist.prob, size=size)
x_true = dist.sample(size=size)

xmax = 9.0

fig, axs = plt.subplots(figsize=(6.0, 3.0), ncols=2, sharex=True, sharey=True)
for ax, x in zip(axs, [x_pred, x_true]):
    ax.hist2d(
        x[:, 0],
        x[:, 1],
        bins=64,
        range=[[-xmax, xmax], [-xmax, xmax]],
        density=True,
        cmap="viridis",
    )

## Hamiltonian Monte Carlo

In [None]:
chains = 10
step_size = 0.21
steps_per_samp = 10

sampler = ment.HamiltonianMonteCarloSampler(
    ndim=ndim,
    start=torch.randn((chains, ndim)) * 0.25**2,
    step_size=step_size,
    steps_per_samp=steps_per_samp,
    burnin=10,
    verbose=1,
)

size = 50_000
x_pred = sampler(dist.prob, size=size)
x_true = dist.sample(size=size)

xmax = 9.0

fig, axs = plt.subplots(figsize=(6.0, 3.0), ncols=2, sharex=True, sharey=True)
for ax, x in zip(axs, [x_pred, x_true]):
    ax.hist2d(
        x[:, 0],
        x[:, 1],
        bins=64,
        range=[[-xmax, xmax], [-xmax, xmax]],
        density=True,
        cmap="viridis",
    )