# Evaluate 2D MENT-Flow model

In [None]:
import os
import pickle
import sys

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt
import torch
from ipywidgets import interact, widgets

import mentflow as mf
from mentflow.utils import grab
from mentflow.utils import unravel

from experiments.load import load_run
from setup import make_dist

In [None]:
mf.train.plot.set_proplot_rc()

## Settings

In [None]:
device = torch.device("cpu")
precision = torch.float32

def send(x):
    return x.type(precision).to(device)

## Load data

In [None]:
data_name = "two-spirals"
script_name = "train_flow"
timestamp = 240107212132
data_dir = f"./output/{data_name}/{script_name}/{timestamp}/"

In [None]:
run = load_run(data_dir)

cfg = run["cfg"]
model = run["model"]
checkpoints = run["checkpoints"]
history = run["history"]

In [None]:
dist = make_dist(cfg)

## Scalar history

In [None]:
keys_sorted = sorted(list(history.keys()))
imax = len(history[keys_sorted[0]])

@interact(
    key=widgets.Dropdown(options=keys_sorted, value="D_norm"),
    irange=widgets.IntRangeSlider(min=0, max=imax, value=(0, imax)),
    log=False,
)
def update(key, irange, log):
    vals = history[key]
    avgs = mf.utils.exp_avg(vals, momentum=0.95)

    fig, ax = pplt.subplots()
    ax.plot(vals[irange[0] : irange[1]], color="gray")
    ax.plot(avgs[irange[0] : irange[1]], color="black", lw=1.0)    
    if log:
        ax.format(yscale="log", yformatter="log")
    ax.format(xlabel="Iteration (global)", ylabel=key)
    plt.show()

## Evaluation

### Density

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0, 
        max=(len(checkpoints) - 1), 
        value=(len(checkpoints) - 1),
    ),
    n=widgets.FloatLogSlider(min=2, max=6, value=1.00e+05),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.0),
    bins=widgets.IntSlider(min=4, max=150, value=125),
)
def update(index, n, xmax, bins):
    model.load(checkpoints[index]["path"], device)

    with torch.no_grad():
        x = grab(model.sample(int(n)))
        x0 = grab(dist.sample(int(n)))
        
        fig, axs = pplt.subplots(ncols=2, xspineloc="neither", yspineloc="neither", space=0)
        kws = dict()
        limits = 2 * [(-xmax, xmax)]
        axs[0].hist2d(x0[:, 0], x0[:, 1], bins=bins, range=limits, **kws)
        axs[1].hist2d(x[:, 0], x[:, 1], bins=bins, range=limits, **kws)
        pplt.show()

### Projections

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0, 
        max=(len(checkpoints) - 1), 
        value=(len(checkpoints) - 1),
    ),
    n=widgets.FloatLogSlider(min=2, max=6, value=50000),
    log_ymin=widgets.FloatSlider(min=-5.0, max=-1.0, value=-3.0),
)
def update(index, n, log_ymin, kde=False, log=False):
    model.load(checkpoints[index]["path"], device=device)

    with torch.no_grad():
        x = send(model.sample(int(n)))

        for diagnostic in unravel(model.diagnostics):
            diagnostic.kde = kde
            
        predictions = mf.sim.forward(x, model.transforms, model.diagnostics)

        fig, axs = mf.train.plot.plot_proj_1d(
            [grab(measurement) for measurement in unravel(model.measurements)],
            [grab(prediction) for prediction in unravel(predictions)],
            [grab(diagnostic.bin_edges) for diagnostic in unravel(model.diagnostics)],
            maxcols=7,
            kind="line",
            height=1.25,
            lw=1.5,
        )
        axs.format(ymax=1.25, ymin=(10.0 ** log_ymin))
        if log:
            axs.format(yscale="log")
        plt.show()

### Grid warp 

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0, 
        max=(len(checkpoints) - 1), 
        value=(len(checkpoints) - 1),
    ),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.0),
    n_lines=widgets.IntSlider(min=0.0, max=200, value=100),
)
def update(index, xmax, n_lines):
    model.load(checkpoints[index]["path"], device=device)

    with torch.no_grad():
        fig, axs = pplt.subplots(ncols=2)

        res = 150
        grid = np.linspace(-3.5 * xmax, 3.5 * xmax, n_lines)
        line_points = []
        for i in range(len(grid)):
            x = np.full(res, grid[i])
            y = np.linspace(-xmax, xmax, res)
            line_points.append(np.vstack([x, y]).T)
            line_points.append(np.vstack([y, x]).T)
        
        kws = dict(color="black", lw=0.6, alpha=0.25)
        
        for z in line_points:
            axs[0].plot(z[:, 0], z[:, 1], **kws)
            z = send(torch.from_numpy(z))
            x = grab(model.gen.forward(z))
            for ax in axs[1:]:
                ax.plot(x[:, 0], x[:, 1], **kws)
        
        z = model.gen.sample_base(1000)
        x = model.gen.forward(z)
        z = grab(z)
        x = grab(x)
        kws = dict(c="black", zorder=999, s=1)
        axs[0].scatter(z[:, 0], z[:, 1], **kws)
        axs[1].scatter(x[:, 0], x[:, 1], **kws)
        axs.format(xlim=(-xmax, xmax), ylim=(-xmax, xmax))
        plt.show()

### Flow trajectory

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0, 
        max=(len(checkpoints) - 1), 
        value=(len(checkpoints) - 1),
    ),
    n=widgets.FloatLogSlider(min=2, max=6, value=1.00e+05),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.0),
    bins=widgets.IntSlider(min=4, max=150, value=125),
)
def update(index, n, xmax, bins):        
    model.load(checkpoints[index]["path"], device)

    with torch.no_grad():
        flow = model.gen.get_flow()
        z = model.gen.sample_base(int(n))

        try:
            xt = model.gen.forward_steps(z)
        except:
            print("`model.gen` does not have `forward_steps` method.")
    
        fig, axs = pplt.subplots(
            figheight=2.0, 
            ncols=len(xt), 
            space=None,
            xticks=[],
            yticks=[], 
            xspineloc="neither",
            yspineloc="neither",
        )
        limits = 2 * [(-xmax, xmax)]
        for ax, x in zip(axs, xt):
            x = grab(x)
            ax.hist2d(x[:, 0], x[:, 1], bins=bins, range=limits)
        axs.format(xlim=limits[0], ylim=limits[1])
        plt.show()