# 4D-from-2D MENT

In [None]:
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.visualization as psv
import scipy.interpolate
import torch
from ipywidgets import interact
from ipywidgets import widgets
from tqdm.notebook import tqdm

import mentflow as mf
from mentflow.utils import grab

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = pplt.Colormap("dark_r", space="hpl")
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

In [None]:
device = torch.device("cpu")
precision = torch.float32

def send(x):
    return x.type(precision).to(device)

## Data

In [None]:
data_name = "rings"
data_kws = dict()
data_size = int(1.00e+06)
data_noise = None
xmax = 3.5
n_bins = 50
meas_noise = 0.0

In [None]:
# Draw samples from the input distribution.
d = 4
dist = mf.data.radial.gen_dist(data_name, d=d, noise=data_noise)
x0 = dist.sample(data_size)
x0 = send(x0)

# Define transforms.
k = 5
angles = np.linspace(0.0, np.pi, k, endpoint=False)
transforms = []
for i in range(k):
    for j in range(k):
        angle_x = angles[i]
        angle_y = angles[j]
        matrix = torch.eye(d)
        matrix[:2, :2] = mf.transform.rotation_matrix(angle_x)
        matrix[2:, 2:] = mf.transform.rotation_matrix(angle_y)
        matrix = send(matrix)
        transform = mf.transform.Linear(matrix)
        transform = transform.to(device)
        transforms.append(transform)

bin_edges = 2 * [torch.linspace(-xmax, xmax, n_bins + 1)]
diagnostic = mf.diagnostics.Histogram2D(axis=(0, 2), bin_edges=bin_edges, kde=False)
diagnostic = diagnostic.to(device)
diagnostics = [diagnostic]

# Generate training data.
diagnostic.kde = False
measurements = []
for transform in transforms:
    measurements.append(diagnostic(x0))
if meas_noise:
    for i in range(len(measurements)):
        measurement = measurements[i]
        frac_noise = meas_noise * torch.randn(measurement.shape[0])
        frac_noise = send(frac_noise)
        measurement *= (1.0 + frac_noise)
        measurement = torch.clamp(measurement, 0.0, None)
        measurements[i] = measurement
diagnostic.kde = True

View the 2D projections of the data.

In [None]:
index = 0
y = transforms[index](x0)
y = grab(y)

limits = d * [(-xmax, xmax)]

grid = psv.CornerGrid(d=4)
grid.plot_points(y, bins=75, limits=(4 * [(-xmax, xmax)]), mask=False)

View the measurements.

In [None]:
@interact
def update(index=(0, len(measurements) - 1)):   
    fig, ax = pplt.subplots()
    ax.pcolormesh(grab(bin_edges[0]), grab(bin_edges[1]), grab(measurements[index].T))
    pplt.show()

## Model 

In [None]:
measurements[0]

In [None]:
# prior = mf.models.ment.UniformPrior(d=d, scale=20.0, device=device)
prior = mf.models.ment.GaussianPrior(d=d, scale=1.0, device=device)

sampler_limits = d * [(-xmax, +xmax)]
sampler_limits = np.multiply(sampler_limits, 1.25)
sampler = mf.sample.GridSampler(limits=sampler_limits, res=50, device=device)

model = mf.models.ment.MENT(
    d=d,
    transforms=transforms,
    measurements=measurements,
    diagnostic=diagnostic,
    prior=prior,
    sampler=sampler,
    interpolate="linear",  # {"nearest", "linear", "pchip"}
    device=device,
)

## Training

In [None]:
# Define integration grid.
int_limits = model.d_int * [(-xmax, xmax)]
int_limits = np.multiply(int_limits, 1.1)
int_res = 50
int_shape = tuple(model.d_int * [int_res])

In [None]:
# Training loop
for iteration in range(5):
    n = 1000000
    
    x = model.sample(n)
    x = grab(x)

    x0 = dist.sample(n)

    # Radial distribution (unnormalized)
    fig, ax = pplt.subplots(figsize=(3, 2))
    for i, _x in enumerate([x0, x]):
        r = np.linalg.norm(_x, axis=1)
        hist_r, edges_r = np.histogram(r, bins=50, range=(0.0, xmax), density=True)
        ax.stairs(hist_r, edges_r, color=["black", "red"][i], lw=1.5)
    pplt.show()
    

    ## Corner plot
    grid = psv.CornerGrid(d=d, corner=False)
    for i, _x in enumerate([x0, x]):
        grid.plot_points(
            _x, 
            bins=75, 
            limits=(4 * [(-xmax, xmax)]),
            upper=(not i), 
            lower=i,
            diag_kws=dict(color="black", kind="line", lw=1.5),
            mask=False,
        )
    pplt.show()

    # model.gauss_seidel_iterate(method="integrate", limits=int_limits, shape=int_shape)
    model.gauss_seidel_iterate(method="sample", n=1000000)