# Evaluate 2D MENT-Flow model (neural network generator)

In [None]:
import os
import pickle
import sys
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt
import skimage.filters
import torch
from ipywidgets import interact
from ipywidgets import widgets

import mentflow as mf
from mentflow.utils import grab
from mentflow.utils import unravel

import experiments.load
import plotting

In [None]:
plotting.set_proplot_rc()

## Settings

In [None]:
device = torch.device("cpu")
precision = torch.float32

def send(x):
    return x.type(precision).to(device)

## Load data

In [None]:
data_name = "two-spirals"
script_name = "train_nn"
timestamp = 240105172722
datadir = f"./output/{data_name}/{script_name}/{timestamp}/"

if not os.path.exists(datadir):
    raise ValueError("Invalid script_name or timestamp")

In [None]:
run = experiments.load.load_run(datadir, gen_model="nn")

args = run["args"]
cfg = run["cfg"]
dist = run["dist"]
model = run["model"]
checkpoints = run["checkpoints"]
history = run["history"]

## Scalar history

In [None]:
keys_sorted = sorted(list(history.keys()))
imax = len(history[keys_sorted[0]])

@interact(
    key=widgets.Dropdown(options=keys_sorted, value="D_norm"),
    irange=widgets.IntRangeSlider(min=0, max=imax, value=(0, imax)),
    log=False,
)
def update(key, irange, log):
    vals = history[key]
    
    avg = vals[0]
    avgs = []
    for i, val in enumerate(vals):
        avg = 0.99 * avg + 0.01 * val
        avgs.append(avg)

    fig, ax = pplt.subplots()
    ax.plot(vals[irange[0] : irange[1]], color="gray")
    ax.plot(avgs[irange[0] : irange[1]], color="black", lw=1.0)    
    if log:
        ax.format(yscale="log", yformatter="log")
    ax.format(xlabel="Iteration (global)", ylabel=key)
    plt.show()

## Evaluation

### Density

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0, 
        max=(len(checkpoints) - 1), 
        value=(len(checkpoints) - 1),
    ),
    n=widgets.FloatLogSlider(min=2, max=6, value=1.00e+05),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.0),
    bins=widgets.IntSlider(min=4, max=150, value=125),
)
def update(index, n, xmax, bins):
    model.load(checkpoints[index]["path"], device)

    with torch.no_grad():
        x = grab(model.sample(int(n)))
        x0 = grab(dist.sample(int(n)))
        
        fig, axs = pplt.subplots(ncols=2, xspineloc="neither", yspineloc="neither", space=0)
        kws = dict()
        limits = 2 * [(-xmax, xmax)]
        axs[0].hist2d(x0[:, 0], x0[:, 1], bins=bins, range=limits, **kws)
        axs[1].hist2d(x[:, 0], x[:, 1], bins=bins, range=limits, **kws)
        pplt.show()

### Projections

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0, 
        max=(len(checkpoints) - 1), 
        value=(len(checkpoints) - 1),
    ),
    n=widgets.FloatLogSlider(min=2, max=6, value=50000),
    log_ymin=widgets.FloatSlider(min=-5.0, max=-1.0, value=-3.0),
)
def update(index, n, log_ymin, kde=False, log=False):
    model.load(checkpoints[index]["path"], device=device)

    with torch.no_grad():
        x = send(model.sample(int(n)))

        for diagnostic in model.diagnostics:
            diagnostic.kde = kde
        predictions = model.simulate(x)

        fig, axs = plotting.plot_proj(
            [grab(measurement) for measurement in unravel(model.measurements)],
            [grab(prediction) for prediction in unravel(predictions)],
            bin_edges=grab(model.diagnostics[0].bin_edges),
            maxcols=7,
            kind="line",
            height=1.25,
            lw=1.5,
        )
        axs.format(ymax=1.25, ymin=(10.0 ** log_ymin))
        if log:
            axs.format(yscale="log")
        plt.show()

### Grid warp

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0, 
        max=(len(checkpoints) - 1), 
        value=(len(checkpoints) - 1),
    ),
)
def update(index):
    model.load(checkpoints[index]["path"], device=device)
    
    fig, axs = pplt.subplots(ncols=2)
    
    xmax = 3.0
    n_lines = 100
    res = 150
    scale = 3.0
    grid = np.linspace(-scale * xmax, scale * xmax, n_lines)
    line_points = []
    for i in range(len(grid)):
        x = np.full(res, grid[i])
        y = np.linspace(-xmax, xmax, res)
        line_points.append(np.vstack([x, y]).T)
        line_points.append(np.vstack([y, x]).T)
    
    kws = dict(color="black", lw=0.6, alpha=0.25)
    for points in line_points:
        axs[0].plot(points[:, 0], points[:, 1], **kws)
        points = torch.from_numpy(points)
        points = points.type(torch.float32).to(device)
        points = model.generator.transformer(points)
        points = points.detach().cpu().numpy()
        for ax in axs[1:]:
            ax.plot(points[:, 0], points[:, 1], **kws)

    z = model.generator.base.sample((1000,))
    x = model.generator.transformer(z)
    z = grab(z)
    x = grab(x)
    kws = dict(c="black", zorder=999, s=1)
    axs[0].scatter(z[:, 0], z[:, 1], **kws)
    axs[1].scatter(x[:, 0], x[:, 1], **kws)
    
    axs.format(xlim=(-xmax, xmax), ylim=(-xmax, xmax))
    plt.show()