# Train 2D MENT algorithm

In [None]:
import mentflow as mf
import numpy as np
import proplot as pplt
import scipy.interpolate
from tqdm.notebook import tqdm

pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

In [None]:
data_name = "spirals"
data_size = int(1.00e+06)
data_noise = None
xmax = 3.0
n_bins = 64
n_meas = 6

z0 = mf.data.toy.gen_data(data_name, size=data_size, noise=data_noise)

angles = np.linspace(0.0, np.pi, n_meas, endpoint=False)
transfer_matrices = []
for angle in angles:
    matrix = mf.utils.rotation_matrix(angle)
    transfer_matrices.append(matrix)
    
measurements = []
edges = []
for matrix in transfer_matrices:
    z_out = np.apply_along_axis(lambda row: np.matmul(matrix, row), 1, z0)
    _edges = np.linspace(-xmax, xmax, n_bins + 1)
    _hist, _ = np.histogram(z_out[:, 0], bins=_edges, density=True)
    measurements.append(_hist)
    edges.append(_edges)

prior = mf.models.ment.GaussianPrior(scale=1.0)

model = mf.models.MENT(
    transfer_matrices=transfer_matrices, 
    measurements=measurements,
    edges=edges,
    prior=prior,
)

In [None]:
for step in range(10):
    predictions = model.forward_samples(n=10000, xmax=8.0, res=200)
    error = sum([np.mean(np.abs(predictions[i] - measurements[i])) for i in range(n_meas)])

    print(f"step={step}")
    print(f"error={error}")
    print(model.lagrange_multipliers[0])

    # Plot projections.
    fig, axs = pplt.subplots(ncols=7, figheight=1.25, figwidth=10.0)
    for i in range(n_meas):
        kws = dict(lw=1.25)
        axs[i].stairs(measurements[i], model.edges[i], color="black", **kws)
        axs[i].stairs(predictions[i], model.edges[i], color="red", **kws)
    pplt.show()

    # Plot density. 
    res = 150
    xmax = xmax
    
    grid_coords = 2 * [np.linspace(-xmax, xmax, res)]
    grid_points = mf.utils.get_grid_points(grid_coords)
    prob = model.prob_vectorized(grid_points)
    prob = np.reshape(prob, (res, res))
    prob = prob / np.max(prob)
    
    limits = 2 * [(-xmax, xmax)]
    _hist, _edges = np.histogramdd(z0, bins=75, range=limits)
    _hist = _hist / np.max(_hist)
    
    fig, axs = pplt.subplots(ncols=2, space=0, xspineloc="neither", yspineloc="neither")
    kws = dict(vmin=0.0, vmax=1.0)
    axs[0].pcolormesh(_edges[0], _edges[1], _hist.T, **kws)
    axs[1].pcolormesh(grid_coords[0], grid_coords[1], prob.T, **kws)
    pplt.show()
    
    model.step(predictions)