# KDE vs. histogram

In [None]:
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import ultraplot as uplt

import mentflow as mf
from mentflow.utils import grab

In [None]:
uplt.rc["cmap.discrete"] = False
uplt.rc["cmap.sequential"] = "viridis"
uplt.rc["figure.facecolor"] = "white"
uplt.rc["grid"] = False

## Settings

In [None]:
data_name = "swissroll"
xmax = 3.0
sizes = [1.00e+03, 1.00e+04, 1.00e+05, 1.00e+06]
bandwidth = 0.5
n_bins = 75
noise = True
noise_scale = 0.1

device = "cpu"

## 1D

In [None]:
edges = torch.linspace(-xmax, xmax, n_bins + 1)
diagnostic = mf.diagnostics.Histogram1D(
    axis=0, 
    edges=edges, 
    bandwidth=bandwidth, 
    noise=noise, 
    noise_scale=noise_scale
)
diagnostic = diagnostic.to(device)

distribution = mf.distributions.get_distribution(data_name)

fig, axs = uplt.subplots(ncols=len(sizes), figsize=(6.0, 1.25))
for ax, size in zip(axs, sizes):
    x = distribution.sample(int(size))
    x = x.type(torch.float32).to(device)
    for i in range(2):
        diagnostic.kde = i
        histogram = diagnostic(x)
        ax.plot(
            grab(diagnostic.coords), 
            grab(histogram), 
            label=["hist", "kde"][i], 
            color=["blue8", "red8"][i]
        )

    ax.format(title=f"n = {size:0.2e}")
axs[-1].legend(loc="r", ncols=1, framealpha=0.0, handlelength=1.5)
plt.show()

## 2D

In [None]:
edges = 2 * [torch.linspace(-xmax, xmax, n_bins + 1)]
diagnostic = mf.diagnostics.Histogram2D(
    axis=(0, 1), 
    edges=edges,
    bandwidth=(bandwidth, bandwidth),
)

cmaps = [
    uplt.Colormap("div", left=0.5),
    uplt.Colormap("div_r", left=0.5),
]

fig, axs = uplt.subplots(ncols=len(sizes), figwidth=6.0, nrows=2)
for j, size in enumerate(sizes):
    x = distribution.sample(int(size))
    x = x.type(torch.float32)
    for i in [0, 1]:
        diagnostic.kde = i
        histogram = diagnostic(x)
        axs[i, j].pcolormesh(
            grab(diagnostic.edges_x),
            grab(diagnostic.edges_y),
            grab(histogram.T),
            cmap=cmaps[i],
        )
    axs[0, j].format(title=f"n = {size:0.2e}")
axs.format(leftlabels=["hist", "kde"])
plt.show()