## Train MENT-Flow without samples

This notebook will illustrate an alternative way to train MENT-Flow. Instead of generating samples, we can numerically integrate the model probablity density to generate the predicted profiles. This requires an invertible lattice. It should dramatically improve the dynamic range of the generated profiles without requiring a huge number of particles. (We'll still need to generate samples to estimate the entropy.)

This notebook is old and does not work currently. To do: update and convert to script `train_hdr.py`.

In [None]:
import os
import sys
import time

import matplotlib.pyplot as plt
import mentflow as mf
import numpy as np
import proplot as pplt
import torch
import zuko
from torch.distributions.utils import _sum_rightmost
from tqdm import tqdm

import plotting
import utils
from data import gen_data

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

In [None]:
device = torch.device("cpu")

data_name = "spirals"
data_kws = dict()
data_size = int(1.00e+06)
data_noise = 0.1
data_decorr = False
data_shuffle = True
data_normalize = True
seed = 0

n_meas = 6
max_angle = 180.0
meas_bins = 100
xmax = 2.75

res = 75  # n evaluation points along integration axis
n_steps = 10
n_iterations = 500
alpha_min = 500.0
alpha_step = 0.0
alpha_mult = 2.0
beta = 0.0

rtol = 0.25
patience = 100

vis_size = 100000
vis_bins = 75
vis_res = 150
vis_freq = 100

n_flows = 5
hidden_units = 64
hidden_layers = 3
spline_bins = 20

targ_scale = 1.0

lr = 0.001
weight_decay = 1.00e-05

cvt = lambda x: x.type(torch.float32).to(device)
grab = lambda x: x.detach().cpu().numpy()

In [None]:
d = 2
x0 = gen_data(
    name=data_name,
    size=data_size,
    shuffle=data_shuffle,
    normalize=data_normalize,
    noise=data_noise,
    seed=seed,
    **data_kws
)
x0 = cvt(torch.from_numpy(x0))

In [None]:
lattice = mf.lattice.LinearLattice()
lattice = lattice.to(device)

angles = np.linspace(0.0, np.radians(max_angle), n_meas, endpoint=False)
transfer_matrices = []
for angle in angles:
    matrix = mf.utils.rotation_matrix(angle)
    matrix = cvt(torch.from_numpy(matrix))
    transfer_matrices.append(matrix)

In [None]:
limits = 2 * [(-xmax, xmax)]
bin_edges = cvt(torch.linspace(-xmax, xmax, meas_bins + 1))
bin_centers = mf.utils.centers_from_edges(bin_edges)

diagnostic = mf.diagnostics.Histogram1D(edges=bin_edges, bandwidth=None)
diagnostic = diagnostic.to(device)

measurements = []
for matrix in transfer_matrices:
    lattice.set_matrix(matrix)
    measurements.append(diagnostic(lattice(x0), kde=False))

In [None]:
flow = zuko.flows.NSF(
    features=d, 
    transforms=n_flows, 
    bins=spline_bins,
    hidden_features=([hidden_units] * hidden_layers),
    randperm=True,
)

In [None]:
target = zuko.distributions.DiagNormal(
    loc=cvt(torch.zeros(d)), 
    scale=cvt(targ_scale * torch.ones(d))
)

In [None]:
model = mf.MENTFlow(
    flow=flow,
    target=target,
    lattice=lattice,
    diagnostic=diagnostic,
    measurements=measurements,
    transfer_matrices=transfer_matrices,
    alpha=alpha_min,
    beta=beta,
    # loss="kld",
    loss="mae",
)

In [None]:
for method in ["sart", "fbp"]:
    image = utils.reconstruct(measurements, angles, method=method, iterations=10)
    x = utils.sample_image(image, edges=(grab(bin_edges), grab(bin_edges)), n=vis_size)
    x = cvt(torch.from_numpy(x))

    predictions = []
    for matrix in transfer_matrices:
        lattice.set_matrix(matrix)
        predictions.append(diagnostic(lattice(x), kde=False))
        
    fig, axs = pplt.subplots(ncols=3, xspineloc="neither", yspineloc="neither", space=0.0, share=False)
    plotting.plot_cloud(grab(x0)[:vis_size], bins=vis_bins, limits=limits, ax=axs[0])
    plotting.plot_cloud(grab(x), bins=vis_bins, limits=limits, ax=axs[1])
    plotting.plot_image(image, coords=(2 * [0.5 * (grab(bin_edges)[1:] + grab(bin_edges)[:-1])]), ax=axs[2])                
    axs.format(xlim=limits[0], ylim=limits[1])
    plt.show()
    
    axs = plotting.plot_proj(measurements, predictions, edges=grab(bin_edges))
    plt.show()

In [None]:
last_avg_cost = float("inf")
stop = False
meter = mf.train.RunningAverageMeter(momentum=0.99)

for step in range(1, n_steps + 1):
    print("step={} alpha={:0.4f} beta={}".format(step, model.alpha, model.beta))
    meter.reset()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    for iteration in range(1, n_iterations + 1):
        t0 = time.time()
        
        optimizer.zero_grad()

        # Estimate entropy from samples.
        x, log_prob, H = model.sample(1000)

        # Predict profiles by integrating density.
        predictions = []
        for matrix, measurement in zip(transfer_matrices, measurements):
            coords = [grab(bin_centers), np.linspace(-xmax, xmax, res)]
            x_grid = mf.utils.get_grid_points(coords)
            x_grid = cvt(torch.from_numpy(x_grid))

            lattice.set_matrix(matrix)
            log_prob = flow().log_prob(lattice.inverse(x_grid))
            prob = torch.exp(log_prob).reshape([len(c) for c in coords])
            prob = prob.reshape([len(c) for c in coords])
            prediction = torch.sum(prob, 1)
            prediction = prediction / torch.sum(prediction) / (bin_edges[1] - bin_edges[0])
            predictions.append(prediction)

        C = []
        for prediction, measurement in zip(predictions, measurements):
            cost = model.loss_func(prediction, measurement)
            C.append(cost)

        loss = model.loss_function(H, C)
        
        if not (torch.isinf(loss) or torch.isnan(loss)):
            loss.backward()
            optimizer.step()

        cost = float(sum(C) / len(C))
        meter.action(cost)
    
        if iteration > 1 and (iteration % 10 == 0):
            message = "iter={:05.0f} t={:0.2f} loss={:0.3e} H={:0.3e} C={:0.3e} Cavg={:0.3e} alpha={:0.4f}".format(
                iteration, time.time() - t0, float(loss), float(H), float(cost), meter.avg, model.alpha,
            )
            print(message)
    
        if iteration > 1 and (iteration % vis_freq == 0):
            model.eval()
            with torch.no_grad():               
                x, _, _ = model.sample(vis_size)
                x = cvt(x)

                # predictions = model(x, kde=False)

                # Predict profiles by integrating density.
                predictions = []
                for matrix, measurement in zip(transfer_matrices, measurements):
                    coords = [grab(bin_centers), np.linspace(-xmax, xmax, res)]
                    x_grid = mf.utils.get_grid_points(coords)
                    x_grid = cvt(torch.from_numpy(x_grid))
        
                    lattice.set_matrix(matrix)
                    log_prob = flow().log_prob(lattice.inverse(x_grid))
                    prob = torch.exp(log_prob).reshape([len(c) for c in coords])
                    prob = prob.reshape([len(c) for c in coords])
                    prediction = torch.sum(prob, 1)
                    prediction = prediction / torch.sum(prediction) / (bin_edges[1] - bin_edges[0])
                    predictions.append(prediction)
                                
                coords = 2 * [torch.linspace(-xmax, xmax, vis_res)]
                x_grid = torch.vstack([C.ravel() for C in torch.meshgrid(*coords, indexing="ij")]).T
                log_prob = flow().log_prob(x_grid)
                log_prob = log_prob.reshape((vis_res, vis_res))
        
                x = grab(x)
                x_true = grab(x0[:vis_size])
                prob = grab(torch.exp(log_prob))
                
                fig, axs = pplt.subplots(ncols=3, xspineloc="neither", yspineloc="neither", space=0.0)
                kws = dict()
                for ax, _x in zip(axs, [x_true, x]):
                    plotting.plot_cloud(_x[:vis_size], bins=vis_bins, limits=limits, ax=ax, **kws)
                plotting.plot_image(prob, coords=coords, ax=axs[-1], **kws)
                plt.show()
    
                maxcols = 7
                ncols = min(len(measurements), maxcols)
                nrows = int(np.ceil(len(measurements) / ncols))
                figheight = 1.75 * nrows
                figwidth = 1.25 * ncols * 1.75
                fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figheight=figheight, figwidth=figwidth)
                kws = dict(lw=1.25)
                for j in range(len(measurements)):
                    scale = max(grab(measurements[j]))
                    axs[j].stairs(grab(measurements[j]) / scale, edges=grab(bin_edges), color="black", **kws)
                    axs[j].stairs(grab(predictions[j]) / scale, edges=grab(bin_edges), color="red", **kws)
                axs.format(ymax=1.25)
                plt.show()
            
            model.train()

        if meter.n_bad > patience:
            print(f"{meter.n_bad} iters without decrease in avg cost: ending step")
            break

    if stop:
        print("Stopping training early.")
        break

    model.beta = [model.beta[i] - model.alpha * float(C[i]) for i in range(len(C))]
        
    if meter.avg < rtol * last_avg_cost:
        model.alpha = alpha_mult * model.alpha + alpha_step
    else:
        # print("Relative cost decrease less than rtol: running one more step")
        # stop = True
        pass
        
    last_avg_cost = meter.avg      