# 6:2 MENT — Gaussian mixture

In [None]:
import os
import sys
import time

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.plot as psv
from tqdm.notebook import tqdm
from tqdm.notebook import trange

import ment

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

Settings:

In [None]:
dist_name = "gaussian_mixture"
ndim = 6
n_meas = 9
n_bins = 50
xmax = 3.5
seed = 12345

Define the source distribution.

In [None]:
dist = ment.dist.get_dist(dist_name, ndim=ndim, seed=seed)
x_true = dist.sample(1_000_000)

In [None]:
limits = ndim * [(-xmax, xmax)]

grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.5))
grid.plot_points(x_true, limits=limits, bins=75, mask=False)
plt.show()

Generate data.

In [None]:
## Create transforms (random phase advances)
rng = np.random.default_rng(seed)
phase_advances = rng.uniform(0.0, np.pi, size=(n_meas, 2))
transfer_matrices = []
for (mux, muy) in phase_advances:
    matrix = np.eye(ndim)
    matrix[0:2, 0:2] = ment.sim.rotation_matrix(mux)
    matrix[2:4, 2:4] = ment.sim.rotation_matrix(muy)
    transfer_matrices.append(matrix)

transforms = []
for matrix in transfer_matrices:
    transform = ment.sim.LinearTransform(matrix)
    transforms.append(transform)

## Create diagnostics (x-y histogram).
axis_proj = (0, 2)
bin_edges = len(axis_proj) * [np.linspace(-xmax, xmax, n_bins + 1)]

diagnostics = []
for transform in transforms:
    diagnostic = ment.diag.HistogramND(axis=(0, 2), edges=bin_edges)
    diagnostics.append([diagnostic])

## Generate data.
measurements = []
for index, transform in enumerate(transforms):
    u = transform(x_true)
    measurements.append([diagnostic(u) for diagnostic in diagnostics[index]])

Create MENT reconstruction model.

In [None]:
prior = ment.GaussianPrior(ndim=ndim, scale=2.0)

samp_grid_res = 10
samp_noise = 1.0
samp_grid_shape = ndim * [samp_grid_res]
samp_grid_limits = limits

sampler = ment.samp.GridSampler(
    grid_limits=samp_grid_limits, 
    grid_shape=samp_grid_shape,
    noise=samp_noise,
)

integration_limits = [limits[axis] for axis in range(ndim) if axis not in axis_proj]
integration_limits = [[integration_limits]] * len(transforms)

model = ment.MENT(
    ndim=ndim,
    measurements=measurements,
    transforms=transforms,
    diagnostics=diagnostics,
    prior=prior,
    interpolation_kws=dict(method="linear"),
    
    sampler=sampler,
    n_samples=1_000_000,
    
    integration_limits=integration_limits,
    integration_size=(15 ** 4), 
    integration_batches=1,

    mode="integrate",
    verbose=True,
)

learning_rate = 0.80
n_epochs = 2

Train the model.

In [None]:
start_time = time.time()

for epoch in range(-1, n_epochs):
    print(f"epoch = {epoch}")
    
    if epoch >= 0:
        model.gauss_seidel_step(lr=learning_rate)
        print("time = {:0.3f}".format(time.time() - start_time))

    x_pred = model.sample(1_000_000)
    
    grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.25), corner=False)
    kws = dict(limits=limits, bins=75, mask=True)
    grid.plot_points(
        x_true, 
        lower=False, 
        diag_kws=dict(kind="step", color="red8", lw=1.25),
        cmap=psv.cubehelix_cmap(color="red"),
        **kws
    )
    grid.plot_points(
        x_pred, 
        upper=False, 
        diag_kws=dict(kind="step", color="blue8", lw=1.25), 
        cmap=psv.cubehelix_cmap(color="blue"),
        **kws
    )
    plt.show()

Check model predictions.

In [None]:
for index in range(n_meas):
    print("index =", index)

    diag_index = 0
    diagnostic = diagnostics[index][diag_index]
    values_meas = measurements[index][diag_index]
    values_pred = model.simulate(index, diag_index)

    coords = diagnostic.coords
    
    fig, axs = pplt.subplots(ncols=2, figwidth=4.0, xspineloc="neither", yspineloc="neither", space=0.0)
    for ax, values in zip(axs, [values_pred, values_meas]):
        ax.pcolormesh(coords[0], coords[1], values.T, robust=False, cmap="mono")
    axs.format(toplabels=["PRED", "MEAS"])
    plt.show()