# N:1 MENT-S â€” random projections

In [None]:
import os
import sys
import time

import matplotlib.pyplot as plt
import ment
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.plot as psv
from tqdm.notebook import tqdm
from tqdm.notebook import trange

from utils import plot_corner_upper_lower

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

Settings:

In [None]:
dist_name = "gaussian_mixture"
ndim = 6
xmax = 4.0
seed = 12345
rng = np.random.default_rng(seed)

Define the source distribution.

In [None]:
dist = ment.dist.get_dist(dist_name, ndim=ndim, seed=seed, noise=0.25)
x_true = dist.sample(1_000_000)

In [None]:
limits = ndim * [(-xmax, xmax)]

grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.5))
grid.plot_points(x_true, limits=limits, bins=64, mask=False)
plt.show()

Generate data.

In [None]:
n_meas = 10
n_bins = 64
kde = False
kde_bandwidth = 1.0


class ProjectionTransform:
    def __init__(self, direction: np.ndarray) -> None:
        self.direction = direction

    def __call__(self, x: np.ndarray) -> np.ndarray:
        return np.sum(x * self.direction, axis=1)[:, None]
        

transforms = []
directions = np.random.normal(size=(n_meas, ndim))
for direction in directions:
    direction = np.random.normal(size=ndim)
    direction = direction / np.linalg.norm(direction)
    transform = ProjectionTransform(direction)    
    transforms.append(transform)

## Create diagnostics (x-y histogram).
axis_proj = axis_meas = 0
bin_edges = np.linspace(-xmax, xmax, n_bins + 1)

diagnostics = []
for transform in transforms:
    diagnostic = ment.diag.Histogram1D(axis=axis_meas, edges=bin_edges, kde=kde, kde_bandwidth=kde_bandwidth)
    diagnostics.append([diagnostic])

## Generate data.
measurements = []
for index, transform in enumerate(transforms):
    u = transform(x_true)
    measurements.append([diagnostic(u) for diagnostic in diagnostics[index]])

## Model

In [None]:
sampler = "mcmc"
n_samples = 500_000
burnin = 100

if sampler == "grid":
    samp_grid_res = 32
    samp_noise = 0.5
    samp_grid_shape = ndim * [samp_grid_res]
    samp_grid_limits = limits

    sampler = ment.samp.GridSampler(
        grid_limits=samp_grid_limits, 
        grid_shape=samp_grid_shape,
        noise=samp_noise,
    )
elif sampler == "mcmc":
    sampler = ment.samp.MetropolisHastingsSampler(
        ndim=ndim,
        chains=512,
        proposal_cov=(np.eye(ndim) * 0.25), 
        burnin=burnin, 
        shuffle=True, 
        verbose=True, 
    )
else:
    raise ValueError

prior = ment.GaussianPrior(ndim=ndim, scale=1.0)

model = ment.MENT(
    ndim=ndim,
    measurements=measurements,
    transforms=transforms,
    diagnostics=diagnostics,
    prior=prior,
    interpolation_kws=dict(method="linear"),
    
    sampler=sampler,
    n_samples=n_samples,

    mode="sample",
    verbose=True,
)

## Training

In [None]:
learning_rate = 1.0
n_epochs = 2
plot_n_samples = n_samples
plot_n_bins = n_bins


start_time = time.time()
for epoch in range(n_epochs + 1):
    print(f"epoch = {epoch}")

    ## Update model
    if epoch > 0:
        model.gauss_seidel_step(learning_rate=learning_rate)
        print("time = {:0.3f}".format(time.time() - start_time))

    ## Sample particles from posterior
    x_pred = model.sample(plot_n_samples)
    
    ## Plot projections
    ncols = min(7, n_meas)
    nrows = int(np.ceil(n_meas / ncols))
    figwidth = ncols * 1.5
    figheight = nrows * 1.0
    fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figwidth=figwidth, figheight=figheight)
    mean_abs_error = 0.0
    for transform, ax in zip(transforms, axs):
        u_pred = transform(x_pred)
        u_true = transform(x_true)[:plot_n_samples]
        hist_pred, _ = np.histogram(u_pred[:, axis_meas], bins=bin_edges, density=True)
        hist_true, _ = np.histogram(u_true[:, axis_meas], bins=bin_edges, density=True)
        mean_abs_error += np.mean(np.abs(hist_true - hist_pred))
        hist_pred = hist_pred / np.max(hist_true)
        hist_true = hist_true / np.max(hist_true)
        bin_coords = 0.5 * (bin_edges[:-1] + bin_edges[1:])
        ax.stairs(hist_pred, bin_edges, lw=1.5, color="red4")
        ax.plot(bin_coords, hist_true, lw=0.0, color="black", marker=".", ms=2.0, zorder=9999)
    axs.format(ymax=1.25)
    plt.show()

    mean_abs_error /= len(transforms)
    print("mean_abs_error =", mean_abs_error)

    # Plot corner
    axs = plot_corner_upper_lower(x_pred, x_true, n_bins=plot_n_bins, limits=limits)
    plt.show()

## Final check

In [None]:
x_pred = model.sample(500_000)

In [None]:
color = "pink"
bins = 32
limits = ndim * [(-xmax, xmax)]

grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.5))
grid.plot_points(
    x_true, 
    limits=limits,
    bins=bins, 
    mask=False, 
    cmap="mono",
    diag_kws=dict(lw=1.25, color="black"),
)
grid.plot_points(
    x_pred[:], 
    limits=limits,
    bins=bins,
    diag_kws=dict(lw=1.25, color="pink5"),
    alpha=0.0,
)
grid.plot_points(
    x_pred[:1_000], 
    diag=False,
    kind="scatter",
    c=color,
    s=0.5,
)
grid.set_limits(limits)
grid.set_labels([r"$x$", r"$p_x$", r"$y$", r"$p_y$", r"$z$", r"$p_z$"])
plt.show()