# Evaluate 2D MENT model

In [None]:
import os
import pickle
import sys

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt
import torch
from ipywidgets import interact, widgets

import mentflow as mf
from mentflow.utils import grab
from mentflow.utils import unravel

from experiments.load import load_ment_run
from setup import make_dist

In [None]:
mf.train.plot.set_proplot_rc()

## Load data

In [None]:
script_name = "train_ment"
timestamp = 240112163454
data_dir = f"./outputs/{script_name}/{timestamp}/"

In [None]:
run = load_ment_run(data_dir)

cfg = run["config"]
history = run["history"]
model = run["model"]
checkpoints = run["checkpoints"]

In [None]:
dist = make_dist(cfg)

## Evaluation

### Density

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0, 
        max=(len(checkpoints) - 1), 
        value=(len(checkpoints) - 1),
    ),
    n=widgets.FloatLogSlider(min=2, max=6, value=1.00e+05),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.0),
    bins=widgets.IntSlider(min=4, max=150, value=125),
)
def update(index, n, xmax, bins):
    model.load(checkpoints[index]["path"])

    with torch.no_grad():
        x = grab(model.sample(int(n)))
        x0 = grab(dist.sample(int(n)))
        
        fig, axs = pplt.subplots(ncols=2, xspineloc="neither", yspineloc="neither", space=0)
        kws = dict()
        limits = 2 * [(-xmax, xmax)]
        axs[0].hist2d(x0[:, 0], x0[:, 1], bins=bins, range=limits, **kws)
        axs[1].hist2d(x[:, 0], x[:, 1], bins=bins, range=limits, **kws)
        pplt.show()

### Projections

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0, 
        max=(len(checkpoints) - 1), 
        value=(len(checkpoints) - 1),
    ),
    n=widgets.FloatLogSlider(min=2, max=6, value=50000),
    log_ymin=widgets.FloatSlider(min=-5.0, max=-1.0, value=-3.0),
)
def update(index, n, log_ymin, kde=False, log=False):
    model.load(checkpoints[index]["path"])

    with torch.no_grad():
        x = model.sample(int(n))

        for diagnostic in unravel(model.diagnostics):
            diagnostic.kde = kde
            
        predictions = mf.sim.forward(x, model.transforms, model.diagnostics)

        fig, axs = mf.train.plot.plot_proj_1d(
            [grab(measurement) for measurement in unravel(model.measurements)],
            [grab(prediction) for prediction in unravel(predictions)],
            [grab(diagnostic.bin_edges) for diagnostic in unravel(model.diagnostics)],
            maxcols=7,
            kind="line",
            height=1.25,
            lw=1.5,
        )
        axs.format(ymax=1.25, ymin=(10.0 ** log_ymin))
        if log:
            axs.format(yscale="log")
        plt.show()