# Train 2D MENT algorithm

In [None]:
import mentflow as mf
import numpy as np
import proplot as pplt
import scipy.interpolate
from tqdm.notebook import tqdm

pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

In [None]:
data_name = "circles"
data_size = int(1.00e+06)
data_noise = None
xmax = 3.0
n_bins = 64
n_meas = 6

dist = mf.data.toy.gen_dist(data_name, noise=data_noise)

x0 = dist.sample(data_size)

angles = np.linspace(0.0, np.pi, n_meas, endpoint=False)
transfer_matrices = []
for angle in angles:
    matrix = mf.utils.rotation_matrix(angle)
    transfer_matrices.append(matrix)
    
measurements = []
edges = []
for matrix in transfer_matrices:
    x_out = np.matmul(x0, matrix.T)
    _edges = np.linspace(-xmax, xmax, n_bins + 1)
    _hist, _ = np.histogram(x_out[:, 0], bins=_edges, density=True)
    measurements.append(_hist)
    edges.append(_edges)

prior = None
prior = mf.models.ment.GaussianDistribution(scale=1.0)

model = mf.models.MENT(
    transfer_matrices=transfer_matrices, 
    measurements=measurements,
    edges=edges,
    prior=prior,
)

In [None]:
for step in range(10):
    # Plot density. 
    res = 150    
    grid_coords = 2 * [np.linspace(-xmax, xmax, res)]
    grid_points = mf.utils.get_grid_points(grid_coords)
    prob = model.prob(grid_points)
    prob = np.reshape(prob, (res, res))
    prob = prob / np.max(prob)
    
    limits = 2 * [(-xmax, xmax)]
    _hist, _edges = np.histogramdd(x0, bins=75, range=limits)
    _hist = _hist / np.max(_hist)
    
    fig, axs = pplt.subplots(ncols=2, space=0, xspineloc="neither", yspineloc="neither")
    kws = dict(vmin=0.0, vmax=1.0)
    axs[0].pcolormesh(_edges[0], _edges[1], _hist.T, **kws)
    axs[1].pcolormesh(grid_coords[0], grid_coords[1], prob.T, **kws)
    axs.format(toplabels=["True", "MENT"])
    pplt.show()

    # Plot projections.
    x = model.sample(100000)
    predictions = []
    for i, matrix in enumerate(model.transfer_matrices):
        x_out = np.matmul(x, matrix.T)
        hist, _ = np.histogram(x_out[:, 0], bins=model.edges[i], density=True)
        predictions.append(hist)

    # Plot projections.
    ncols = 7
    nrows = 1 + n_meas // ncols
    fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figheight=(1.2 * nrows), figwidth=10.0)
    for i in range(n_meas):
        kws = dict(lw=1.25)
        axs[i].plot(model.coords[i], measurements[i], color="black", **kws)
        axs[i].plot(model.coords[i], predictions[i], color="red", **kws)
    pplt.show()

    # Compute cost.
    cost = sum(np.sum(np.abs(predictions[i] - measurements[i])) for i in range(n_meas)) / n_meas
    print("cost = {}".format(cost))

    # Update lagrange multipliers.
    model.step()