# 2:1 MENT — toy problem

In [None]:
import os
import time

import matplotlib.pyplot as plt
import matplotlib.colors
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.plot as psv

import ment
from ment.utils import unravel

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

## Settings

In [None]:
dist_name = "galaxy"
ndim = 2
n_meas = 6
n_bins = 80
xmax = 6.0
seed = 0

## Source distribution

In [None]:
dist = ment.dist.get_dist(dist_name, ndim=ndim, seed=seed, normalize=True)
X_true = dist.sample(1_000_000)

In [None]:
limits = 2 * [(-xmax, xmax)]

fig, axs = pplt.subplots(ncols=2)
for i, ax in enumerate(axs):
    hist, edges = np.histogramdd(X_true, bins=75, range=limits)
    psv.plot_points(
        X_true,
        limits=limits,
        bins=75,
        offset=1.0,
        norm=("log" if i else None),
        colorbar=True,
        ax=ax,
    )
plt.show()

## Data generation

In [None]:
phase_advances = np.linspace(0.0, np.pi, n_meas, endpoint=False)

transfer_matrices = []
for phase_advance in phase_advances:
    matrix = ment.sim.rotation_matrix(phase_advance)
    transfer_matrices.append(matrix)

transforms = []
for matrix in transfer_matrices:
    transform = ment.sim.LinearTransform(matrix)
    transforms.append(transform)

bin_edges = np.linspace(-xmax, xmax, n_bins + 1)

diagnostics = []
for transform in transforms:
    diagnostic = ment.diag.Histogram1D(axis=0, edges=bin_edges)
    diagnostics.append([diagnostic])

projections = ment.sim.forward(X_true, transforms, diagnostics)

## Reconstruction model

In [None]:
prior = ment.GaussianPrior(ndim=2, scale=1.0)

samp_grid_limits = limits
samp_grid_shape = ndim * [200]
sampler = ment.samp.GridSampler(grid_limits=samp_grid_limits, grid_shape=samp_grid_shape)

integration_limits = [(-xmax, xmax)]
integration_limits = [integration_limits for transform in transforms]
integration_size = 200

model = ment.MENT(
    ndim=ndim,
    projections=projections,
    transforms=transforms,
    diagnostics=diagnostics,
    prior=prior,
    interpolation_kws=dict(method="linear"),
    sampler=sampler,
    nsamp=1_000_000,
    integration_limits=integration_limits,
    integration_size=integration_size,
    verbose=True,
)

## Training

In [None]:
def plot_model(model):
    figs = []

    # Plot reconstructed vs. true distribution.
    X_pred = model.sample(1_000_000)

    fig, axs = pplt.subplots(ncols=2)
    for i, ax in enumerate(axs):
        psv.plot_points(
            X_pred,
            limits=limits,
            bins=128,
            offset=1.0,
            norm=("log" if i else None),
            colorbar=True,
            discrete=False,
            ax=ax,
        )
    figs.append(fig)

    # Plot simulated vs. measured projections.
    projections_true = model.projections
    projections_pred = ment.sim.forward(X_pred, transforms, diagnostics)

    ncols = min(n_meas, 7)
    nrows = int(np.ceil(n_meas / ncols))

    for log in [False, True]:
        fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figsize=(1.5 * ncols, 1.25 * nrows))
        for index, ax in enumerate(axs):
            transform = transforms[index]
            projection = unravel(projections)[index]
            diagnostic = unravel(diagnostics)[index]

            values_pred = diagnostic(transform(X_pred))
            values_meas = projection
            scale = np.max(values_meas)

            ax.plot(diagnostic.coords, values_meas / scale, color="lightgray")
            ax.plot(diagnostic.coords, values_pred / scale, color="black", marker=".", lw=0, ms=1.0)
            ax.format(ymax=1.25, xlim=(-xmax, xmax))
            if log:
                ax.format(yscale="log", ymax=5.0, ymin=1.00e-05, yformatter="log")
        figs.append(fig)

    return figs

In [None]:
model.mode = "sample"

trainer = ment.train.Trainer(
    model,
    plot_func=plot_model,
    eval_func=None,
    notebook=True,
)

trainer.train(epochs=4, learning_rate=0.90)