# Test grid sampling

In [None]:
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.plot as psv
import scipy.interpolate
from tqdm import tqdm

import ment

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

Create a particle distribution.

In [None]:
# Settings
ndim = 6
size = 1_000_000
n_modes = 5
seed = 1241

# Create gaussian particle distribution.
rng = np.random.default_rng(seed)
mean = np.zeros(ndim)
cov = np.identity(ndim)
for i in range(ndim):
    for j in range(i):
        cov[i, j] = cov[j, i] = rng.uniform(-0.2, 0.2)
x = rng.multivariate_normal(mean, cov, size=size)

# Add gaussian blobs.
for _ in range(n_modes):
    scale = rng.uniform(0.5, 1.5, size=ndim)
    loc = rng.uniform(-3.0, 3.0, size=ndim)
    x = np.vstack([x, rng.normal(loc=loc, scale=scale, size=(size // n_modes, ndim))])
x = x - np.mean(x, axis=0)

x_true = np.copy(x)

In [None]:
limits = ps.points.limits(x_true)

grid = psv.CornerGrid(ndim, corner=True)
grid.plot_points(
    x_true, 
    bins=50, 
    limits=limits, 
    cmap="mono",
)
plt.show()

Compute ground-truth histogram.

In [None]:
n_bins = 25
hist, edges = np.histogramdd(x_true, bins=n_bins, range=limits, density=True)
coords = [ment.grid.edges_to_coords(e) for e in edges]
print("hist.shape =", hist.shape)

Interpolate to obtain a smooth density function.

In [None]:
prob_func = scipy.interpolate.RegularGridInterpolator(
    coords, hist, method="linear", bounds_error=False, fill_value=0.0,
)

## Grid Sampling (GS)

Sample from the distribution.

In [None]:
samp_grid_res = 15  # limit ~ 15
grid_shape = [samp_grid_res] * ndim
grid_limits = limits

sampler = ment.samp.GridSampler(
    grid_limits=grid_limits,
    grid_shape=grid_shape,
    noise=0.0,
)

start_time = time.time()
x_samp = sampler(prob_func, size)

print("time:", time.time() - start_time)

Plot a samples over histogram projections.

In [None]:
def plot_corner_samp(x_samp: np.ndarray):
    limits = ps.points.limits(x_true)
    
    grid = psv.CornerGrid(ndim, corner=True)
    grid.set_limits(limits)
    grid.plot_image(hist, coords=coords, cmap="mono")
    grid.plot_points(
        x_samp[:1000, :], 
        kind="scatter", 
        color="red", 
        s=0.5,
        diag_kws=dict(color="red")
    )
    return grid
    
def plot_corner_hist(x_samp: np.ndarray):
    grid = psv.CornerGrid(ndim, corner=False)
    grid.set_limits(limits)
    grid.plot_image(hist, coords=coords, lower=False, cmap="mono")
    grid.plot_points(x_samp, upper=False, bins=n_bins, limits=limits, cmap="mono")
    return grid

In [None]:
plot_corner_samp(x_samp)
plot_corner_hist(x_samp)

## Slice Grid Sampling (GS)

In [None]:
samp_res = 20
int_res  = 10

ndim_proj = 2
ndim_samp = ndim_int = ndim - ndim_proj

grid_shape = [samp_res] * ndim
grid_limits = limits
int_size = int(int_res ** ndim_int)

sampler = ment.samp.SliceGridSampler(
    grid_limits=grid_limits,
    grid_shape=grid_shape,
    proj_dim=ndim_proj,
    int_size=int_size,
    int_method="grid",
    int_batches=1,
    noise=0.0,
    verbose=True,
)

start_time = time.time()
x_samp = sampler(prob_func, size)
print("time:", time.time() - start_time)

In [None]:
plot_corner_samp(x_samp)
plot_corner_hist(x_samp)

## Monte Carlo — Metropolis Hastings

In [None]:
sampler = ment.samp.MetropolisHastingsSampler(ndim=ndim, scale=1.0, burnin=100_000, shuffle=True)
x_samp = sampler(prob_func, size=200_000)

In [None]:
plot_corner_samp(x_samp)
plot_corner_hist(x_samp)