# Evaluate 2D MENT-Flow model

In [None]:
import os
import pickle
import sys
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt
import skimage.filters
import torch
from ipywidgets import interact
from ipywidgets import widgets

import mentflow as mf
import experiments.setup

import plotting

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

## Setup

In [None]:
data_name = "spirals"
timestamp = 231128232746
script_name = "train_nsf_penalty"
datadir = f"./data_output/{data_name}/{script_name}/{timestamp}/"

In [None]:
device = torch.device("cpu")
precision = torch.float32

In [None]:
def grab(x):
    return x.detach().cpu().numpy()

In [None]:
def cvt(x):
    return x.type(precision).to(device)

Get info from log.

In [None]:
path = os.path.join(datadir, "log.txt")

info = {}
file = open(path, "r")
for line in file:
    line = line.rstrip()
    if line.startswith("Namespace"):
        line = line.split("Namespace(")[1].split(")")[0]
        strings = [string.strip() for string in line.split(",")]
        for string in strings:
            key, val = string.split("=")
            info[key] = val
file.close()

print(info)

Get info to generate data.

## Training history

In [None]:
path = os.path.join(datadir, "history.pkl")
with open(path, "rb") as file:
    history = pickle.load(file)
    
print(list(history.keys()))

In [None]:
keys_sorted = sorted(list(history.keys()))
imax = len(history[keys_sorted[0]])

@interact(
    key=widgets.Dropdown(options=keys_sorted, value="C_norm"),
    irange=widgets.IntRangeSlider(min=0, max=imax, value=(0, imax)),
    log=False,
)
def update(key, irange, log):
    vals = history[key]
    
    avg = vals[0]
    avgs = []
    for i, val in enumerate(vals):
        avg = 0.99 * avg + 0.01 * val
        avgs.append(avg)

    fig, ax = pplt.subplots()
    ax.plot(vals[irange[0] : irange[1]], color="gray")
    ax.plot(avgs[irange[0] : irange[1]], color="black", lw=1.0)    
    if log:
        ax.format(yscale="log", yformatter="log")
    ax.format(xlabel="Iteration (global)", ylabel=key)
    plt.show()

In [None]:
fig, ax = pplt.subplots()
ax.plot(history["epoch"], history["C_norm"], marker=".")

## Load model

Setup model from config file.

In [None]:
path = os.path.join(datadir, "config.pkl")
with open(path, "rb") as file:
    cfg = pickle.load(file)
    
pprint(cfg)

In [None]:
model = experiments.setup.setup_model(cfg)
model.eval()

Get model checkpoint filename, along with step and iteration.

In [None]:
subdir = "checkpoints"
subdir = os.path.join(datadir, subdir)
checkpoint_paths = os.listdir(subdir)
checkpoint_paths = [os.path.join(subdir, f) for f in checkpoint_paths]
checkpoint_paths = sorted(checkpoint_paths)

checkpoints = []
for path in checkpoint_paths:
    (step, iteration) = experiments.setup.get_step_and_iteration_number(path)
    checkpoints.append(
        {
            "step": step,
            "iteration": iteration,
            "path": path,
        }
    )

Get the diagnostic, lattices, and measurements. These will be reloaded at every `model.load(path)` but are the same for every model.

In [None]:
model.load(checkpoints[0]["path"], device=device)
diagnostic = model.diagnostic
lattices = model.lattices
measurements = model.measurements
measurements_np = [grab(measurement) for measurement in measurements]

## Evaluation

In [None]:
def load_data(n):
    return mf.data.toy.gen_data(
        name=cfg["data"]["name"],
        size=n,
        noise=cfg["data"]["noise"],
        shuffle=cfg["data"]["shuffle"],
        decorr=cfg["data"]["decorr"],
        seed=cfg["data"]["seed"],
        warp=cfg["data"]["warp"],
    )

### Density

In [None]:
@interact(
    index=widgets.IntSlider(min=0, max=(len(checkpoints) - 1), value=0),
    n=widgets.FloatLogSlider(min=2, max=6, value=1.00e+06),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.0),
    bins=widgets.IntSlider(min=4, max=150, value=125),
    res=widgets.IntSlider(min=4, max=250, value=150),
)
def update(index, n, xmax, bins, res):
    model.load(checkpoints[index]["path"], device=device)

    n = int(n)

    x = model.sample(n)
    x = grab(x)

    x0 = load_data(n)

    limits = 2 * [(-xmax, xmax)]
    coords = 2 * [torch.linspace(-xmax, xmax, res)]
    x_grid = torch.vstack([C.ravel() for C in torch.meshgrid(*coords, indexing="ij")]).T
    x_grid = x_grid.to(device)
    log_prob = model.log_prob(x_grid)
    prob = torch.exp(log_prob)
    prob = prob.reshape((res, res))
    prob = grab(prob)

    fig, axs = pplt.subplots(ncols=3)
    kws = dict(
        # cmap="blues"
    )
    axs[0].hist2d(x0[:, 0], x0[:, 1], bins=bins, range=limits, **kws)
    axs[1].hist2d(x[:, 0], x[:, 1], bins=bins, range=limits, **kws)
    axs[2].pcolormesh(coords[0], coords[1], prob.T, **kws)

### Projections

In [None]:
@interact(
    index=widgets.IntSlider(min=0, max=(len(checkpoints) - 1), value=0),
    n=widgets.FloatLogSlider(min=2, max=6, value=50000),
    log_ymin=widgets.FloatSlider(min=-5.0, max=-1.0, value=-3.0),
)
def update(index, n, log_ymin, kde=True, log=False):
    model.load(checkpoints[index]["path"], device=device)

    with torch.no_grad():

        x = model.sample(int(n))
        x = cvt(x)
        predictions = model.simulate(x, kde=kde)
        
        _predictions = [grab(prediction) for prediction in predictions]
        _measurements = [grab(measurement) for measurement in model.measurements]
        
        edges = grab(diagnostic.bin_edges)
        coords = 0.5 * edges[:-1] + edges[1:]

        kws = dict()

        maxcols = 7
        ncols = min(len(_measurements), maxcols)
        nrows = int(np.ceil(len(_measurements) / ncols))
        figheight = 1.3 * nrows
        figwidth = min(10.0, 1.75 * ncols)
        
        fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figheight=figheight, figwidth=figwidth)
        for j in range(len(measurements)):
            ax = axs[j]            
            scale = np.max(_measurements[j])
            ax.plot(coords, _measurements[j] / scale, color="black", **kws)
            ax.plot(coords, _predictions[j] / scale, color="red", **kws)
        axs.format(ymax=1.25)
        axs.format(ymin=(10.0 ** log_ymin))
        if log:
            axs.format(yscale="log")
        plt.show()