# 2D MENT

In [None]:
import os
import sys

import numpy as np
import ot
import proplot as pplt
import scipy.interpolate
import torch
from ipywidgets import interact
from ipywidgets import widgets
from tqdm.notebook import tqdm

import mentflow as mf
from mentflow.utils import grab

# local
import plotting 
import utils

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = pplt.Colormap("dark_r", space="hpl")
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

## Data

In [None]:
device = torch.device("cpu")
precision = torch.float32

def send(x):
    return x.type(precision).to(device)

In [None]:
# Parameters
data_name = "two-spirals"
data_kws = dict()
data_size = int(1.00e+06)
data_noise = None
xmax = 3.5
n_bins = 80
n_meas = 6
meas_noise = 0.0
meas_min_angle = 0.0 * np.pi
meas_max_angle = 1.0 * np.pi

In [None]:
dist = mf.data.toy.gen_dist(data_name, noise=data_noise, **data_kws)
x0 = dist.sample(data_size)
x0 = send(x0)

angles = np.linspace(meas_min_angle, meas_max_angle, n_meas, endpoint=False)
transfer_matrices = []
for angle in angles:
    matrix = mf.transform.rotation_matrix(angle)
    matrix = send(matrix)
    transfer_matrices.append(matrix)

transforms = []
for matrix in transfer_matrices:
    transform = mf.transform.Linear(matrix)
    transform = transform.to(device)
    transforms.append(transform)

bin_edges = torch.linspace(-xmax, xmax, n_bins + 1)
diagnostic = mf.diagnostics.Histogram1D(axis=0, bin_edges=bin_edges, kde=False)
diagnostic = diagnostic.to(device)

measurements = [diagnostic(transform(x0)) for transform in transforms]

if meas_noise:
    for i, measurement in enumerate(measurements):
        frac_noise = meas_noise * torch.randn(measurement.shape[0])
        frac_noise = send(frac_noise)
        measurement *= (1.0 + frac_noise)
        measurement = torch.clamp(measurement, 0.0, None)
        measurements[i] = measurement

View the distribution in the transformed space.

In [None]:
@interact
def update(index=(0, n_meas - 1)):
    x = transforms[index](x0)
    x = grab(x)
    
    fig, ax = pplt.subplots()
    limits = 2 * [(-xmax, +xmax)]
    ax.hist2d(x[:, 0], x[:, 1], bins=100, range=limits)
    
    pax = ax.panel_axes("bottom", width=0.75)
    plotting.plot_hist(grab(measurements[index]), grab(bin_edges), ax=pax, color="black")

## Model

In [None]:
# prior = mf.models.ment.UniformPrior(d=2, scale=20.0, device=device)
prior = mf.models.ment.GaussianPrior(d=2, scale=2.0, device=device)

sampler_limits = 2 * [(-xmax, +xmax)]
sampler_limits = 1.1 * np.array(sampler_limits)
sampler = mf.sample.GridSampler(limits=sampler_limits, res=200, device=device)

model = mf.models.ment.MENT(
    d=2,
    transforms=transforms,
    measurements=measurements,
    diagnostic=diagnostic,
    prior=prior,
    sampler=sampler,
    interpolate="pchip",  # {"nearest", "linear", "pchip"}
    device=device,
)

## Training

In [None]:
# Settings
n_iterations = 10
method = "integrate"  # {"integrate", "sample"}
method_kws = {
    "limits": [(-xmax, +xmax)],  # integration limits
    "shape": (150,),  # integration grid shape (resolution)
    "n": int(1.00e+05),  # number of particles (if method="sample")
}

vis_n_bins = 100
vis_n_samples = int(1.00e+06)

for iteration in range(n_iterations):    
    # Draw samples from the model and true distribution.
    x_true = dist.sample(vis_n_samples)
    x_true = grab(x_true)
    x = model.sample(vis_n_samples)
    x = grab(x)

    # Simulate measurements
    predictions = model.simulate(method=method, **method_kws)

    # Plot 2D density
    fig, axs = plotting.plot_dist(
        x_true,
        x,
        bins=vis_n_bins,
        limits=(2 * [(-xmax, xmax)]),
    )
    axs.format(toplabels=["true", "model"])
    pplt.show()

    # Plot overlayed simulated/measured projections.    
    fig, axs = plotting.plot_proj(
        [grab(measurement) for measurement in measurements],
        [grab(prediction) for prediction in predictions],
        bin_edges=grab(diagnostic.bin_edges),
        maxcols=7,
        kind="step",
        height=1.25
    )
    axs.format(ymin=0.0)
    pplt.show()

    # Compute simulation-measurement discrepancy
    discrepancy = sum(
        torch.sum(torch.abs(predictions[i] - measurements[i]))
        for i in range(n_meas)
    )
    discrepancy = discrepancy / n_meas

    ## Compute sliced wasserstein distance between true and model samples.
    distance = None
    # distance = ot.sliced.sliced_wasserstein_distance(
    #     x0, x, n_projections=50, p=2, projections=None
    # )

    # Print summary
    print("iteration = {}".format(iteration))
    print("D(y_model, y_true) = {}".format(discrepancy))
    print("SWD(x_model, x_true) = {}".format(distance))
    
    # Update the lagrange functions.
    model.gauss_seidel_iterate(method=method, **method_kws)