# N:2 MENT — sample-based solver

In [None]:
import os
import sys
import time

import matplotlib.pyplot as plt
import ment
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.plot as psv
from tqdm.notebook import tqdm
from tqdm.notebook import trange

from utils import plot_corner_upper_lower

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

Settings:

In [None]:
dist_name = "kv"
ndim = 6
xmax = 3.5
seed = 12345
rng = np.random.default_rng(seed)

Define the source distribution.

In [None]:
dist = ment.dist.get_dist(dist_name, ndim=ndim, seed=seed, noise=0.5)
x_true = dist.sample(1_000_000)
x_true[:, 0:2] = ment.dist.get_dist("two-spirals", noise=0.2).sample(x_true.shape[0])
x_true[:, 2:4] = ment.dist.get_dist("swissroll").sample(x_true.shape[0])

In [None]:
limits = ndim * [(-xmax, xmax)]

grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.5))
grid.plot_points(x_true, limits=limits, bins=64, mask=False)
plt.show()

Generate data.

In [None]:
n_bins = 50
kde = False
kde_bandwidth = 1.0

## Measure 2D marginals
rng = np.random.default_rng(seed)
axis_meas = (0, 2)
n_meas = ndim * (ndim - 1) // 2

transfer_matrices = []
for i in range(ndim):
    for j in range(i):
        matrices = []
        for k, l in zip(axis_meas, (j, i)):
            matrix = np.identity(ndim)
            matrix[k, k] = matrix[l, l] = 0.0
            matrix[k, l] = matrix[l, k] = 1.0
            matrices.append(matrix)
        transfer_matrices.append(np.linalg.multi_dot(matrices[::-1]))

transforms = []
for matrix in transfer_matrices:
    transform = ment.sim.LinearTransform(matrix)
    transforms.append(transform)

## Create diagnostics (x-y histogram).
axis_proj = axis_meas
bin_edges = len(axis_proj) * [np.linspace(-xmax, xmax, n_bins + 1)]

diagnostics = []
for transform in transforms:
    diagnostic = ment.diag.Histogram2D(axis=(0, 2), edges=bin_edges, kde=kde, kde_bandwidth=kde_bandwidth)
    diagnostics.append([diagnostic])

## Generate data.
measurements = []
for index, transform in enumerate(transforms):
    u = transform(x_true)
    measurements.append([])
    for diagnostic in diagnostics[index]:
        diagnostic.kde = False
        measurements[-1].append(diagnostic(u))
        diagnostic.kde = kde

## Model

In [None]:
sampler = "mcmc"
n_samples = 1_000_000
burnin = 1000

if sampler == "grid":
    samp_grid_res = 32
    samp_noise = 0.5
    samp_grid_shape = ndim * [samp_grid_res]
    samp_grid_limits = limits

    sampler = ment.samp.GridSampler(
        grid_limits=samp_grid_limits, 
        grid_shape=samp_grid_shape,
        noise=samp_noise,
    )
elif sampler == "mcmc":
    sampler = ment.samp.MetropolisHastingsSampler(
        ndim=ndim,
        chains=512,
        proposal_cov=(np.eye(ndim) * 0.25), 
        burnin=burnin, 
        shuffle=True, 
        verbose=True, 
    )
else:
    raise ValueError

prior = ment.GaussianPrior(ndim=ndim, scale=1.0)

model = ment.MENT(
    ndim=ndim,
    measurements=measurements,
    transforms=transforms,
    diagnostics=diagnostics,
    prior=prior,
    interpolation_kws=dict(method="linear"),
    
    sampler=sampler,
    n_samples=n_samples,

    mode="sample",
    verbose=True,
)

## Training

In [None]:
learning_rate = 1.0
n_epochs = 2
plot_n_samples = n_samples
plot_n_bins = n_bins


start_time = time.time()
for epoch in range(n_epochs + 1):
    print(f"epoch = {epoch}")

    ## Update model
    if epoch > 0:
        model.gauss_seidel_step(learning_rate=learning_rate)
        print("time = {:0.3f}".format(time.time() - start_time))

    ## Sample particles from posterior
    x_pred = model.sample(plot_n_samples)

    # Plot corner
    axs = plot_corner_upper_lower(x_pred, x_true, n_bins=plot_n_bins, limits=limits)
    plt.show()

    ## Plot measured vs. simulated projections side-by-side.
    for index in range(n_meas):
        u_pred = transforms[index](x_pred)
        u_true = transforms[index](x_true)
        u_true = u_true[:u_pred.shape[0]]

        values_pred, edges = np.histogramdd(u_pred[:, axis_meas], bin_edges, density=True)
        values_true, edges = np.histogramdd(u_true[:, axis_meas], bin_edges, density=True)
        
        fig, axs = pplt.subplots(ncols=2, figwidth=3.0, xspineloc="neither", yspineloc="neither", space=0.0)
        axs[0].pcolormesh(edges[0], edges[1], values_pred.T, cmap="mono")
        axs[1].pcolormesh(edges[0], edges[1], values_true.T, cmap="mono")
        axs.format(suptitle=f"index={index}", toplabels=["PRED", "MEAS"])
        plt.show()

## Final check

In [None]:
x_pred = model.sample(1_000_000)

In [None]:
color = "pink"
bins = 32

grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.5))
grid.plot_points(
    x_true, 
    limits=(ndim * [(-xmax, xmax)]), 
    bins=bins, 
    mask=False, 
    cmap="mono",
    diag_kws=dict(lw=1.25, color="black"),
)
grid.plot_points(
    x_pred[:1_000], 
    diag=False,
    kind="scatter",
    c=color,
    s=0.5,
)
grid.plot_points(
    x_pred[:], 
    diag_kws=dict(lw=1.25, color="pink5"),
    bins=bins,
    alpha=0.0,
)
grid.set_labels([r"$x$", r"$p_x$", r"$y$", r"$p_y$", r"$z$", r"$p_z$"])
plt.show()