# Train 2D MENT algorithm

In [None]:
import numpy as np
import proplot as pplt
import scipy.interpolate
from tqdm.notebook import tqdm

import mentflow as mf

pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

In [None]:
data_name = "spirals"
data_size = int(1.00e+06)
data_noise = None
xmax = 3.0
n_bins = 90
n_meas = 6

dist = mf.data.toy.gen_dist(data_name, noise=data_noise)
x0 = dist.sample(data_size)

angles = np.linspace(0.0, np.pi, n_meas, endpoint=False)
transfer_matrices = []
for angle in angles:
    matrix = mf.transform.rotation_matrix(angle)
    matrix = matrix.numpy()
    transfer_matrices.append(matrix)

transforms = []
for matrix in transfer_matrices:
    transform = mf.models.ment.LinearTransform(matrix)
    transforms.append(transform)

bin_edges = np.linspace(-xmax, xmax, n_bins + 1)
diagnostic = mf.models.ment.HistogramDiagnostic(bin_edges, axis=0)

measurements = []
for transform in transforms:
    x_out = transform(x0)
    measurement = diagnostic(x_out)
    measurements.append(measurement)

prior = None
prior = mf.models.ment.GaussianPrior(scale=1.0)
# prior = mf.models.ment.UniformPrior(scale=20.0)

model = mf.models.ment.MENT_2D1D_numpy(
    transforms=transforms,
    measurements=measurements,
    diagnostic=diagnostic,
    prior=prior,
)

In [None]:
n_iterations = 10

for iteration in range(n_iterations):
    # Compute the model density.
    res = 200    
    grid_coords = 2 * [np.linspace(-xmax, xmax, res)]
    grid_points = mf.utils.get_grid_points(grid_coords)
    prob = model.prob(grid_points)
    prob = np.reshape(prob, (res, res))
    prob = prob / np.max(prob)

    # Draw samples from the model.
    n_samples = 100000
    x = model.sample(n_samples, xmax=(1.1 * xmax))

    fig, axs = pplt.subplots(ncols=3, space=0, xspineloc="neither", yspineloc="neither")
    kws = dict()
    limits = 2 * [(-xmax, xmax)]
    axs[0].hist2d(x0[:n_samples, 0], x0[:n_samples, 1], bins=100, range=limits, **kws)
    axs[1].pcolormesh(grid_coords[0], grid_coords[1], prob.T, **kws)
    axs[2].hist2d(x[:, 0], x[:, 1], bins=100, range=limits, **kws)
    axs.format(toplabels=["True samples", "MENT density", "MENT samples"])
    pplt.show()

    # Simulate the measurements.
    predictions = []
    for i, transform in enumerate(model.transforms):
        x_out = transform(x)
        hist, _ = np.histogram(x_out[:, 0], bins=model.bin_edges, density=True)
        predictions.append(hist)

    # Plot simulated vs. measured projections.
    ncols = 7
    nrows = 1 + n_meas // ncols
    fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figheight=(1.1 * nrows), figwidth=10.0)
    for i in range(n_meas):
        kws = dict(lw=1.25)
        axs[i].plot(model.bin_coords, measurements[i], color="black", **kws)
        axs[i].plot(model.bin_coords, predictions[i], color="red", **kws)
    pplt.show()

    # Compute the mean discrepancy.
    cost = sum(np.sum(np.abs(predictions[i] - measurements[i])) for i in range(n_meas)) / n_meas
    print("cost = {}".format(cost))

    # Update lagrange multipliers.
    model.gauss_seidel_iterate(method="integrate", n_samples=int(1.00e+06))