# 2D MENT-Flow

In [None]:
import numpy as np
import proplot as pplt
import scipy.interpolate
import torch
import zuko
from tqdm.notebook import tqdm

import mentflow as mf
from mentflow.utils import grab
from mentflow.utils import unravel
from mentflow.wrappers import WrappedZukoFlow

# Local
import plotting
import utils

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

In [None]:
device = torch.device("mps")
precision = torch.float32

def send(x):
    return x.type(precision).to(device)

## Data

In [None]:
data_name = "spirals"
data_size = int(1.00e+06)
data_noise = None
xmax = 3.25
n_bins = 75
n_meas = 6
min_angle = 0.0  # [deg]
max_angle = 180.0  # [deg]
meas_noise = 0.00

# Define the input distribution.
d = 2
dist = mf.data.toy.gen_dist(
    data_name,
    noise=data_noise,
)

# Draw samples from the input distribution.
x0 = dist.sample(data_size)
x0 = send(x0)

# Define linear transformations.
angles = np.linspace(np.radians(min_angle), np.radians(max_angle), n_meas, endpoint=False)
transforms = []
for angle in angles:
    matrix = mf.transform.rotation_matrix(angle)
    matrix = send(matrix)
    transform = mf.transform.Linear(matrix)
    transform = transform.to(device)
    transforms.append(transform)

# Create histogram diagnostic (x axis).
bin_edges = torch.linspace(-xmax, xmax, n_bins + 1)
bin_edges = send(bin_edges)

diagnostic = mf.diagnostics.Histogram1D(axis=0, bin_edges=bin_edges)
diagnostic = diagnostic.to(device)
diagnostics = [diagnostic]

# Generate training data.
diagnostic.kde = False
measurements = mf.simulate(x0, transforms, diagnostics)
if meas_noise:
    for i in range(len(measurements)):
        for j in range(len(measurements[i])):
            measurement = measurements[i][j]
            frac_noise = meas_noise * torch.randn(measurement.shape[0])
            frac_noise = send(frac_noise)
            measurement *= (1.0 + frac_noise)
            measurement = torch.clamp(measurement, 0.0, None)
            measurements[i][j] = measurement
diagnostic.kde = True

## Model

In [None]:
n_transforms = 3
n_spline_bins = 20
n_hidden_layers = 5
n_hidden_units = 64
prior_scale = 1.0
penalty_parameter = 5.0
discrepancy_function = "kld"

In [None]:
flow = zuko.flows.NSF(
    features=d,
    transforms=n_transforms,
    bins=n_spline_bins,
    hidden_features=(n_hidden_layers * [n_hidden_units]),
    randperm=True,
)
flow = zuko.flows.Flow(flow.transform.inv, flow.base)  # faster sampling
flow = flow.to(device)
flow = WrappedZukoFlow(flow)

prior = zuko.distributions.DiagNormal(
    send(torch.zeros(d)),
    send(prior_scale * torch.ones(d)),
)

entropy_estimator = mf.entropy.MonteCarloEntropyEstimator()

model = mf.MENTFlow(
    generator=flow,
    prior=prior,
    entropy_estimator=entropy_estimator,
    transforms=transforms,
    diagnostics=diagnostics,
    measurements=measurements,
    penalty_parameter=penalty_parameter,
    discrepancy_function=discrepancy_function,
)
model = model.to(device)

## Diagnostics

In [None]:
monitor_freq = 25
vis_bins = 75
vis_size = int(1.00e+05)
vis_res = 150
vis_maxcols = 7

In [None]:
monitor = mf.train.Monitor(model=model, momentum=0.98, freq=monitor_freq)

In [None]:
def make_plots(x, prob, predictions):
    figs = []

    # Plot the true samples, model samples, and model density.
    fig, axs = plotting.plot_dist(
        dist.sample(vis_size),
        x,
        prob=prob,
        coords=([np.linspace(-xmax, xmax, s) for s in prob.shape]),
        n_bins=vis_bins,
        limits=(2 * [(-xmax, xmax)]),
    )
    figs.append(fig)

    # Plot overlayed simulated/measured projections.    
    fig, axs = plotting.plot_proj(
        [grab(measurement) for measurement in unravel(measurements)],
        predictions,
        bin_edges=grab(diagnostic.bin_edges),
        maxcols=vis_maxcols,
        kind="line",
    )
    figs.append(fig)

    return figs


def plotter(model):
    # Evaluate the model density on a grid.
    res = vis_res
    grid_coords = [np.linspace(-xmax, xmax, res) for i in range(2)]
    grid_points = mf.utils.get_grid_points(grid_coords)
    grid_points = torch.from_numpy(grid_points)
    grid_points = send(grid_points)
    log_prob = model.log_prob(grid_points)
    log_prob = log_prob.reshape((res, res))
    prob = torch.exp(log_prob)

    # Draw samples from the model.
    x = send(model.sample(vis_size))

    # Simulate the measurements.
    for diagnostics in model.diagnostics:
        diagnostic.kde = False
        
    predictions = model.simulate(x)
    predictions = [grab(prediction) for prediction in unravel(predictions)]
    
    for diagnostics in model.diagnostics:
        diagnostic.kde = True

    return make_plots(grab(x), grab(prob), predictions)

## Training

In [None]:
lr = 0.001
weight_decay = 0.0
lr_min = 0.001
lr_patience = 500
lr_drop = 0.25

n_epochs = 10
n_iterations = 400
batch_size = 30000

penalty_step = 20.0
penalty_scale = 1.1
penalty_max = None

vis_freq = None
check_freq = None

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=lr,
    weight_decay=weight_decay,
)

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    min_lr=lr_min,
    patience=lr_patience,
    factor=lr_drop,
)

trainer = mf.train.Trainer(
    model=model,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    monitor=monitor,
    plotter=plotter,
    precision=precision,
    device=device,
    save=False,
)

trainer.train(
    epochs=n_epochs,
    iterations=n_iterations,
    batch_size=batch_size,
    penalty_step=penalty_step,
    penalty_scale=penalty_scale,
    penalty_max=penalty_max,
    save=False,
    vis_freq=vis_freq,
    checkpoint_freq=check_freq,
)