# Distributions

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt

import psdist as ps
import psdist.visualization as psv

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "mono"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

## 2D

In [None]:
names = [
    "gaussian",
    "waterbag",
]

In [None]:
def plot_dist_2d(dist, n=10000, xmax=3.5, res=200):
    x = dist.sample(n)

    coords = 2 * [np.linspace(-xmax, xmax, res)]
    x_grid = ps.image.get_grid_points(coords)
    prob = dist.prob(x_grid).reshape((res, res))

    fig, axs = pplt.subplots(ncols=2, figheight=2.0, space=0.0, xspineloc="neither", yspineloc="neither")
    psv.points.plot2d(x, bins=100, limits=(2 * [(-xmax, xmax)]), ax=axs[0], mask=False)
    psv.image.plot2d(prob, coords=coords, ax=axs[1])
    return axs

In [None]:
for name in names:
    print(name)
    dist = ps.distributions.get_distribution(name=name, ndim=2)
    axs = plot_dist_2d(dist, n=10000, xmax=4.0, res=200)
    plt.show()

### 4D 

In [None]:
def plot_dist_corner(dist, n=100000, xmax=3.5, res=45):
    ndim = dist.ndim

    x = dist.sample(n)

    coords = ndim * [np.linspace(-xmax, xmax, res)]
    x_grid = ps.image.get_grid_points(coords)
    prob = dist.prob(x_grid).reshape(tuple(ndim * [res]))

    grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.25), corner=False)
    grid.plot_points(x, bins=75, limits=(ndim * [(-xmax, xmax)]), upper=False, diag=True)
    grid.plot_image(prob, coords=coords, lower=False, diag=True)
    grid.axs.format(xticklabels=[], yticklabels=[])
    return grid

In [None]:
for ndim in [3, 4, 5]:
    print(f"ndim={ndim}")
    for name in names:
        print(name)
        dist = ps.distributions.get_distribution(name=name, ndim=ndim)
        grid = plot_dist_corner(dist, res=25)
        plt.show()