# Train 2D normalizing flow

In [None]:
import os
import pickle
import sys
import time

import numpy as np
import proplot as pplt
import scipy.interpolate
import torch
import zuko
from ipywidgets import interact
from ipywidgets import widgets
from tqdm.notebook import tqdm

import mentflow as mf
from mentflow.utils import grab
from mentflow.utils import unravel
from mentflow.wrappers import WrappedZukoFlow

# Local
import plotting

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = pplt.Colormap("dark_r", space="hpl")
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

In [None]:
device = torch.device("mps")
precision = torch.float32

def send(x):
    return x.type(precision).to(device)

## Data

In [None]:
# Parameters
data_name = "galaxy"
data_kws = dict()
data_size = int(1.00e+06)
data_noise = None
xmax = 3.5
n_bins = 80
n_meas = 6
meas_noise = 0.0
meas_min_angle = 0.0 * np.pi
meas_max_angle = 1.0 * np.pi

In [None]:
# Draw samples from the input distribution.
d = 2
dist = mf.data.toy.gen_dist(data_name, noise=data_noise, **data_kws)
x0 = dist.sample(data_size)
x0 = send(x0)

# Define transforms.
angles = np.linspace(meas_min_angle, meas_max_angle, n_meas, endpoint=False)
transforms = []
for angle in angles:
    matrix = mf.transform.rotation_matrix(angle)
    matrix = send(matrix)
    transform = mf.transform.Linear(matrix)
    transform = transform.to(device)
    transforms.append(transform)

# Create histogram diagnostic (x axis).
bin_edges = torch.linspace(-xmax, xmax, n_bins + 1)
bin_edges = send(bin_edges)

diagnostic = mf.diagnostics.Histogram1D(axis=0, bin_edges=bin_edges)
diagnostic = diagnostic.to(device)
diagnostics = [diagnostic]

# Generate training data.
diagnostic.kde = False
measurements = mf.simulate(x0, transforms, diagnostics)
if meas_noise:
    for i in range(len(measurements)):
        for j in range(len(measurements[i])):
            measurement = measurements[i][j]
            frac_noise = meas_noise * torch.randn(measurement.shape[0])
            frac_noise = send(frac_noise)
            measurement *= (1.0 + frac_noise)
            measurement = torch.clamp(measurement, 0.0, None)
            measurements[i][j] = measurement
diagnostic.kde = True

View the data in the transformed space.

In [None]:
@interact
def update(index=(0, n_meas - 1)):
    x = transforms[index](x0)
    x = grab(x)
    
    fig, ax = pplt.subplots()
    limits = 2 * [(-xmax, +xmax)]
    ax.hist2d(x[:, 0], x[:, 1], bins=100, range=limits)
    
    pax = ax.panel_axes("bottom", width=0.75)
    plotting.plot_hist(grab(measurements[index][0]), grab(bin_edges), ax=pax, color="black")
    pplt.show()

Plot the integration lines in the input and transformed space.

In [None]:
@interact
def update(index=(0, len(measurements) - 1)):
    transform = transforms[index]
    
    x = send(dist.sample(100000))
    y = transform(x)
    y = grab(y)
    x = grab(x)
    
    fig, axs = pplt.subplots(ncols=2)
    limits = 2 * [(-xmax, +xmax)]
    
    for ax, _x in zip(axs, [x, y]):
        ax.hist2d(_x[:, 0], _x[:, 1], bins=100, range=limits,)
    
    n_lines = 20
    n_dots_per_line = 100
    y = mf.utils.get_grid_points_torch(
        2.0 * torch.linspace(-xmax, +xmax, n_lines),
        2.0 * torch.linspace(-xmax, +xmax, n_dots_per_line),
    )
    y = send(y)
    
    x = transform.inverse(y)
    x = grab(x)
    y = grab(y)
    
    for ax, _x in zip(axs, [x, y]):
        for line in np.split(_x, n_lines):
            ax.plot(line[:, 0], line[:, 1], color="pink4", alpha=0.5)
    axs.format(xlim=(-xmax, xmax), ylim=(-xmax, xmax))
    pplt.show()

## Model

In [None]:
# Parameters
n_transforms = 3
n_spline_bins = 20
n_hidden_layers = 5
n_hidden_units = 64
prior_scale = 1.0  # Gaussian prior std
discrepancy_function = "kld"

In [None]:
flow = zuko.flows.NSF(
    features=d,
    transforms=n_transforms,
    bins=n_spline_bins,
    hidden_features=(n_hidden_layers * [n_hidden_units]),
    randperm=True,
)
flow = zuko.flows.Flow(flow.transform.inv, flow.base)  # faster sampling
flow = flow.to(device)
flow = WrappedZukoFlow(flow)

prior = zuko.distributions.DiagNormal(
    send(torch.zeros(d)),
    send(prior_scale * torch.ones(d)),
)

entropy_estimator = mf.entropy.MonteCarloEntropyEstimator()

model = mf.MENTFlow(
    generator=flow,
    prior=prior,
    entropy_estimator=entropy_estimator,
    transforms=transforms,
    diagnostics=diagnostics,
    measurements=measurements,
    penalty_parameter=0.0,
    discrepancy_function=discrepancy_function,
)
model = model.to(device)

## Training

In [None]:
# Parameters
n_epochs = 10
n_iterations = 300
batch_size = 30000

lr = 0.01
weight_decay = 0.0
lr_min = 0.001
lr_patience = 500
lr_drop = 0.1

penalty_parameter = 5.0
penalty_step = 20.0
penalty_scale = 1.2
penalty_max = None

monitor_freq = 25
vis_freq = None
vis_bins = 100
vis_size = int(1.00e+06)

In [None]:
def make_plots(x, predictions):
    figs = []

    # Plot the 2D density.
    x_true = dist.sample(x.shape[0])
    x_true = grab(x_true)
    
    fig, axs = plotting.plot_dist(
        x_true,
        x,
        bins=vis_bins,
        limits=(2 * [(-xmax, xmax)]),
    )
    figs.append(fig)

    # Plot overlayed simulated/measured projections.    
    fig, axs = plotting.plot_proj(
        [grab(measurement) for measurement in unravel(measurements)],
        predictions,
        bin_edges=grab(diagnostic.bin_edges),
        maxcols=7,
        kind="line",
        height=1.25,
        lw=1.5,
    )
    figs.append(fig)

    return figs


def plotter(model):
    x = send(model.sample(vis_size))

    for diagnostics in model.diagnostics:
        diagnostic.kde = False
        
    predictions = model.simulate(x)
    predictions = [grab(prediction) for prediction in unravel(predictions)]
    
    for diagnostics in model.diagnostics:
        diagnostic.kde = True

    return make_plots(grab(x), predictions)

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=lr,
    weight_decay=weight_decay,
)

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    min_lr=lr_min,
    patience=lr_patience,
    factor=lr_drop,
)

model.penalty_parameter = penalty_parameter

monitor = mf.train.Monitor(model=model, momentum=0.98, freq=monitor_freq)

trainer = mf.train.Trainer(
    model=model,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    monitor=monitor,
    plotter=plotter,
    precision=precision,
    device=device,
    save=False,
)

trainer.train(
    epochs=n_epochs,
    iterations=n_iterations,
    batch_size=batch_size,
    penalty_step=penalty_step,
    penalty_scale=penalty_scale,
    penalty_max=penalty_max,
    vis_freq=vis_freq,
    dmax=5.00e-04,
)