# 4:2 MENT — random uncoupled phase advances

In [None]:
import os
import sys
import time

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.plot as psv
from tqdm.notebook import tqdm
from tqdm.notebook import trange

import ment
from ment.utils import unravel

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

## Settings

In [None]:
dist_name = "gaussian_mixture"
ndim = 4
nmeas = 9
nbins = 50
xmax = 3.5
seed = 12345

## Source distribution

In [None]:
dist = ment.dist.get_dist(dist_name, ndim=ndim, seed=seed, noise=0.25)
X_true = dist.sample(1_000_000)

In [None]:
limits = ndim * [(-xmax, xmax)]

grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.5))
grid.plot_points(X_true, limits=limits, bins=75, mask=False, kind="contourf")
plt.show()

## Data generation

In [None]:
# Create transforms (random phase advances)
rng = np.random.default_rng(seed)
phase_advances = rng.uniform(0.0, np.pi, size=(nmeas, 2))
transfer_matrices = []
for mux, muy in phase_advances:
    matrix = np.eye(ndim)
    matrix[0:2, 0:2] = ment.sim.rotation_matrix(mux)
    matrix[2:4, 2:4] = ment.sim.rotation_matrix(muy)
    transfer_matrices.append(matrix)

transforms = []
for matrix in transfer_matrices:
    transform = ment.sim.LinearTransform(matrix)
    transforms.append(transform)

# Create diagnostics (x-y histogram).
axis_proj = (0, 2)
bin_edges = len(axis_proj) * [np.linspace(-xmax, xmax, nbins + 1)]

diagnostics = []
for transform in transforms:
    diagnostic = ment.diag.HistogramND(axis=(0, 2), edges=bin_edges)
    diagnostics.append([diagnostic])

# Generate data.
projections = ment.sim.forward(X_true, transforms, diagnostics)

## Reconstruction model

In [None]:
prior = ment.GaussianPrior(ndim=ndim, scale=1.0)

samp_grid_res = 32
samp_noise = 0.0
samp_grid_shape = ndim * [samp_grid_res]
samp_grid_limits = limits

sampler = ment.samp.GridSampler(
    grid_limits=samp_grid_limits,
    grid_shape=samp_grid_shape,
    noise=samp_noise,
)

integration_limits = [limits[axis] for axis in range(ndim) if axis not in axis_proj]
integration_limits = [[integration_limits]] * len(transforms)

model = ment.MENT(
    ndim=ndim,
    projections=projections,
    transforms=transforms,
    diagnostics=diagnostics,
    prior=prior,
    interpolation_kws=dict(method="linear"),
    sampler=sampler,
    nsamp=1_000_000,
    integration_limits=integration_limits,
    integration_size=(15**4),
    integration_batches=1,
    mode="sample",  # {"sample", "integrate"}
    verbose=2,
)

## Training

In [None]:
def plot_model(model):
    figs = []

    # Plot reconstructed vs. true distribution.
    X_pred = model.sample(1_000_000)

    grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.25), diag_shrink=0.80)
    grid.plot_points(
        X_pred,
        limits=limits,
        bins=65,
        mask=False,
        cmap="viridis",
    )
    figs.append(grid.fig)

    # Plot simulated vs. measured projections.
    projections_true = model.projections
    projections_true = unravel(projections_true)
    projections_pred = ment.sim.forward(X_pred, transforms, diagnostics)
    projections_pred = unravel(projections_pred)

    ncols = min(nmeas, 7)
    nrows = int(np.ceil(nmeas / ncols))
    fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figwidth=(1.1 * ncols))
    for ax, values_true, values_pred in zip(axs, projections_true, projections_pred):
        psv.plot_image(
            values_true.T,
            ax=ax,
            kind="contour",
            process_kws=dict(blur=0.5, norm="max"),
            colors="black",
            lw=0.7,
            levels=np.linspace(0.01, 1.0, 7),
        )
        psv.plot_image(
            values_pred.T,
            ax=ax,
            kind="contour",
            process_kws=dict(blur=0.5, norm="max"),
            colors="red",
            lw=0.7,
            levels=np.linspace(0.01, 1.0, 7),
        )
    figs.append(fig)

    return figs

In [None]:
model.mode = "sample"

trainer = ment.train.Trainer(
    model,
    plot_func=plot_model,
    eval_func=None,
    notebook=True,
)

trainer.train(epochs=3, learning_rate=0.80)