# 4:2 MENT

In [None]:
import os
import sys
import time

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.visualization as psv

import ment

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

Settings:

In [None]:
dist_name = "gaussian_mixture"
ndim = 4
n_meas = 32
n_bins = 70
xmax = 3.5
seed = 145

Define the source distribution.

In [None]:
dist = ment.dist.get_dist(
    dist_name, ndim=ndim, seed=seed, 
    # scale=1.0,
)
x_true = dist.sample(1_000_000)

limits = ndim * [(-xmax, xmax)]

grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.5))
grid.plot_points(x_true, limits=limits, bins=75, mask=False);

Generate data.

In [None]:
# ## Create transforms: rotations in x-x' and y-y'.
# phase_advances_x = np.linspace(0.0, np.pi, int(np.sqrt(n_meas)), endpoint=False)
# phase_advances_y = phase_advances_x
# transfer_matrices = []
# for mux in phase_advances_x:
#     for muy in phase_advances_y:
#         matrix = np.eye(ndim)
#         matrix[0:2, 0:2] = ment.sim.rotation_matrix(mux)
#         matrix[2:4, 2:4] = ment.sim.rotation_matrix(muy)
#         transfer_matrices.append(matrix)


# Create transforms
rng = np.random.default_rng(seed)
phase_advances = rng.uniform(0.0, np.pi, size=(n_meas, 2))
transfer_matrices = []
for (mux, muy) in phase_advances:
    matrix = np.eye(ndim)
    matrix[0:2, 0:2] = ment.sim.rotation_matrix(mux)
    matrix[2:4, 2:4] = ment.sim.rotation_matrix(muy)
    transfer_matrices.append(matrix)


transforms = []
for matrix in transfer_matrices:
    transform = ment.sim.LinearTransform(matrix)
    transforms.append(transform)

# Create diagnostics (x-y histogram).
axis_proj = (0, 2)
bin_edges = len(axis_proj) * [np.linspace(-xmax, xmax, n_bins + 1)]

diagnostics = []
for transform in transforms:
    diagnostic = ment.diag.HistogramND(axis=(0, 2), bin_edges=bin_edges)
    diagnostics.append([diagnostic])
diagnostics

# Generate measurement data.
measurements = []
for index, transform in enumerate(transforms):
    u = transform(x_true)
    measurements.append([diagnostic(u) for diagnostic in diagnostics[index]])

Create reconstruction model.

In [None]:
# Prior distribution
prior = ment.UniformPrior(ndim=ndim, scale=(2.0 * xmax))

# Sampler
samp_grid_res = 32
samp_noise = 1.0
samp_grid_shape = ndim * [samp_grid_res]
samp_grid_limits = limits

sampler = ment.samp.GridSampler(
    grid_limits=samp_grid_limits, 
    grid_shape=samp_grid_shape,
    noise=samp_noise,
)

# MENT solver
model = ment.MENT(
    ndim=ndim,
    measurements=measurements,
    transforms=transforms,
    diagnostics=diagnostics,
    prior=prior,
    sampler=sampler,
    n_samples=200_000,
    verbose=True,
    interpolation=dict(method="linear"),
)


learning_rate = 0.90
n_epochs = 10

In [None]:
t0 = time.time()
model.sample(10000)
print(time.time() - t0)

In [None]:
start_time = time.time()

for epoch in range(-1, n_epochs):
    print(f"epoch = {epoch}")
    if epoch >= 0:
        model.gauss_seidel_step(lr=learning_rate)
        print("time = {:0.3f}".format(time.time() - start_time))

    x = model.sample(1_000_000)
    
    grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.5), corner=False)
    kws = dict(limits=limits, bins=75, mask=True)
    grid.plot_points(
        x_true, 
        lower=False, 
        diag_kws=dict(kind="line", color="red8", lw=1.1),
        cmap=pplt.Colormap("reds"),
    )
    grid.plot_points(
        x, 
        upper=False, 
        diag_kws=dict(kind="line", color="blue8", lw=1.1), 
        cmap=pplt.Colormap("blues"),
        **kws
    )
    plt.show()