# 2:1 MENT — toy problem

In [None]:
import os
import time

import matplotlib.pyplot as plt
import matplotlib.colors
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.plot as psv

import ment

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

Settings:

In [None]:
dist_name = "galaxy"
ndim = 2
n_meas = 6
n_bins = 80
xmax = 6.0
seed = 0

Define the source distribution.

In [None]:
dist = ment.dist.get_dist(dist_name, ndim=ndim, seed=seed, normalize=True)
x_true = dist.sample(1_000_000)

In [None]:
limits = 2 * [(-xmax, xmax)]

fig, axs = pplt.subplots(ncols=2)
for i, ax in enumerate(axs):
    hist, edges = np.histogramdd(x_true, bins=75, range=limits)
    psv.plot_points(x_true, limits=limits, bins=75, offset=1.0, norm=("log" if i else None), colorbar=True, ax=ax)
plt.show()

Create the measurement data.

In [None]:
phase_advances = np.linspace(0.0, np.pi, n_meas, endpoint=False)

transfer_matrices = []
for phase_advance in phase_advances:
    matrix = ment.sim.rotation_matrix(phase_advance)
    transfer_matrices.append(matrix)

transforms = []
for matrix in transfer_matrices:
    transform = ment.sim.LinearTransform(matrix)
    transforms.append(transform)

edges = np.linspace(-xmax, xmax, n_bins + 1)

diagnostics = []
for transform in transforms:
    diagnostic = ment.diag.Histogram1D(axis=0, edges=edges)
    diagnostics.append([diagnostic])

measurements = []
for index, transform in enumerate(transforms):
    u = transform(x_true)
    measurements.append([diagnostic(u) for diagnostic in diagnostics[index]])

Set up MENT reconstruction model.

In [None]:
prior = ment.GaussianPrior(ndim=2, scale=1.0)

samp_grid_limits = limits
samp_grid_shape = ndim * [200]
sampler = ment.samp.GridSampler(grid_limits=samp_grid_limits, grid_shape=samp_grid_shape)

integration_limits = [(-xmax, xmax)]
integration_limits = [integration_limits for transform in transforms]
integration_size = 200
integration_batches = 10

model = ment.MENT(
    ndim=ndim,
    measurements=measurements,
    transforms=transforms,
    diagnostics=diagnostics,
    prior=prior,
    interpolation_kws=dict(method="linear"),
    sampler=sampler,
    n_samples=1_000_000,
    integration_limits=integration_limits,
    integration_size=integration_size,
    integration_batches=integration_batches,
    verbose=True,
    mode="integrate",  # {"integrate", "sample"}
)

Train the model via Gauss-Seidel iterations.

In [None]:
# Settings
learning_rate = 0.90
n_epochs = 6


for epoch in range(-1, n_epochs):
    print("epoch =", epoch)
    
    if epoch >= 0:
        model.gauss_seidel_step(learning_rate)

    # Sample particles
    x_pred = model.sample(1_000_000)

    # Plot 2D density
    fig, axs = pplt.subplots(ncols=2)
    for i, ax in enumerate(axs):
        hist, edges = np.histogramdd(x_true, bins=75, range=limits)
        psv.plot_points(x_pred, limits=limits, bins=75, offset=1.0, norm=("log" if i else None), colorbar=True, ax=ax)
    plt.show()

    # Plot measured vs. simulated projections    
    ncols = min(n_meas, 7)
    nrows = int(np.ceil(n_meas / ncols))
    figwidth = 1.5 * ncols
    figheight = 1.25 * nrows

    for log in [False, True]:
        fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figwidth=figwidth, figheight=figheight, sharex=True, sharey=True)
        
        for index, transform in enumerate(transforms):
            values_pred = diagnostic(transform(x_pred))
            values_meas = np.copy(measurements[index][0])
            ax = axs[index]

            diagnostic = diagnostics[index][0]
    
            values_pred /= np.max(values_meas)
            values_meas /= np.max(values_meas)
            ax.plot(diagnostic.coords, values_meas, color="red3")
            ax.plot(diagnostic.coords, values_pred, color="black")
            ax.format(ymax=1.25, xlim=(-xmax, xmax))
            if log:
                ax.format(yscale="log", ymax=5.0, ymin=1.00e-05, yformatter="log")
        plt.show()