# N:2 MENT — sample-based solver

In [None]:
import os
import sys
import time

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.plot as psv
from ipywidgets import interact
from ipywidgets import widgets
from tqdm.notebook import tqdm
from tqdm.notebook import trange

import ment
from ment.utils import unravel

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

Settings:

In [None]:
dist_name = "gaussian_mixture"
ndim = 6
xmax = 3.5
seed = 12345
rng = np.random.default_rng(seed)

Define the source distribution.

In [None]:
dist = ment.dist.get_dist(dist_name, ndim=ndim, seed=seed, noise=0.25)
X_true = dist.sample(1_000_000)

# X_true[:, (0, 1)] = ment.dist.get_dist("galaxy").sample(X_true.shape[0])
# X_true[:, (2, 3)] = ment.dist.get_dist("galaxy").sample(X_true.shape[0])
# X_true[:, (2, 3)] = np.flip(X_true[:, (2, 3)], axis=1)

In [None]:
limits = ndim * [(-xmax, xmax)]

grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.5))
grid.plot_points(X_true, limits=limits, bins=64, mask=False)
plt.show()

Set up forward model.

In [None]:
# Settings
nmeas = ndim * (ndim - 1) // 2
nbins = 64
blur = 1.0
kde = False
kde_bandwidth = 1.0
axis_meas = (0, 2)

# Create transforms
transfer_matrices = []
for i in range(ndim):
    for j in range(i):
        matrices = []
        for k, l in zip(axis_meas, (j, i)):
            matrix = np.identity(ndim)
            matrix[k, k] = matrix[l, l] = 0.0
            matrix[k, l] = matrix[l, k] = 1.0
            matrices.append(matrix)
        transfer_matrices.append(np.linalg.multi_dot(matrices[::-1]))

transforms = []
for matrix in transfer_matrices:
    transform = ment.sim.LinearTransform(matrix)
    transforms.append(transform)

# Create diagnostics
bin_edges = len(axis_meas) * [np.linspace(-xmax, xmax, nbins + 1)]

diagnostics = []
for transform in transforms:
    diagnostic = ment.diag.Histogram2D(
        axis=axis_meas,
        edges=bin_edges, 
        kde=kde, 
        kde_bandwidth=kde_bandwidth, 
        blur=blur,
    )
    diagnostics.append([diagnostic])

Generate data.

In [None]:
projections = ment.sim.forward_with_diag_update(
    X_true, transforms, diagnostics, kde=False, blur=False,
)

thresh = None
thresh = 1.00e-04
for projection in unravel(projections):
    if thresh:
        projection[projection < thresh] = 0.0

In [None]:
@interact(index=(0, nmeas - 1))
def update(index: int):
    coords = unravel(diagnostics)[index].coords
    values = unravel(projections)[index]
    
    fig, ax = pplt.subplots()
    ax.pcolormesh(coords[0], coords[1], values.T)

## Model

In [None]:
sampler = "mcmc"
nsamp = 1_000_000
burnin = 500
chains = 1000
c = 2.4 / np.sqrt(ndim)

if sampler == "grid":
    samp_grid_res = 32
    samp_noise = 0.5
    samp_grid_shape = ndim * [samp_grid_res]
    samp_grid_limits = limits

    sampler = ment.samp.GridSampler(
        grid_limits=samp_grid_limits,
        grid_shape=samp_grid_shape,
        noise=samp_noise,
    )
elif sampler == "mcmc":
    proposal_cov = 0.1 * np.eye(ndim)
    
    start = np.random.multivariate_normal(np.zeros(ndim), 0.5 * np.eye(ndim), size=chains)

    # start = ment.dist.get_dist("waterbag", ndim=ndim).sample(chains)
    
    sampler = ment.samp.MetropolisHastingsSampler(
        ndim=ndim,
        chains=chains,
        proposal_cov=proposal_cov,
        start=start,
        burnin=burnin,
        shuffle=True,
        verbose=True,
    )
else:
    raise ValueError

prior = ment.GaussianPrior(ndim=ndim, scale=1.0)

model = ment.MENT(
    ndim=ndim,
    projections=projections,
    transforms=transforms,
    diagnostics=diagnostics,
    prior=prior,
    interpolation_kws=dict(method="linear"),
    sampler=sampler,
    nsamp=nsamp,
    mode="sample",
    verbose=True,
)

## Training

In [None]:
plot_nsamp = 1_000_000

In [None]:
def plot_model(model):
    figs = []

    # Sample particles
    X_pred = model.sample(plot_nsamp)

    # Plot reconstructed vs. true distribution
    grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.25), diag_shrink=0.80)
    grid.plot_points(
        X_pred,
        limits=limits,
        bins=65,
        mask=False,
        cmap="viridis",
    )
    figs.append(grid.fig)

    # Simulate measurements
    projections_true = model.projections
    projections_pred = ment.sim.forward_with_diag_update(
        X_pred, transforms, diagnostics, kde=False, blur=False
    )
    
    projections_true = unravel(projections_true)
    projections_pred = unravel(projections_pred)

    # Plot simulated vs. measured projections.    
    contour_levels = np.linspace(0.01, 1.0, 7)
    contour_colors = ["black", "red"]
    plot_kws = dict(
        process_kws=dict(
            blur=1.0, 
            norm="max",
        ),
        kind="contour",
        levels=contour_levels,
        lw=0.7,
    )

    ncols = min(nmeas, 7)
    nrows = int(np.ceil(nmeas / ncols))
    fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figwidth=(1.1 * ncols))
    for values_true, values_pred, ax in zip(projections_true, projections_pred, axs):
        psv.plot_image(values_true.T, ax=ax, colors=contour_colors[0], **plot_kws)
        psv.plot_image(values_pred.T, ax=ax, colors=contour_colors[1], **plot_kws)    
    figs.append(fig)

    return figs

In [None]:
trainer = ment.train.Trainer(
    model,
    plot_func=plot_model,
    eval_func=None,
    notebook=True,
)
trainer.train(
    epochs=3, 
    learning_rate=0.80, 
    thresh=1.00e-03,
)

## Final check

In [None]:
X = model.sample(100_000)