# 2D iterative MENT solver

In [None]:
import numpy as np
import proplot as pplt
import scipy.interpolate
import torch
from tqdm.notebook import tqdm

import mentflow as mf
from mentflow.utils import grab

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

## Data

In [None]:
device = torch.device("cpu")
precision = torch.float32

def send(x):
    return x.type(precision).to(device)

In [None]:
data_name = "spirals"
data_size = int(1.00e+06)
data_noise = None
xmax = 3.25
n_bins = 64
n_meas = 6

dist = mf.data.toy.gen_dist(data_name, noise=data_noise)
x0 = dist.sample(data_size)
x0 = send(x0)

angles = np.linspace(0.0, np.pi, n_meas, endpoint=False)
transfer_matrices = []
for angle in angles:
    matrix = mf.transform.rotation_matrix(angle)
    matrix = send(matrix)
    transfer_matrices.append(matrix)

transforms = []
for matrix in transfer_matrices:
    transform = mf.transform.Linear(matrix)
    transform = transform.to(device)
    transforms.append(transform)

bin_edges = torch.linspace(-xmax, xmax, n_bins + 1)
diagnostic = mf.diagnostics.Histogram1D(axis=0, bin_edges=bin_edges, kde=False)
diagnostic = diagnostic.to(device)

measurements = [diagnostic(transform(x0)) for transform in transforms]

## Model

In [None]:
prior = mf.models.ment.GaussianPrior(d=2, scale=1.0, device=device)
# prior = mf.models.ment.UniformPrior(d=2, scale=20.0, device=device)

sampler = mf.sample.GridSampler(limits=(2 * [(-xmax, xmax)]), res=200, device=device)

model = mf.models.ment.MENT_2D1D(
    transforms=transforms,
    measurements=measurements,
    diagnostic=diagnostic,
    prior=prior,
    sampler=sampler,
    interpolate="nearest",  # {"nearest", "linear", "pchip"}
    device=device,
)

## Training

In [None]:
for iteration in range(10):
    # Compute t?he model density.
    res = 200    
    grid_coords = 2 * [torch.linspace(-xmax, xmax, res)]
    grid_points = mf.utils.get_grid_points_torch(grid_coords)
    grid_points = send(grid_points)
    prob = model.prob(grid_points)
    prob = prob.reshape(res, res)
    prob = grab(prob)

    # Draw samples from the true distribution.
    n_samples = 100000
    x_true = dist.sample(n_samples)
    x_true = grab(x_true)
    
    # Draw samples from the model.
    x = model.sample(n_samples)
    x = grab(x)

    # Simulate the measurements.
    predictions = model.simulate(method="sample", n=n_samples)

    # Plot density.
    fig, axs = pplt.subplots(ncols=3, space=0, xspineloc="neither", yspineloc="neither")
    kws = dict()
    limits = 2 * [(-xmax, xmax)]

    n_bins = 100
    axs[0].hist2d(x_true[:, 0], x_true[:, 1], bins=n_bins, range=limits, **kws)
    axs[1].pcolormesh(grid_coords[0], grid_coords[1], prob.T, **kws)
    axs[2].hist2d(x[:, 0], x[:, 1], bins=n_bins, range=limits, **kws)
    axs.format(toplabels=["True samples", "MENT density", "MENT samples"])
    pplt.show()
    
    # Plot projections.
    ncols = 7
    nrows = 1 + n_meas // ncols
    fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figheight=(1.1 * nrows), figwidth=10.0)
    for i in range(n_meas):
        kws = dict(lw=1.25)
        bin_coords = mf.utils.centers_from_edges(diagnostic.bin_edges)
        axs[i].plot(grab(bin_coords), grab(measurements[i]), color="black", **kws)
        axs[i].plot(grab(bin_coords), grab(predictions[i]), color="red", **kws)
    pplt.show()

    # Compute discrepancy.
    discrepancy = sum(torch.sum(torch.abs(predictions[i] - measurements[i])) for i in range(n_meas)) / n_meas

    # Print a summary
    print("iteration = {}".format(iteration))
    print("discrepancy = {}".format(discrepancy))

    ## Update lagrange multipliers.
    model.gauss_seidel_iterate(method="integrate")
    # model.gauss_seidel_iterate(method="sample", n=int(1.00e+06))