# 4D-from-1D MENT

In [None]:
import numpy as np
import proplot as pplt
import scipy.interpolate
import torch
from ipywidgets import interact
from ipywidgets import widgets
from tqdm.notebook import tqdm

import mentflow as mf
from mentflow.utils import grab

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

In [None]:
device = torch.device("cpu")
precision = torch.float32

def send(x):
    return x.type(precision).to(device)

## Data

In [None]:
data_name = "spheres"
data_size = int(1.00e+06)
data_noise = None
d = 4
xmax = 4.0
n_bins = 45
n_meas = 20

dist = mf.data.radial.gen_dist(data_name, d=d, noise=data_noise)
x0 = dist.sample(data_size)
x0 = send(x0)

In [None]:
unit_vectors = np.random.normal(size=(n_meas, 4))
unit_vectors = unit_vectors / np.linalg.norm(unit_vectors, axis=1)[:, None]
unit_vectors = torch.from_numpy(unit_vectors)
unit_vectors = send(unit_vectors)

transforms = []
for unit_vector in unit_vectors:

    matrix = torch.eye(d)
    i = 0
    for j in range(d):
        matrix[i, j] = unit_vector[j]
    matrix = send(matrix)
    transform = mf.transform.Linear(matrix)
        
    transform = transform.to(device)
    transforms.append(transform)

bin_edges = torch.linspace(-xmax, xmax, n_bins + 1)
diagnostic = mf.diagnostics.Histogram1D(axis=0, bin_edges=bin_edges, kde=False)
diagnostic = diagnostic.to(device)

measurements = [diagnostic(transform(x0)) for transform in transforms]

In [None]:
@interact
def update(index=(0, n_meas - 1)):
    fig, ax = pplt.subplots(figsize=(3, 2))
    ax.stairs(grab(measurements[index]), grab(bin_edges), color="black", lw=1.25)

## Model

In [None]:
prior = mf.models.ment.GaussianPrior(d=d, scale=1.0, device=device)
# prior = mf.models.ment.UniformPrior(d=d, scale=20.0, device=device)

sampler_limits = d * [(-xmax, +xmax)]
sampler_limits = np.multiply(sampler_limits, 1.05)
sampler = mf.sample.GridSampler(limits=sampler_limits, res=64, device=device)

model = mf.models.ment.MENT(
    d=d,
    transforms=transforms,
    measurements=measurements,
    diagnostic=diagnostic,
    prior=prior,
    sampler=sampler,
    interpolate="linear",  # {"nearest", "linear", "pchip"}
    device=device,
)

## Training

In [None]:
# Define integration grid.
int_limits = model.d_int * [(-xmax, xmax)]
int_limits = np.multiply(int_limits, 1.1)
int_res = 50
int_shape = tuple(model.d_int * [int_res])

In [None]:
# Training loop
for iteration in range(4):
    # Evaluate the model's performance.
    n = 10000
    x  = grab(model.sample(n))
    x0 = grab(dist.sample(n))

    fig, ax = pplt.subplots(figsize=(3, 2))
    for i, _x in enumerate([x0, x]):
        r = np.linalg.norm(_x, axis=1)
        hist_r, edges_r = np.histogram(r, bins=50, range=(0.0, xmax), density=True)
        ax.stairs(hist_r, edges_r, color=["black", "red"][i], lw=1.5)
    pplt.show()

    ## Update lagrange functions.
    # model.gauss_seidel_iterate(method="sample", n=100000)
    model.gauss_seidel_iterate(method="integrate", limits=int_limits, shape=int_shape)

## Evaluation

In [None]:
predictions = model.simulate(method="integrate", limits=int_limits, shape=int_shape)
# predictions = model.simulate(method="sample", n=100000)

In [None]:
@interact
def update(index=(0, len(measurements) - 1)):
    y_meas = grab(measurements[index])
    y_pred = grab(predictions[index])
    y_max = np.max(y_meas)
    y_meas = y_meas / y_max
    y_pred = y_pred / y_max
    
    fig, ax = pplt.subplots(figsize=(3.0, 2.0))
    kws = dict(lw=1.25)
    ax.stairs(y_meas, grab(diagnostic.bin_edges), color="black", **kws)
    ax.stairs(y_pred, grab(diagnostic.bin_edges), color="red", **kws)
    ax.format(ymin=0.0, ymax=1.25)
    pplt.show()

In [None]:
n = 100000
x0 = grab(dist.sample(n))
x = grab(model.sample(n))

In [None]:
import sys
sys.path.append("/Users/46h/repo/psdist/")
import psdist as ps
import psdist.visualization as psv

grid = psv.CornerGrid(d=d, corner=False)
limits = d * [(-xmax, xmax)]
kws = dict(limits=limits, bins=50, mask=False)
grid.plot_points(x, upper=False, diag_kws=dict(color="blue7", lw=1.5), cmap=pplt.Colormap("vlag_r", left=0.5), **kws)
grid.plot_points(x0, lower=False, diag_kws=dict(color="red7", lw=1.5), cmap=pplt.Colormap("vlag", left=0.5), **kws)
pplt.show()