# Evaluate 2D MENT-Flow model

In [None]:
import os
import pickle
import sys
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt
import skimage.filters
import torch
from ipywidgets import interact
from ipywidgets import widgets

import mentflow as mf
from mentflow.utils import grab

import experiments.load
import plotting

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

## Load data

In [None]:
data_name = "swissroll"
timestamp = 231217205901
script_name = "train"
datadir = f"./output/{data_name}/{script_name}/{timestamp}/"

In [None]:
device = torch.device("cpu")
precision = torch.float32

In [None]:
def send(x):
    return x.type(precision).to(device)

In [None]:
run = experiments.load.load_run(datadir)

args = run["args"]
cfg = run["cfg"]
dist = run["dist"]
model = run["model"]
checkpoints = run["checkpoints"]
history = run["history"]

## Scalar history

In [None]:
keys_sorted = sorted(list(history.keys()))
imax = len(history[keys_sorted[0]])

@interact(
    key=widgets.Dropdown(options=keys_sorted, value="D_norm"),
    irange=widgets.IntRangeSlider(min=0, max=imax, value=(0, imax)),
    log=False,
)
def update(key, irange, log):
    vals = history[key]
    
    avg = vals[0]
    avgs = []
    for i, val in enumerate(vals):
        avg = 0.99 * avg + 0.01 * val
        avgs.append(avg)

    fig, ax = pplt.subplots()
    ax.plot(vals[irange[0] : irange[1]], color="gray")
    ax.plot(avgs[irange[0] : irange[1]], color="black", lw=1.0)    
    if log:
        ax.format(yscale="log", yformatter="log")
    ax.format(xlabel="Iteration (global)", ylabel=key)
    plt.show()

## Evaluation

### Density

In [None]:
@interact(
    index=widgets.IntSlider(min=0, max=(len(checkpoints) - 1), value=0),
    n=widgets.FloatLogSlider(min=2, max=6, value=1.00e+05),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.0),
    bins=widgets.IntSlider(min=4, max=150, value=125),
    res=widgets.IntSlider(min=4, max=250, value=150),
    cmap=["viridis", "plasma", "blues", "mono", pplt.Colormap("dark_r", space="hpl")],
)
def update(index, n, xmax, bins, res, cmap):
    model.load(checkpoints[index]["path"], device)

    with torch.no_grad():
        x = model.sample(int(n))
        x = grab(x)
        x0 = dist.sample(int(n))
    
        limits = 2 * [(-xmax, xmax)]
        coords = 2 * [torch.linspace(-xmax, xmax, res)]
        x_grid = torch.vstack([C.ravel() for C in torch.meshgrid(*coords, indexing="ij")]).T
        x_grid = x_grid.to(device)
        log_prob = model.log_prob(x_grid)
        prob = torch.exp(log_prob)
        prob = prob.reshape((res, res))
        prob = grab(prob)
    
        fig, axs = pplt.subplots(ncols=3, figheight=None)
        kws = dict(cmap=cmap)
        axs[0].hist2d(x0[:, 0], x0[:, 1], bins=bins, range=limits, **kws)
        axs[1].hist2d(x[:, 0], x[:, 1], bins=bins, range=limits, **kws)
        axs[2].pcolormesh(coords[0], coords[1], prob.T, **kws)

### Flow trajectory

In [None]:
@interact(
    index=widgets.IntSlider(min=0, max=(len(checkpoints) - 1), value=0),
    n=widgets.FloatLogSlider(min=2, max=6, value=1.00e+05),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.0),
    bins=widgets.IntSlider(min=4, max=150, value=125),
    res=widgets.IntSlider(min=4, max=250, value=150),
    cmap=["viridis", "plasma", "blues", "mono", pplt.Colormap("dark_r", space="hpl")],
)
def update(index, n, xmax, bins, res, cmap):
    model.load(checkpoints[index]["path"], device)

    with torch.no_grad():
        flow = model.generator()
        x = flow.base.sample((int(n),))
        xs = [x]
        for t in flow.transform.inv.transforms:
            xs.append(t(xs[-1]))
        xs = [grab(x) for x in xs]
    
        fig, axs = pplt.subplots(
            figheight=2.0, 
            ncols=len(xs), 
            space=None,
            xticks=[],
            yticks=[], 
            xspineloc="neither",
            yspineloc="neither",
        )
        limits = 2 * [(-xmax, xmax)]
        for ax, x in zip(axs, xs):
            ax.hist2d(x[:, 0], x[:, 1], bins=75, range=limits, cmap=pplt.Colormap("dark_r"))
        axs.format(xlim=limits[0], ylim=limits[1])
        plt.show()

### Grid warp 

In [None]:
fig, axs = pplt.subplots(ncols=2)

model.load(checkpoints[-1]["path"], device=device)

xmax = 3.0
n_lines = 100
res = 150
scale = 3.0
grid = np.linspace(-scale * xmax, scale * xmax, n_lines)
line_points = []
for i in range(len(grid)):
    x = np.full(res, grid[i])
    y = np.linspace(-xmax, xmax, res)
    line_points.append(np.vstack([x, y]).T)
    line_points.append(np.vstack([y, x]).T)

kws = dict(color="black", lw=0.6, alpha=0.5)
for points in line_points:
    axs[0].plot(points[:, 0], points[:, 1], **kws)
    points = torch.from_numpy(points)
    points = points.type(torch.float32).to(device)
    points = model.generator().transform.inv(points)
    points = points.detach().cpu().numpy()
    for ax in axs[1:]:
        ax.plot(points[:, 0], points[:, 1], **kws)

flow = model.generator()
z = flow.base.sample((1000,))
x = flow.transform.inv(z)
z = grab(z)
x = grab(x)
kws = dict(c="red", zorder=999, s=1)
axs[0].scatter(z[:, 0], z[:, 1], **kws)
axs[1].scatter(x[:, 0], x[:, 1], **kws)

axs.format(xlim=(-xmax, xmax), ylim=(-xmax, xmax))

### Projections

In [None]:
@interact(
    index=widgets.IntSlider(min=0, max=(len(checkpoints) - 1), value=0),
    n=widgets.FloatLogSlider(min=2, max=6, value=50000),
    log_ymin=widgets.FloatSlider(min=-5.0, max=-1.0, value=-3.0),
)
def update(index, n, log_ymin, kde=True, log=False):
    model.load(checkpoints[index]["path"], device=device)

    with torch.no_grad():

        x = model.sample(int(n))
        x = send(x)

        for diagnostic in model.diagnostics:
            diagnostic.kde = kde
        predictions = model.simulate(x)
        
        _predictions = [grab(prediction) for prediction in mf.utils.unravel(predictions)]
        _measurements = [grab(measurement) for measurement in mf.utils.unravel(model.measurements)]
        
        edges = grab(model.diagnostics[0].bin_edges)
        coords = 0.5 * edges[:-1] + edges[1:]

        kws = dict()

        maxcols = 7
        ncols = min(len(_measurements), maxcols)
        nrows = int(np.ceil(len(_measurements) / ncols))
        figheight = 1.3 * nrows
        figwidth = min(10.0, 1.75 * ncols)
        
        fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figheight=figheight, figwidth=figwidth)
        for j in range(len(model.measurements)):
            ax = axs[j]            
            scale = np.max(_measurements[j])
            ax.plot(coords, _measurements[j] / scale, color="black", **kws)
            ax.plot(coords, _predictions[j] / scale, color="red", **kws)
        axs.format(ymax=1.25)
        axs.format(ymin=(10.0 ** log_ymin))
        if log:
            axs.format(yscale="log")
        plt.show()