# Test KDE

In [1]:
import os
import sys
import time

import numpy as np
import psdist as ps
import psdist.plot as psv
import ultraplot as plt
from ipywidgets import interact
from tqdm.notebook import tqdm
from tqdm.notebook import trange

import ment

In [2]:
plt.rc["cmap.discrete"] = False
plt.rc["cmap.sequential"] = "viridis"
plt.rc["figure.facecolor"] = "white"
plt.rc["grid"] = False

In [3]:
dist_name = "gaussian-mixture"
ndim = 6
xmax = 3.5
seed = 12345

In [4]:
dist = ment.dist.get_dist(dist_name, ndim=ndim, seed=seed)
x_true = dist.sample(1_000_000)

## 1D

In [5]:
@interact(
    axis=list(range(ndim)), log_n_samp=(3.0, 5.0), bandwidth=(0.01, 4.00, 0.01), bins=(10, 128)
)
def update(axis: int, log_n_samp: float, bandwidth_frac: float = 1.0, bins: int = 50):
    n_samp = int(10.0**log_n_samp)
    x = dist.sample(n_samp)

    edges = np.linspace(-xmax, xmax, bins + 1)

    fig, ax = plt.subplots(figsize=(3.0, 1.5))
    for i in range(2):
        diagnostic = ment.diag.Histogram1D(
            axis=axis, edges=edges, kde=i, kde_bandwidth_frac=bandwidth_frac
        )
        values = diagnostic(x)
        ax.stairs(values, edges, lw=1.5)
    ax.hist(x_true[:, axis], edges, density=True, color="black", alpha=1.0, histtype="step", lw=1.5)
    plt.show()

interactive(children=(Dropdown(description='axis', options=(0, 1, 2, 3, 4, 5), value=0), FloatSlider(value=4.0…

## 2D

In [6]:
@interact(
    axis1=list(range(ndim)),
    axis2=list(range(ndim)),
    log_n_samp=(3.0, 5.0),
    bandwidth=(0.01, 3.00, 0.01),
    bins=(10, 128),
)
def update(
    axis1: int = 0,
    axis2: int = 1,
    log_n_samp: float = 4.0,
    bandwidth_frac: float = 1.0,
    bins: int = 50,
):
    axis = (axis1, axis2)

    n_samp = int(10.0**log_n_samp)
    x = dist.sample(n_samp)

    edges = 2 * [np.linspace(-xmax, xmax, bins + 1)]
    grid_coords = [e[:-1] + e[1:] for e in edges]
    grid_points = ps.hist.get_grid_points(grid_coords)

    fig, axs = plt.subplots(figsize=None, ncols=2)
    for i, ax in enumerate(axs):
        diagnostic = ment.diag.HistogramND(
            axis=axis, edges=edges, kde=bool(i), kde_bandwidth_frac=bandwidth_frac
        )
        values = diagnostic(x)
        ax.pcolormesh(edges[0], edges[1], values.T)
    plt.show()

interactive(children=(Dropdown(description='axis1', options=(0, 1, 2, 3, 4, 5), value=0), Dropdown(description…