# Train 2D MENT model

In [None]:
import argparse
import os
import pathlib
import pickle
import sys
import time

import matplotlib.pyplot as plt
import numpy as np
import ot
import proplot as pplt
import scipy.interpolate
import torch
import zuko
from ipywidgets import interact
from ipywidgets import widgets
from tqdm.notebook import tqdm

import mentflow as mf
from mentflow.utils import grab
from mentflow.utils import unravel

# Local
import arguments
import plotting
import utils

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = pplt.Colormap("dark_r", space="hpl")
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

## Setup

In [None]:
parser = arguments.make_parser(model="ment")
args = parser.parse_args([])  # not all used
args  

In [None]:
# Create output directories.
file_path = os.path.join(os.path.abspath("."), "train_ment.ipynb")
output_dir = os.path.join(os.path.abspath("."), f"output/{args.data}/")

man = mf.train.ScriptManager(file_path, output_dir)
man.make_dirs("checkpoints", "figures")

# Save args and copy of this script.
man.save_pickle(vars(args), "args.pkl")
man.save_script_copy()

print("output_dir:", man.output_dir)


# Set random seed.
rng = np.random.default_rng(seed=args.seed)
if args.seed is not None:
    torch.manual_seed(args.seed)

# Set device and precision.
device = torch.device(args.device)
precision = torch.float32
torch.set_default_dtype(precision)

def send(x):
    return x.type(precision).to(device)

## Data

In [None]:
# Define the input distribution.
d = 2
dist = mf.data.toy.gen_dist(
    args.data,
    noise=args.data_noise,
    decorr=args.data_decorr,
    rng=rng,
)
man.save_pickle(dist, "dist.pkl")

# Draw samples from the input distribution.
x0 = dist.sample(args.data_size)
x0 = send(x0)

# Define transforms.
transforms = utils.make_transforms_rotation(
    args.meas_angle_min, args.meas_angle_max, args.meas_num
)
transforms = [transform.to(device) for transform in transforms]

# Create histogram diagnostic (x axis).
xmax = args.meas_xmax
bin_edges = torch.linspace(-xmax, xmax, args.meas_bins + 1)
bin_edges = send(bin_edges)

diagnostic = mf.diagnostics.Histogram1D(axis=0, bin_edges=bin_edges)
diagnostic = diagnostic.to(device)
diagnostics = [diagnostic]

# Generate training data.
measurements = mf.simulate_nokde(x0, transforms, diagnostics)

# Add measurement noise.
measurements = utils.add_measurement_noise(
    measurements, 
    scale=args.meas_noise, 
    noise_type=args.meas_noise_type, 
    device=device
)

View the distribution in the transformed space.

In [None]:
@interact
def update(index=(0, args.meas_num - 1)):
    x = transforms[index](x0)
    x = grab(x)
    
    fig, ax = pplt.subplots()
    limits = 2 * [(-xmax, +xmax)]
    ax.hist2d(x[:, 0], x[:, 1], bins=100, range=limits)
    
    pax = ax.panel_axes("bottom", width=0.75)
    _hist, _edges = np.histogram(x[:, 0], bins=(len(bin_edges) - 1), density=True)
    _meas = grab(measurements[index][0])
    _scale = _meas.max()
    plotting.plot_hist(_hist / _scale, _edges, ax=pax, color="black", alpha=0.3)
    plotting.plot_hist(_meas / _scale, grab(bin_edges), ax=pax, color="black")

Plot the integration lines in the input and transformed space.

In [None]:
@interact
def update(index=(0, len(measurements) - 1)):
    transform = transforms[index]
    
    x = send(dist.sample(100000))
    y = transform(x)
    y = grab(y)
    x = grab(x)
    
    fig, axs = pplt.subplots(ncols=2)
    limits = 2 * [(-xmax, +xmax)]
    
    for ax, _x in zip(axs, [x, y]):
        ax.hist2d(_x[:, 0], _x[:, 1], bins=100, range=limits,)
    
    n_lines = 20
    n_dots_per_line = 100
    y = mf.utils.get_grid_points_torch(
        2.0 * torch.linspace(-xmax, +xmax, n_lines),
        2.0 * torch.linspace(-xmax, +xmax, n_dots_per_line),
    )
    y = send(y)
    
    x = transform.inverse(y)
    x = grab(x)
    y = grab(y)
    
    for ax, _x in zip(axs, [x, y]):
        for line in np.split(_x, n_lines):
            ax.plot(line[:, 0], line[:, 1], color="white", alpha=0.5)
    axs.format(xlim=(-xmax, xmax), ylim=(-xmax, xmax))
    pplt.show()

## Model

In [None]:
prior = None
if args.prior == "gaussian":
    prior = mf.models.ment.GaussianPrior(d=2, scale=args.prior_scale, device=device)
if args.prior == "uniform":
    prior = mf.models.ment.UniformPrior(d=2, scale=(10.0 * xmax), device=device)

sampler_limits = 2 * [(-xmax, +xmax)]
sampler_limits = 1.1 * np.array(sampler_limits)
sampler = mf.sample.GridSampler(limits=sampler_limits, res=200, device=device)

model = mf.models.ment.MENT(
    d=2,
    transforms=transforms,
    measurements=measurements,
    diagnostics=diagnostics,
    discrepancy_function=args.disc,
    prior=prior,
    sampler=sampler,
    interpolate=args.interpolate,  # {"nearest", "linear", "pchip"}
    device=device,
)

## Training

In [None]:
# Make output folders.
output_dir = man.output_dir

if output_dir is not None:
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    fig_dir = os.path.join(output_dir, f"figures")
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)
        
    checkpoint_dir = os.path.join(output_dir, f"checkpoints")
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

In [None]:
# Define simulation method (integration or particle tracking).
sim_method = args.method
sim_kws = {
    "limits": [(-xmax, +xmax)],  # integration limits
    "shape": (300,),  # integration grid shape (resolution)
    "n": int(1.00e+06),  # number of particles (if sim_method="sample")
}


In [None]:
# Training diagnostics:

class Plotter:
    def __init__(self):
        return

    def __call__(self, model):
        figs = plotting.plot_model(
            model,
            dist,
            n=args.vis_size, 
            sim_kws=dict(method=sim_method, **sim_kws),
            bins=args.vis_bins,
            xmax=xmax,
            maxcols=args.vis_maxcols, 
            kind=args.vis_line,
            device=device
        )
        return figs


class Evaluator:
    def __init__(self, swd=True):
        self.swd = swd

    def __call__(self, model):
        # Draw samples from the model and true distribution.
        x_true = dist.sample(args.vis_size)
        x_true = send(x_true)
        x = model.sample(args.vis_size)
        x = send(x)
    
        # Compute simulation-measurement discrepancy
        predictions = model.simulate(method=sim_method, **sim_kws)
        discrepancy_vector = model.discrepancy(predictions)
        discrepancy = sum(discrepancy_vector) / len(discrepancy_vector)
    
        # Compute sliced wasserstein distance between true and model samples.
        distance = None
        if self.swd:
            n = 100000
            distance = ot.sliced.sliced_wasserstein_distance(
                x[:n],
                x_true[:n], 
                n_projections=50, 
                p=2, 
            )
    
        # Print summary
        print("D(y_model, y_true) = {}".format(discrepancy))
        print("SWD(x_model, x_true) = {}".format(distance))
        return None


evaluator = Evaluator(swd=True)
plotter = Plotter()

Training loop:

In [None]:
for epoch in range(args.epochs):
    print(f"epoch = {epoch}")

    # Evaluate model
    result = evaluator(model)

    # Save model
    if output_dir is not None:
        filename = f"checkpoint_{epoch:03.0f}_00000.pt"
        filename = os.path.join(checkpoint_dir, filename)
        print(f"Saving file {filename}")
        model.save(filename)

    # Make figures
    figs = plotter(model)
    for index, fig in enumerate(figs):
        if output_dir is not None:
            filename = f"fig_{index:02.0f}_{epoch:03.0f}_00000.{args.fig_ext}"
            filename = os.path.join(fig_dir, filename)
            print(f"Saving file {filename}")
            fig.savefig(filename, dpi=args.fig_dpi)
        plt.show()
        plt.close("all")

    # Update Gauss-Seidel
    model.gauss_seidel_iterate(omega=args.omega, sim_method=sim_method, **sim_kws)