# Train 4D normalizing flow on 2D projections

In [None]:
import os
import pickle
import sys
import time

import numpy as np
import proplot as pplt
import psdist as ps
import psdist.visualization as psv
import scipy.interpolate
import scipy.ndimage
import torch
import zuko
from ipywidgets import interact
from ipywidgets import widgets
from tqdm.notebook import tqdm

import mentflow as mf
from mentflow.utils import grab
from mentflow.utils import unravel
from mentflow.wrappers import WrappedZukoFlow

# Local
import plotting

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = pplt.Colormap("dark_r", space="hpl")
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

In [None]:
device = torch.device("mps")
precision = torch.float32

def send(x):
    return x.type(precision).to(device)

## Data

In [None]:
data_name = "rings"
data_kws = dict()
data_size = int(1.00e+06)
data_noise = None
xmax = 3.5
n_bins = 50
meas_noise = 0.0

In [None]:
# Draw samples from the input distribution.
d = 4
dist = mf.data.radial.gen_dist(data_name, d=d, noise=data_noise)
x0 = dist.sample(data_size)
x0 = send(x0)

# Define transforms.
k = 5
angles = np.linspace(0.0, np.pi, k, endpoint=False)
transforms = []
for i in range(k):
    for j in range(k):
        angle_x = angles[i]
        angle_y = angles[j]
        matrix = torch.eye(d)
        matrix[:2, :2] = mf.transform.rotation_matrix(angle_x)
        matrix[2:, 2:] = mf.transform.rotation_matrix(angle_y)
        matrix = send(matrix)
        transform = mf.transform.Linear(matrix)
        transform = transform.to(device)
        transforms.append(transform)

bin_edges = 2 * [torch.linspace(-xmax, xmax, n_bins + 1)]
diagnostic = mf.diagnostics.Histogram2D(axis=(0, 2), bin_edges=bin_edges, kde=False)
diagnostic = diagnostic.to(device)
diagnostics = [diagnostic]

# Generate training data.
diagnostic.kde = False
measurements = mf.simulate(x0, transforms, diagnostics)
if meas_noise:
    for i in range(len(measurements)):
        for j in range(len(measurements[i])):
            measurement = measurements[i][j]
            frac_noise = meas_noise * torch.randn(measurement.shape[0])
            frac_noise = send(frac_noise)
            measurement *= (1.0 + frac_noise)
            measurement = torch.clamp(measurement, 0.0, None)
            measurements[i][j] = measurement
diagnostic.kde = True

View the 2D projections of the data.

In [None]:
index = 0
y = transforms[index](x0)
y = grab(y)

limits = d * [(-xmax, xmax)]

grid = psv.CornerGrid(d=4)
grid.plot_points(y, bins=75, limits=(4 * [(-xmax, xmax)]), mask=False)

View the measurements.

In [None]:
@interact
def update(index=(0, len(measurements) - 1)):   
    fig, ax = pplt.subplots()
    ax.pcolormesh(grab(bin_edges[0]), grab(bin_edges[1]), grab(measurements[index][0].T))
    pplt.show()

## Model

In [None]:
# Parameters
n_transforms = 5
n_spline_bins = 20
n_hidden_layers = 5
n_hidden_units = 64
prior_scale = 1.0  # Gaussian prior std
discrepancy_function = "mae"

In [None]:
flow = zuko.flows.NSF(
    features=d,
    transforms=n_transforms,
    bins=n_spline_bins,
    hidden_features=(n_hidden_layers * [n_hidden_units]),
    randperm=True,
)
flow = zuko.flows.Flow(flow.transform.inv, flow.base)  # faster sampling
flow = flow.to(device)
flow = WrappedZukoFlow(flow)

prior = zuko.distributions.DiagNormal(
    send(torch.zeros(d)),
    send(prior_scale * torch.ones(d)),
)

entropy_estimator = mf.entropy.MonteCarloEntropyEstimator()

model = mf.MENTFlow(
    generator=flow,
    prior=prior,
    entropy_estimator=entropy_estimator,
    transforms=transforms,
    diagnostics=diagnostics,
    measurements=measurements,
    penalty_parameter=0.0,
    discrepancy_function=discrepancy_function,
)
model = model.to(device)

## Training

In [None]:
# Parameters
n_epochs = 10
n_iterations = 300
batch_size = 30000

lr = 0.01
weight_decay = 0.0
lr_min = 0.001
lr_patience = 500
lr_drop = 0.1

penalty_parameter = 5.0
penalty_step = 20.0
penalty_scale = 1.2
penalty_max = None

monitor_freq = 25
vis_freq = None
vis_size = int(1.00e+06)

In [None]:
def plotter(model):
    # Generate samples
    x = send(model.sample(vis_size))
    x_true = send(dist.sample(x.shape[0]))

    # Simulate measurements
    for diagnostic in model.diagnostics:
        diagnostic.kde = False
    predictions = model.simulate(x)    
    for diagnostic in model.diagnostics:
        diagnostic.kde = True

    x = grab(x)
    x_true = grab(x_true)

    figs = []

    
    ## Radial distribution (unnormalized)
    colors = ["black", "red"]
    fig, ax = pplt.subplots(figsize=(2.75, 1.5))
    for i, _x in enumerate([x_true, x]):
        r = np.linalg.norm(_x, axis=1)
        hist_r, edges_r = np.histogram(r, bins=85, range=(0.0, xmax), density=True)
        plotting.plot_hist(hist_r, edges_r, ax=ax, kind="line", color=colors[i])
    ax.format(xlabel="radius", ylabel="count")
    figs.append(fig)


    ## Corner plot
    grid = psv.CornerGrid(d=d, corner=False)
    for i, _x in enumerate([x_true, x]):
        grid.plot_points(
            _x, 
            bins=75, 
            limits=limits, 
            upper=(not i), 
            lower=i,
            diag_kws=dict(color="black", kind="line", lw=1.5),
            mask=False,
        )
        fig = grid.fig
        figs.append(fig)

    
    ## Simulated measurements
    
    # maxcols = 7
    # n_meas = len(measurements) * len(measurements[0])
    # ncols = min(n_meas, maxcols)
    # nrows = int(np.ceil(n_meas / ncols))
    
    # fig, axs = pplt.subplots(ncols=ncols, nrows=nrows)
    # counter = 0
    # for i in range(len(measurements)):
    #     for j in range(len(measurements[i])):
    #         prob_meas = grab(measurements[i][j])
    #         prob_pred = grab(predictions[i][j])
    #         scale = prob_meas.max()

    #         ax = axs[counter]
    #         for k, _prob in enumerate([prob_meas, prob_pred]):
    #             _prob = _prob / scale
    #             _prob = scipy.ndimage.gaussian_filter(_prob, 0.75)
    #             ax.contour(
    #                 _prob,
    #                 levels=np.linspace(0.01, 1.0, 6),
    #                 color=["black", "red"][k],
    #             )
    #         counter += 1

    maxcols = 7
    n_meas = len(measurements) * len(measurements[0])
    ncols = min(n_meas, maxcols)
    nrows = int(np.ceil(n_meas / ncols))
    
    fig, axs = pplt.subplots(ncols=ncols, nrows=(2 * nrows), figwidth=10.0)
    counter = 0
    for i in range(len(measurements)):
        for j in range(len(measurements[i])):
            prob_meas = grab(measurements[i][j])
            prob_pred = grab(predictions[i][j])
            scale = prob_meas.max()

            row, col = np.unravel_index(counter, (nrows, ncols))
            row *= 2
            for ax, _prob in zip(axs[row: row + 2, col], [prob_meas, prob_pred]):
                ax.pcolormesh(_prob.T)
            counter += 1
            
    figs.append(fig)
    
    return figs

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=lr,
    weight_decay=weight_decay,
)

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    min_lr=lr_min,
    patience=lr_patience,
    factor=lr_drop,
)

model.penalty_parameter = penalty_parameter

monitor = mf.train.Monitor(model=model, momentum=0.98, freq=monitor_freq)

trainer = mf.train.Trainer(
    model=model,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    monitor=monitor,
    plotter=plotter,
    precision=precision,
    device=device,
    save=False,
)

trainer.train(
    epochs=n_epochs,
    iterations=n_iterations,
    batch_size=batch_size,
    penalty_step=penalty_step,
    penalty_scale=penalty_scale,
    penalty_max=penalty_max,
    vis_freq=vis_freq,
    dmax=5.00e-04,
)