# Evaluate MENT-Flow model

In [None]:
import os
import pickle
import sys

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.visualization as psv
import scipy.ndimage
import torch
import zuko
from ipywidgets import interact
from ipywidgets import widgets

import mentflow as mf
from mentflow.utils import grab
from mentflow.utils import unravel

from experiments.load import load_mentflow_run
from setup import make_dist

In [None]:
mf.train.plot.set_proplot_rc()

## Settings

In [None]:
device = "mps"
precision = torch.float32

def send(x):
    return x.type(precision).to(device)

## Load data

In [None]:
script_name = "train_flow"
timestamp = 240130133321
data_dir = f"./outputs/{script_name}/{timestamp}/"

In [None]:
run = load_mentflow_run(data_dir, device=device)

cfg = run["config"]
history = run["history"]
model = run["model"]
checkpoints = run["checkpoints"]

In [None]:
dist = make_dist(cfg)

## Scalar history

In [None]:
keys_sorted = sorted(list(history.keys()))
imax = len(history[keys_sorted[0]])

@interact(
    key=widgets.Dropdown(options=keys_sorted, value="D_norm"),
    irange=widgets.IntRangeSlider(min=0, max=imax, value=(0, imax)),
    log=False,
)
def update(key, irange, log):
    vals = history[key]
    avgs = mf.utils.exp_avg(vals, momentum=0.95)

    fig, ax = pplt.subplots()
    ax.plot(vals[irange[0] : irange[1]], color="gray")
    ax.plot(avgs[irange[0] : irange[1]], color="black", lw=1.0)    
    if log:
        ax.format(yscale="log", yformatter="log")
    ax.format(xlabel="Iteration (global)", ylabel=key)
    plt.show()

## Evaluation

### Radial PDF

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0, 
        max=(len(checkpoints) - 1), 
        value=(len(checkpoints) - 1),
    ),
    n_samples=widgets.FloatLogSlider(min=2, max=6.5, value=1.00e+05),
    rmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.5),
    bins=widgets.IntSlider(min=4, max=150, value=50),
    kind=["step", "line"],
)
def update(index, n_samples, rmax, bins, kind, normalize=True):
    model.load(checkpoints[index]["path"], device)

    with torch.no_grad():
        n_samples = int(n_samples)
        x1 = grab(dist.sample(n_samples))
        x2 = grab(model.sample(n_samples))

        colors = ["red4", "black"]
        
        bin_edges = np.linspace(0.0, rmax, bins + 1)
    
        r1 = np.linalg.norm(x1, axis=1)
        r2 = np.linalg.norm(x2, axis=1)
        hist_r1, _ = np.histogram(r1, bins=bin_edges, density=True)
        hist_r2, _ = np.histogram(r2, bins=bin_edges, density=True)
        if normalize:
            for i in range(len(bin_edges) - 1):
                rmin = bin_edges[i]
                rmax = bin_edges[i + 1]
                hist_r1[i] = hist_r1[i] / mf.utils.sphere_shell_volume(rmin=rmin, rmax=rmax, d=x1.shape[1])
                hist_r2[i] = hist_r2[i] / mf.utils.sphere_shell_volume(rmin=rmin, rmax=rmax, d=x2.shape[1])
        
        fig, ax = pplt.subplots(figsize=(3.0, 2.0))
        scale = hist_r1.max()
        psv.plot_profile(hist_r1 / scale, edges=bin_edges, ax=ax, color=colors[0], kind="step", lw=1.5)
        psv.plot_profile(hist_r2 / scale, edges=bin_edges, ax=ax, color=colors[1], kind="step", lw=1.5)
        ax.format(ymax=1.1)

        for name in ["gaussian", "kv"]:
            x = mf.dist.dist_nd.gen_dist(name=name, noise=0.0).sample(n_samples)
            x = grab(x)     
            r = np.linalg.norm(x, axis=1)
            hist, _ = np.histogram(r, bins=bin_edges, density=True)
            if normalize:
                for i in range(len(bin_edges) - 1):
                    rmin = bin_edges[i]
                    rmax = bin_edges[i + 1]
                    hist[i] = hist[i] / mf.utils.sphere_shell_volume(rmin=rmin, rmax=rmax, d=x.shape[1])

            r = np.linspace(0.0, rmax, 100)
            ax.plot(r, np.exp(-0.5 * r**2), color="black", alpha=0.1, ls="--", zorder=0, lw=1.5)            
        ax.format(xlabel="Radius", ylabel="PDF")
        pplt.show()

In [None]:
with torch.no_grad():
    x = model.sample(1000000)
    x = grab(x)


psv.points.proj2d_interactive_slice(x, cmap="viridis", options=dict(mask=True), autolim_kws=dict(sigma=3.5, zero_center=True))

### Radial CDF

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0, 
        max=(len(checkpoints) - 1), 
        value=(len(checkpoints) - 1),
    ),
    n_samples=widgets.FloatLogSlider(min=2, max=6, value=1.00e+05),
    rmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.0),
    bins=widgets.IntSlider(min=4, max=150, value=50),
    kind=["step", "line"],
)
def update(index, n_samples, rmax, bins, kind):
    model.load(checkpoints[index]["path"], device)

    with torch.no_grad():
        n_samples = int(n_samples)
        x1 = grab(dist.sample(n_samples))
        x2 = grab(model.sample(n_samples))
        
        fig, ax = mf.train.plot.plot_dist_radial_cdf(x1, x2, rmax=rmax, bins=bins, kind=kind, lw=1.5, colors=["red4", "black"])

        for name in ["kv", "gaussian"]:
            x = mf.dist.dist_nd.gen_dist(name=name, noise=0.0).sample(n_samples)
            x = grab(x)     
            hist, edges = ps.points.radial_histogram(x, bins=75, limits=(0.0, rmax))
            cdf = np.cumsum(hist)
            cdf = cdf / cdf[-1]
            psv.plot_profile(cdf, edges=edges, ax=ax, kind="line", color="black", alpha=0.1, ls="--", zorder=0, lw=1.5)
        ax.format(xlabel="Radius", ylabel="CDF")
        pplt.show()

## 2D projections

In [None]:
@interact(
    dim1=widgets.Dropdown(options=range(cfg.d), value=0),
    dim2=widgets.Dropdown(options=range(cfg.d), value=1),
    index=widgets.IntSlider(
        min=0, 
        max=(len(checkpoints) - 1), 
        value=(len(checkpoints) - 1),
    ),
    n=widgets.FloatLogSlider(min=2.0, max=6.0, value=1.00e+05),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.0),
    bins=widgets.IntSlider(min=4, max=150, value=85),
)
def update(dim1, dim2, index, n, xmax, bins):
    if dim1 == dim2:
        return
    model.load(checkpoints[index]["path"], device)

    with torch.no_grad():
        x1 = grab(dist.sample(int(n)))
        x2 = grab(model.sample(int(n)))

        fig, axs = pplt.subplots(ncols=2, xspineloc="neither", yspineloc="neither")
        for x, ax in zip([x1, x2], axs):
            psv.points.plot2d(
                x[:, (dim1, dim2)],
                bins=bins,
                limits=(2 * [(-xmax, xmax)]),
                ax=ax,
                mask=False,
            )
        pplt.show()

### Corner plot

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0,
        max=(len(checkpoints) - 1),
        value=(len(checkpoints) - 1),
    ),
    n=widgets.FloatLogSlider(min=2, max=6, value=1.00e+05, continuous_update=False),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.0),
    bins=widgets.IntSlider(min=4, max=150, value=85),
)
def update(index, n, xmax, bins):
    model.load(checkpoints[index]["path"], device)

    with torch.no_grad():
        x1 = grab(dist.sample(int(n)))
        x2 = grab(model.sample(int(n)))

        mask = True

        grid = psv.CornerGrid(d=x1.shape[1], corner=False)
        limits = [(-xmax, xmax)] * x1.shape[1]
        grid.plot_points(
            x2,
            lower=True,
            upper=False,
            mask=mask,
            bins=bins,
            limits=limits,
            cmap="blues",
            diag_kws=dict(color="blue7", lw=1.5),
        )
        grid.plot_points(
            x1,
            upper=True,
            lower=False,
            mask=mask,
            bins=bins,
            limits=limits,
            cmap="reds",
            diag_kws=dict(color="red7", lw=1.5),
        )
        pplt.show()

### Simulated measurements

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0, 
        max=(len(checkpoints) - 1), 
        value=(len(checkpoints) - 1),
    ),
    n=widgets.FloatLogSlider(min=2, max=6, value=50000),
    log_ymin=widgets.FloatSlider(min=-5.0, max=-1.0, value=-3.0),
    maxcols=widgets.IntSlider(min=4, max=12, value=7),
    figwidthscale=widgets.FloatSlider(min=0.5, max=2.5, value=1.25),
)
def update(index, n, log_ymin, maxcols, figwidthscale, kde=False, log=False):
    model.load(checkpoints[index]["path"], device=device)

    with torch.no_grad():
        # Generate samples from the model.
        x = model.sample(int(n))
        x = send(x)

        # Compute projections.
        for diagnostic in unravel(model.diagnostics):
            diagnostic.kde = kde
            
        predictions = mf.sim.forward(x, model.transforms, model.diagnostics)
        
        y_meas = [grab(meas) for meas in unravel(model.measurements)]
        y_pred = [grab(pred) for pred in unravel(predictions)]      
        coords = [
            [grab(c) for c in diagnostic.bin_coords] 
            for diagnostic in unravel(model.diagnostics)
        ]

        # Setup figure.
        ncols = min(len(y_meas), maxcols)
        nrows = 2 * int(np.ceil(len(y_meas) / ncols))
        
        fig, axs = pplt.subplots(
            ncols=ncols, nrows=nrows, 
            figwidth=min(figwidthscale * ncols, 10.0),
            xspineloc="neither",
            yspineloc="neither",
            space=0.0,
        )

        # Plot truth in first row, predictions in second row, etc.
        kws = dict(
            norm=("log" if log else None),
        )
        i = 0
        for row in range(0, nrows, 2):
            for col in range(ncols):    
                if i < len(y_meas):
                    ax_index = row * ncols + col
                    scale = np.max(y_meas[i])        
                    psv.image.plot2d(y_meas[i] / scale, coords=coords[i], ax=axs[ax_index], **kws)
                    psv.image.plot2d(y_pred[i] / scale, coords=coords[i], ax=axs[ax_index + ncols], **kws)
                i += 1
                
        plt.show()

### Grid warp 

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0, 
        max=(len(checkpoints) - 1), 
        value=(len(checkpoints) - 1),
    ),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.0),
    n_lines=widgets.IntSlider(min=0.0, max=200, value=100),
)
def update(index, xmax, n_lines):
    model.load(checkpoints[index]["path"], device=device)

    with torch.no_grad():
        fig, axs = pplt.subplots(ncols=2)

        res = 150
        grid = np.linspace(-3.5 * xmax, 3.5 * xmax, n_lines)
        line_points = []
        for i in range(len(grid)):
            x = np.full(res, grid[i])
            y = np.linspace(-xmax, xmax, res)
            line_points.append(np.vstack([x, y]).T)
            line_points.append(np.vstack([y, x]).T)
        
        kws = dict(color="black", lw=0.6, alpha=0.25)
        
        for z in line_points:
            axs[0].plot(z[:, 0], z[:, 1], **kws)
            z = np.hstack([z, np.zeros((z.shape[0], cfg.d - z.shape[1]))])
            z = send(torch.from_numpy(z))
            x = grab(model.gen.forward(z))
            for ax in axs[1:]:
                ax.plot(x[:, 0], x[:, 1], **kws)
        
        z = model.gen.sample_base(1000)
        x = model.gen.forward(z)
        z = grab(z)
        x = grab(x)
        kws = dict(c="black", zorder=999, s=1)
        axs[0].scatter(z[:, 0], z[:, 1], **kws)
        axs[1].scatter(x[:, 0], x[:, 1], **kws)
        axs.format(xlim=(-xmax, xmax), ylim=(-xmax, xmax))
        plt.show()

### Flow trajectory

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0, 
        max=(len(checkpoints) - 1), 
        value=(len(checkpoints) - 1),
    ),
    n=widgets.FloatLogSlider(min=2, max=6, value=1.00e+05),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.0),
    bins=widgets.IntSlider(min=4, max=150, value=50),
)
def update(index, n, xmax, bins):        
    model.load(checkpoints[index]["path"], device)

    with torch.no_grad():
        z = model.gen.sample_base(int(n))
        
        try:
            xt = model.gen.forward_steps(z)
        except:
            print("`model.gen` does not have `forward_steps` method.")
    
        fig, axs = pplt.subplots(
            figheight=1.25, 
            ncols=len(xt), 
            space=None,
            xticks=[],
            yticks=[], 
            xspineloc="neither",
            yspineloc="neither",
        )
        limits = 2 * [(-xmax, xmax)]
        for ax, x in zip(axs, xt):
            x = grab(x)
            ax.hist2d(x[:, 0], x[:, 1], bins=bins, range=limits)
        axs.format(xlim=limits[0], ylim=limits[1])
        plt.show()