# Histogram sampling

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt

import psdist as ps
import psdist.visualization as psv

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

Create distribution.

In [None]:
ndim = 6
size = 10_000
dims = ["x", "px", "y", "py", "z", "pz"]

state = np.random.default_rng(1241)
cov = np.identity(ndim)
for i in range(ndim):
    for j in range(i):
        cov[i, j] = cov[j, i] = state.uniform(-0.2, 0.2)
        
points = state.multivariate_normal(np.zeros(ndim), cov, size=size)
for _ in range(4):
    scale = state.uniform(0.5, 1.5, size=ndim)
    loc = state.uniform(-3.0, 3.0, size=ndim)
    points = np.vstack([points, state.normal(loc=loc, scale=scale, size=(size, ndim))])
points -= np.mean(points, axis=0)

points_true = points.copy()

In [None]:
limits = ps.points.limits(points_true)

grid = psv.points.corner(
    points_true, 
    bins=50, 
    limits=limits, 
    grid_kws=dict(figwidth=7.0), 
    cmap=pplt.Colormap("mono", left=0.05), 
)
plt.show()

Compute histogram.

In [None]:
n_bins = 30
hist, edges = np.histogramdd(points_true, bins=n_bins, range=limits)
coords = ps.utils.coords_list_from_edges_list(edges)
print("hist.size =", hist.size)

Sample from histogram.

In [None]:
n_samples = 100_000
points_samp = ps.image.sample(hist, edges=edges, size=n_samples)

In [None]:
grid = psv.CornerGrid(d=6, figwidth=7.0)
grid.plot_image(hist, coords=coords, cmap=pplt.Colormap("mono"))
grid.plot_points(
    points_samp[:500, :], 
    kind="scatter", 
    color="red", 
    s=0.5, 
    diag_kws=dict(color="red")
)

Compute sparse histogram.

In [None]:
(nonzero_indices, nonzero_counts, nonzero_edges) = ps.points.sparse_histogram(points, bins=n_bins, limits=limits)
print("sparse_hist.size =", len(nonzero_counts))

Sample from sparse histogram.

In [None]:
points_samp_sparse = ps.image.sample_sparse(
    indices=nonzero_indices, 
    values=nonzero_counts, 
    edges=nonzero_edges, 
    size=n_samples,
)

In [None]:
grid = psv.CornerGrid(d=6, figwidth=7.0)
grid.plot_image(hist, coords=coords, cmap=pplt.Colormap("mono"))
grid.plot_points(
    points_samp_sparse[:500, :], 
    kind="scatter", 
    color="red", 
    s=0.5, 
    diag_kws=dict(color="red")
)

Compare sparse and regular histogram samples.

In [None]:
grid = psv.CornerGrid(d=6, figwidth=7.0)
for i, points in enumerate([points_samp, points_samp_sparse]):
    color = ["blue5", "red5"][i]
    ls = ["-", "--"][i]

    grid.plot_points(
        points,
        bins=30, 
        autolim_kws=dict(pad=-0.10),
        kind="contour", 
        process_kws=dict(norm="max", blur=1.0), 
        diag_kws=dict(color=color, ls=ls),
        
        levels=np.linspace(0.0, 1.0, 7, endpoint=False)[1:],
        colors=color,
        lw=1.0,
        ls=ls
    )