# Evaluate 2D MENT model

In [None]:
import os
import sys
import time

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.visualization as psv
import torch
from ipywidgets import interact
from ipywidgets import widgets

import mentflow as mf
from mentflow.train.plot import set_proplot_rc
from mentflow.train.plot import plot_proj_1d
from mentflow.utils import grab
from mentflow.utils import unravel

sys.path.append("../../..")
from experiments.load import load_ment_run
from experiments.rec_2d.setup import make_distribution

In [None]:
set_proplot_rc()

## Load data

In [None]:
script_name = "train_ment"
timestamp = 240412120752
data_dir = f"./outputs/{script_name}/{timestamp}/"

In [None]:
run = load_ment_run(data_dir)

cfg = run["config"]
history = run["history"]
model = run["model"]
checkpoints = run["checkpoints"]

In [None]:
distribution = make_distribution(cfg)

## Evaluation

### Density

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0,
        max=(len(checkpoints) - 1),
        value=(len(checkpoints) - 1),
    ),
    size=widgets.FloatLogSlider(min=2, max=6, value=1.00e05),
    xmax=widgets.FloatSlider(min=0.0, max=6.0, value=3.0),
    bins=widgets.IntSlider(min=4, max=150, value=125),
)
def update(index: int, size: int, xmax: float, bins: int):
    model.load(checkpoints[index]["path"], device="cpu")
    
    size = int(size)
    x_pred = grab(model.sample(size))
    x_true = grab(distribution.sample(size))

    fig, axs = pplt.subplots(ncols=2, xspineloc="neither", yspineloc="neither", space=0)
    for ax, x in zip(axs, [x_true, x_pred]):
        ax.hist2d(x[:, 0], x[:, 1], bins=bins, range=[(-xmax, xmax), (-xmax, xmax)])
    pplt.show()

### Projections

In [None]:
@interact(
    index=widgets.IntSlider(
        min=0,
        max=(len(checkpoints) - 1),
        value=(len(checkpoints) - 1),
    ),
    size=widgets.FloatLogSlider(min=2, max=6, value=50000),
    log_ymin=widgets.FloatSlider(min=-5.0, max=-1.0, value=-3.0),
)
def update(index: int, size: float, log_ymin: float, kde: bool = False, log: bool = False):
    # Load model parameters.
    model.load(checkpoints[index]["path"], device="cpu")

    # Plot settigns
    kind = "line"
    lw = 1.25
    colors = ["red4", "black"]

    # Simulate the measurements.
    size = int(size)
    x_pred = model.sample(int(size))

    for diagnostic in unravel(model.diagnostics):
        diagnostic.kde = kde

    predictions = mf.simulate.forward(x_pred, model.transforms, model.diagnostics)

    # Plot simulated vs. measured profiles.
    y_pred = [grab(pred) for pred in unravel(predictions)]
    y_meas = [grab(meas) for meas in unravel(model.measurements)]
    edges = [grab(diag.edges) for diag in unravel(model.diagnostics)]

    fig, axs = plot_proj_1d(
        y_pred,
        y_meas,
        edges,
        maxcols=7,
        kind=kind,
        height=1.25,
        lw=lw,
        colors=colors,
    )
    axs.format(ymax=1.25, ymin=(10.0**log_ymin))
    if log:
        axs.format(yscale="log")
    pplt.show()