# Train 2D iterative MENT solver

In [None]:
import numpy as np
import proplot as pplt
import scipy.interpolate
from tqdm.notebook import tqdm

import mentflow as mf

pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

## Data

In [None]:
data_name = "spirals"
data_size = int(1.00e+06)
data_noise = None
xmax = 3.25
n_bins = 64
n_meas = 30

dist = mf.data.toy.gen_dist(data_name, noise=data_noise)
x0 = dist.sample(data_size)

angles = np.linspace(0.0, np.pi, n_meas, endpoint=False)
transfer_matrices = []
for angle in angles:
    matrix = mf.transform.rotation_matrix(angle)
    matrix = matrix.numpy()
    transfer_matrices.append(matrix)
transforms = [mf.models.ment.LinearTransform(matrix) for matrix in transfer_matrices]

bin_edges = np.linspace(-xmax, xmax, n_bins + 1)
diagnostic = mf.models.ment.HistogramDiagnostic(bin_edges, axis=0)

measurements = [diagnostic(transform(x0)) for transform in transforms]

## Model

In [None]:
prior = mf.models.ment.GaussianPrior(scale=1.0)
# prior = mf.models.ment.UniformPrior(scale=20.0)

sampler = mf.sample.GridSampler(limits=(2 * [(-xmax, xmax)]), res=200)

model = mf.models.ment.MENT_2D1D(
    transforms=transforms,
    measurements=measurements,
    diagnostic=diagnostic,
    prior=prior,
    sampler=sampler,
    interpolate="pchip"  # {False, "linear", "pchip"}
)

## Training

In [None]:
n_iterations = 10

for iteration in range(n_iterations):
    # Compute the model density.
    res = 200    
    grid_coords = 2 * [np.linspace(-xmax, xmax, res)]
    grid_points = mf.utils.get_grid_points(grid_coords)
    prob = model.prob(grid_points)
    prob = np.reshape(prob, (res, res))
    prob = prob / np.max(prob)

    # Draw samples from the model.
    n_samples = 100000
    x = model.sample(n_samples)

    # Simulate the measurements.
    predictions = model.simulate(method="sample", n=n_samples)

    # Plot density.
    fig, axs = pplt.subplots(ncols=3, space=0, xspineloc="neither", yspineloc="neither")
    kws = dict()
    limits = 2 * [(-xmax, xmax)]
    axs[0].hist2d(x0[:n_samples, 0], x0[:n_samples, 1], bins=100, range=limits, **kws)
    axs[1].pcolormesh(grid_coords[0], grid_coords[1], prob.T, **kws)
    axs[2].hist2d(x[:, 0], x[:, 1], bins=100, range=limits, **kws)
    axs.format(toplabels=["True samples", "MENT density", "MENT samples"])
    pplt.show()
    
    # Plot projections.
    ncols = 7
    nrows = 1 + n_meas // ncols
    fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figheight=(1.1 * nrows), figwidth=10.0)
    for i in range(n_meas):
        kws = dict(lw=1.25)
        axs[i].plot(model.bin_coords, measurements[i], color="black", **kws)
        axs[i].plot(model.bin_coords, predictions[i], color="red", **kws)
    pplt.show()

    # Compute discrepancy.
    discrepancy = sum(np.sum(np.abs(predictions[i] - measurements[i])) for i in range(n_meas)) / n_meas

    # Print a summary
    print("iteration = {}".format(iteration))
    print("discrepancy = {}".format(discrepancy))

    ## Update lagrange multipliers.
    model.gauss_seidel_iterate(method="integrate")
    # model.gauss_seidel_iterate(method="sample", n=100000)