# Test sampling algorithms

In [None]:
import math

import matplotlib.pyplot as plt
import torch

import ment_torch as ment

In [None]:
plt.style.use("style.mplstyle")

## Create distribution

In [None]:
class Distribution2D:
    def __init__(self) -> None:
        self.ndim = ndim

    def prob(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError
        
    def prob_grid(self, shape: tuple[int], limits: list[tuple[float, float]]) -> torch.Tensor:
        edges = [torch.linspace(limits[i][0], limits[i][1], shape[i] + 1) for i in range(self.ndim)]
        coords = [0.5 * (e[:-1] + e[1:]) for e in edges]
        points = torch.stack([c.ravel() for c in torch.meshgrid(*coords, indexing="ij")], axis=-1)
        values = self.prob(points)
        values = values.reshape(shape)
        return values, coords

In [None]:
class RingDistribution(Distribution2D):
    def __init__(self) -> None:
        super().__init__()
        
    def prob(self, x: torch.Tensor) -> torch.Tensor:
        x1 = x[..., 0]
        x2 = x[..., 1]
        log_prob = torch.sin(torch.pi * x1) - 2.0 * (x1**2 + x2**2 - 2.0)**2
        return torch.exp(log_prob)

In [None]:
# class GaussianMixtureDistribution(Distribution2D):
#     def __init__(self, locs: torch.Tensor, covs: torch.Tensor) -> None:
#         super().__init__()
        
#         self.dists = []
#         for loc, cov in zip(locs, covs):
#             dist = torch.distributions.MultivariateNormal(loc, cov)
#             self.dists.append(dist)

#         self.ndim = len(locs[0])
#         self.nmodes = len(self.dists)
        
#     def sample(self, size: int) -> torch.Tensor:
#         sizes = torch.ones(self.nmodes) * (size // self.nmodes)
        
#         indices = torch.arange(self.nmodes)
#         if self.nmodes > 1:
#             indices = indices[sizes > 0]

#         x = torch.empty(0, device=sizes.device)
#         for i in indices:
#             dist = self.dists[i]
#             size = int(sizes[i])
#             x_k = dist.sample((size,))
#             x = torch.cat((x, x_k), dim=0)
#         return x

#     def prob(self, x: torch.Tensor) -> None:
#         p = torch.zeros(x.shape[0])
#         for dist in self.dists:
#             p += torch.exp(dist.log_prob(x))
#         return p

In [None]:
ndim = 2
nmodes = 7
seed = 11
# xmax = 7.0
xmax = 3.0

torch.manual_seed(seed)

dist_locs = []
dist_covs = []
for _ in range(nmodes):
    loc = 5.0 * (torch.rand(size=(ndim,)) - 0.5)
    std = 1.0 * (torch.rand(size=(ndim,))) + 0.5
    cov = torch.eye(ndim) * std**2
    dist_locs.append(loc)
    dist_covs.append(cov)
    
# dist = GaussianMixture(locs=dist_locs, covs=dist_covs)
dist = RingDistribution()

In [None]:
grid_limits = 2 * [(-xmax, xmax)]
grid_shape = (128, 128)
grid_values, grid_coords = dist.prob_grid(grid_shape, grid_limits)

fig, ax = plt.subplots(figsize=(3, 3))
ax.pcolormesh(grid_coords[0], grid_coords[1], grid_values.T)
plt.show()

## Sample

In [None]:
def plot_samples(x_pred: torch.Tensor) -> tuple:
    fig, axs = plt.subplots(ncols=2, figsize=(6.0, 2.75), sharex=True, sharey=True)
    axs[0].hist2d(x_pred[:, 0], x_pred[:, 1], bins=80, range=grid_limits)
    axs[1].pcolormesh(grid_coords[0], grid_coords[1], grid_values.T)
    axs[0].set_title("PRED", fontsize="medium")
    axs[1].set_title("TRUE", fontsize="medium")
    return fig, axs

In [None]:
size = 256_000

### Grid Sampling

In [None]:
sampler = ment.GridSampler(
    limits=grid_limits,
    shape=grid_shape,
)
x_pred = sampler(dist.prob, size=size)

In [None]:
fig, axs = plot_samples(x_pred)
plt.show()

## Metropolis-Hastings

In [None]:
sampler = ment.MetropolisHastingsSampler(
    ndim=ndim,
    proposal_cov=(0.05 * torch.eye(ndim)),
    chains=50,
    verbose=1,
)
x_pred = sampler(dist.prob, size=size)

In [None]:
sampler.results

In [None]:
fig, axs = plot_samples(x_pred)
plt.show()