# 2D MENT

In [None]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import ot
import proplot as pplt
import scipy.interpolate
import torch
from ipywidgets import interact
from ipywidgets import widgets
from tqdm.notebook import tqdm

import mentflow as mf
from mentflow.utils import grab
from mentflow.utils import unravel

# local
import plotting 

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = pplt.Colormap("dark_r", space="hpl")
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

## Data

In [None]:
device = torch.device("cpu")
precision = torch.float32

def send(x):
    return x.type(precision).to(device)

In [None]:
# Parameters
data_name = "two-spirals"
data_kws = dict()
data_size = int(1.00e+06)
data_noise = None
xmax = 3.5
n_bins = 80
n_meas = 5
meas_noise = 0.00  # fractional noise std
meas_noise_type = "gaussian"  # {"uniform", "gaussian"}
meas_min_angle = 0.0 * np.pi
meas_max_angle = 1.0 * np.pi
seed = 0

In [None]:
# Set random state.
if seed is not None:
    torch.manual_seed(seed)
rng = np.random.default_rng(seed)

# Draw samples from the input distribution.
d = 2
dist = mf.data.toy.gen_dist(data_name, noise=data_noise, **data_kws)
x0 = dist.sample(data_size)
x0 = send(x0)

# Define transforms.
angles = np.linspace(meas_min_angle, meas_max_angle, n_meas, endpoint=False)
transforms = []
for angle in angles:
    matrix = mf.transform.rotation_matrix(angle)
    matrix = send(matrix)
    transform = mf.transform.LinearTransform(matrix)
    transform = transform.to(device)
    transforms.append(transform)

# Create histogram diagnostic (x axis).
bin_edges = torch.linspace(-xmax, xmax, n_bins + 1)
bin_edges = send(bin_edges)
diagnostic = mf.diagnostics.Histogram1D(axis=0, bin_edges=bin_edges)
diagnostic = diagnostic.to(device)
diagnostics = [diagnostic]

# Generate training data.
measurements = mf.simulate(x0, transforms, diagnostics)

# Add measurement noise.
for i in range(len(measurements)):
    for j in range(len(diagnostics)):
        measurement = measurements[i][j]
        frac_noise = torch.zeros(measurement.shape[0])
        if meas_noise_type == "uniform":
            frac_noise = meas_noise * torch.rand(measurement.shape[0]) * 2.0
        elif meas_noise_type == "gaussian":
            frac_noise = meas_noise * torch.randn(measurement.shape[0])
        else:
            raise ValueError("Invalid meas_noise_type")
        frac_noise = send(frac_noise)
        noise = frac_noise * measurement        
        noise = send(noise)
        measurement = measurement + noise
        measurement = torch.clamp(measurement, 0.0, None)
        measurements[i][j] = measurement

View the distribution in the transformed space.

In [None]:
@interact
def update(index=(0, len(transforms) - 1)):
    u = transforms[index](x0)
    u = grab(u)
    
    fig, ax = pplt.subplots()
    limits = 2 * [(-xmax, +xmax)]
    ax.hist2d(u[:, 0], u[:, 1], bins=100, range=limits)
    
    pax = ax.panel_axes("bottom", width=0.75)
    _hist, _edges = np.histogram(u[:, 0], bins=(len(bin_edges) - 1), density=True)
    _meas = grab(measurements[index][0])
    _scale = _meas.max()
    plotting.plot_hist(_hist / _scale, _edges, ax=pax, color="black", alpha=0.3)
    plotting.plot_hist(_meas / _scale, grab(bin_edges), ax=pax, color="black")
    pplt.show()

Plot the integration lines in the input and transformed space.

In [None]:
@interact
def update(index=(0, len(transforms) - 1)):
    transform = transforms[index]
    
    x = send(dist.sample(100000))
    u = transform(x)
    u = grab(u)
    x = grab(x)
    
    fig, axs = pplt.subplots(ncols=2)
    limits = 2 * [(-xmax, +xmax)]
    
    for ax, _x in zip(axs, [x, u]):
        ax.hist2d(_x[:, 0], _x[:, 1], bins=100, range=limits,)
    
    n_lines = 20
    n_dots_per_line = 100
    u = mf.utils.get_grid_points_torch(
        2.0 * torch.linspace(-xmax, +xmax, n_lines),
        2.0 * torch.linspace(-xmax, +xmax, n_dots_per_line),
    )
    u = send(u)
    x = transform.inverse(u)
    x = grab(x)
    u = grab(u)
    
    for ax, _x in zip(axs, [x, u]):
        for line in np.split(_x, n_lines):
            ax.plot(line[:, 0], line[:, 1], color="white", alpha=0.5)
    axs.format(xlim=(-xmax, xmax), ylim=(-xmax, xmax))
    pplt.show()

## Model

In [None]:
prior = mf.models.ment.GaussianPrior(d=2, scale=1.0, device=device)

sampler_limits = 2 * [(-xmax, +xmax)]
sampler_limits = 1.1 * np.array(sampler_limits)
sampler = mf.sample.GridSampler(limits=sampler_limits, res=200, device=device)

model = mf.models.ment.MENT(
    d=2,
    transforms=transforms,
    measurements=measurements,
    diagnostics=diagnostics,
    prior=prior,
    sampler=sampler,
    interpolate="pchip",  # {"nearest", "linear", "pchip"}
    device=device,
)

## Training

In [None]:
# Settings
n_iterations = 10
omega = 0.25  # "learning rate" in range (0.0, 1.0].
sim_method = "integrate"  # {"integrate", "sample"}
sim_kws = {
    "limits": [(-xmax, +xmax)],  # integration limits
    "shape": (150,),  # integration grid shape (resolution)
    "n": int(1.00e+06),  # number of particles (if method="sample")
}

vis_n_bins = 125
vis_n_samples = int(1.00e+06)

Training loop:

In [None]:
for iteration in range(n_iterations):
    
    # Draw samples from the model and true distribution.
    x_true = dist.sample(vis_n_samples)
    x_true = send(x_true)
    x = model.sample(vis_n_samples)
    x = send(x)

    # Simulate measurements
    predictions = model.simulate(method=sim_method, **sim_kws)

    # Compute simulation-measurement discrepancy
    discrepancy_vector = []
    for prediction, measurement in zip(unravel(predictions), unravel(measurements)):
        discrepancy = torch.mean(torch.abs(prediction - measurement))
        discrepancy = float(discrepancy)
        discrepancy_vector.append(discrepancy)
    discrepancy = sum(discrepancy_vector) / len(discrepancy_vector)

    ## Compute sliced wasserstein distance between true and model samples.
    distance = None
    n_samples = 100000
    distance = ot.sliced.sliced_wasserstein_distance(
        x[:n_samples],
        x_true[:n_samples], 
        n_projections=50, 
        p=2, 
    )

    # Print summary
    print("iteration = {}".format(iteration))
    print("D(y_model, y_true) = {}".format(discrepancy))
    print("SWD(x_model, x_true) = {}".format(distance))

    # Make plots
    figs = plotting.plot_model(
        model,
        dist,
        sim_kws=sim_kws,
        n=vis_n_samples,
        bins=vis_n_bins,
        xmax=xmax,
        maxcols=7,
        kind="line",
        colors=["black", "red"],
        device=device,
    )
    for fig in figs:
        plt.show()
        plt.close("all")

    # Update lagrange functions.
    model.gauss_seidel_iterate(omega=omega, sim_method=sim_method, **sim_kws)