# N:1 MENT — MCMC

In [None]:
import os
import sys
import time

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.plot as psv
from tqdm.notebook import tqdm
from tqdm.notebook import trange

import ment

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

Settings:

In [None]:
dist_name = "gaussian_mixture"
ndim = 4
xmax = 3.5
seed = 12345

Define the source distribution.

In [None]:
dist = ment.dist.get_dist(dist_name, ndim=ndim, seed=seed)
x_true = dist.sample(1_000_000)

In [None]:
limits = ndim * [(-xmax, xmax)]

grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.5))
grid.plot_points(x_true, limits=limits, bins=64, mask=False)
plt.show()

Generate data.

In [None]:
## Measure 1D marginals
rng = np.random.default_rng(seed)
axis_meas = 0
n_bins = 50
n_meas = ndim

transfer_matrices = []
for i in range(ndim):
    j = axis_meas
    matrix = np.identity(ndim)
    matrix[i, i] = matrix[j, j] = 0.0
    matrix[i, j] = matrix[j, i] = 1.0
    transfer_matrices.append(matrix)

transforms = []
for matrix in transfer_matrices:
    transform = ment.sim.LinearTransform(matrix)
    transforms.append(transform)

## Create diagnostics (x-y histogram).
axis_proj = axis_meas
bin_edges = np.linspace(-xmax, xmax, n_bins + 1)

diagnostics = []
for transform in transforms:
    diagnostic = ment.diag.Histogram1D(axis=axis_meas, edges=bin_edges, kde=False, kde_bandwidth=1.0)
    diagnostics.append([diagnostic])

## Generate data.
measurements = []
for index, transform in enumerate(transforms):
    u = transform(x_true)
    measurements.append([diagnostic(u) for diagnostic in diagnostics[index]])

## Model

In [None]:
sampler = "mcmc"

if sampler == "grid":
    samp_grid_res = 32
    samp_noise = 0.5
    samp_grid_shape = ndim * [samp_grid_res]
    samp_grid_limits = limits

    sampler = ment.samp.GridSampler(
        grid_limits=samp_grid_limits, 
        grid_shape=samp_grid_shape,
        noise=samp_noise,
    )
else:
    sampler = ment.samp.MetropolisHastingsSampler(
        ndim=ndim, scale=1.0, burnin=10_000, shuffle=True, verbose=True,
    )

prior = ment.GaussianPrior(ndim=ndim, scale=1.0)

n_samples = 25_000

model = ment.MENT(
    ndim=ndim,
    measurements=measurements,
    transforms=transforms,
    diagnostics=diagnostics,
    prior=prior,
    interpolation_kws=dict(method="linear"),
    
    sampler=sampler,
    n_samples=n_samples,

    mode="sample",
    verbose=True,
)

## Training

In [None]:
def plot_points(x_pred: np.ndarray, x_true: np.ndarray, n_bins: int):
    n = x_pred.shape[0]
    
    grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.25), corner=False)
    kws = dict(limits=limits, bins=n_bins, mask=True)
    grid.plot_points(
        x_true[:n], 
        lower=False, 
        diag_kws=dict(kind="step", color="red8", lw=1.25),
        cmap=psv.cubehelix_cmap(color="red"),
        **kws
    )
    grid.plot_points(
        x_pred[:n], 
        upper=False, 
        diag_kws=dict(kind="step", color="blue8", lw=1.25), 
        cmap=psv.cubehelix_cmap(color="blue"),
        **kws
    )
    return grid.axs

In [None]:
learning_rate = 1.0
n_epochs = 2
plot_n_samples = 128_000
plot_n_bins = n_bins


start_time = time.time()

for epoch in range(n_epochs + 1):
    print(f"epoch = {epoch}")

    # Update model
    if epoch > 0:
        model.gauss_seidel_step(learning_rate=learning_rate)
        print("time = {:0.3f}".format(time.time() - start_time))

    # Plot samples from posterior
    x_pred = model.sample(plot_n_samples)
    axs = plot_points(x_pred, x_true, n_bins=plot_n_bins)
    plt.show()

## Final check

In [None]:
x_pred = model.sample(500_000)

In [None]:
color = "pink"
bins = 32

grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.5))
grid.plot_points(
    x_true, 
    limits=(ndim * [(-xmax, xmax)]), 
    bins=bins, 
    mask=False, 
    cmap="mono",
    diag_kws=dict(lw=1.25, color="black"),
)
grid.plot_points(
    x_pred[:1_000], 
    diag=False,
    kind="scatter",
    c=color,
    s=0.5,
)
grid.plot_points(
    x_pred[:], 
    diag_kws=dict(lw=1.25, color="pink5"),
    bins=bins,
    alpha=0.0,
)
grid.set_labels([r"$x$", r"$p_x$", r"$y$", r"$p_y$", r"$z$", r"$p_z$"])
plt.show()