# 4D-from-2D MENT

In [None]:
import numpy as np
import proplot as pplt
import scipy.interpolate
import torch
from ipywidgets import interact
from ipywidgets import widgets
from tqdm.notebook import tqdm

import mentflow as mf
from mentflow.utils import grab

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

## Data

In [None]:
device = torch.device("cpu")
precision = torch.float32

def send(x):
    return x.type(precision).to(device)

In [None]:
data_name = "spheres"
data_size = int(1.00e+06)
data_noise = None
d = 4
xmax = 4.5
n_bins = 35

dist = mf.data.radial.gen_dist(data_name, d=d, noise=data_noise)
x0 = dist.sample(data_size)
x0 = send(x0)

In [None]:
k = 5
angles = np.linspace(0.0, np.pi, k, endpoint=False)

transforms = []
for i in range(k):
    for j in range(k):
        angle_x = angles[i]
        angle_y = angles[j]
        matrix = torch.ones((4, 4))
        matrix[:2, :2] = mf.transform.rotation_matrix(angle_x)
        matrix[2:, 2:] = mf.transform.rotation_matrix(angle_y)
        matrix = send(matrix)
        transform = mf.transform.Linear(matrix)
        transform = transform.to(device)
        transforms.append(transform)

bin_edges = 2 * [torch.linspace(-xmax, xmax, n_bins + 1)]
diagnostic = mf.diagnostics.Histogram2D(axis=(0, 2), bin_edges=bin_edges, kde=False)
diagnostic = diagnostic.to(device)

measurements = [diagnostic(transform(x0)) for transform in transforms]

## Model 

In [None]:
prior = mf.models.ment.UniformPrior(d=d, scale=20.0, device=device)

sampler_limits = d * [(-xmax, +xmax)]
sampler_limits = np.multiply(sampler_limits, 1.25)
sampler = mf.sample.GridSampler(limits=sampler_limits, res=50, device=device)

model = mf.models.ment.MENT(
    d=d,
    transforms=transforms,
    measurements=measurements,
    diagnostic=diagnostic,
    prior=prior,
    sampler=sampler,
    interpolate="linear",  # {"nearest", "linear", "pchip"}
    device=device,
)

## Training

In [None]:
# Define integration grid.
int_limits = model.d_int * [(-xmax, xmax)]
int_limits = np.multiply(int_limits, 1.1)
int_res = 50
int_shape = tuple(model.d_int * [int_res])

In [None]:
# Training loop
for iteration in range(5):
    n = 10000
    
    x = model.sample(n)
    x = grab(x)

    x0 = dist.sample(n)

    fig, ax = pplt.subplots(figsize=(3, 2))
    for i, _x in enumerate([x0, x]):
        r = np.linalg.norm(_x, axis=1)
        hist_r, edges_r = np.histogram(r, bins=50, range=(0.0, xmax), density=True)
        ax.stairs(hist_r, edges_r, color=["black", "red"][i], lw=1.5)
    pplt.show()

    model.gauss_seidel_iterate(method="integrate", limits=int_limits, shape=int_shape)

In [None]:
## Compute predictions.
predictions = model.simulate(method="sample")
# predictions = model.simulate(method="sample", limits=limits, shape=shape)

In [None]:
@interact
def update(index=(0, len(measurements) - 1)):
    prob_meas = grab(measurements[index])
    prob_pred = grab(predictions[index])
    
    fig, axs = pplt.subplots(ncols=2)
    axs[0].pcolormesh(prob_meas.T)
    axs[1].pcolormesh(prob_pred.T)
    pplt.show()