# Train 2D MENT-Flow model (neural network generator)

In [None]:
import argparse
import os
import pathlib
import pickle
import sys
import time

import numpy as np
import proplot as pplt
import scipy.interpolate
import torch
import zuko
from ipywidgets import interact
from ipywidgets import widgets
from tqdm.notebook import tqdm

import mentflow as mf
from mentflow.utils import grab
from mentflow.utils import unravel
from mentflow.utils import save_pickle
from mentflow.wrappers import WrappedZukoFlow

# Local
import arguments
import plotting
import utils

In [None]:
plotting.set_proplot_rc()

## Setup

In [None]:
parser = arguments.make_parser(model="nn")

# Set command line arguments
parser.set_defaults(
    device="mps",
)

args = parser.parse_args([])
args

In [None]:
# Create output directories.
file_path = os.path.join(os.path.abspath("."), "train_nn.ipynb")
output_dir = os.path.join(os.path.abspath("."), f"output/{args.data}/")

man = mf.train.ScriptManager(file_path, output_dir)
man.make_dirs("checkpoints", "figures")

# Save args and copy of this script.
man.save_pickle(vars(args), "args.pkl")
man.save_script_copy()

print("output_dir:", man.output_dir)


# Set random seed.
rng = np.random.default_rng(seed=args.seed)
if args.seed is not None:
    torch.manual_seed(args.seed)

# Set device and precision.
device = torch.device(args.device)
precision = torch.float32
torch.set_default_dtype(precision)

def send(x):
    return x.type(precision).to(device)

## Data

In [None]:
# Define the input distribution.
d = 2
dist = mf.data.toy.gen_dist(
    args.data,
    noise=args.data_noise,
    decorr=args.data_decorr,
    rng=rng,
)
man.save_pickle(dist, "dist.pkl")

# Draw samples from the input distribution.
x0 = dist.sample(args.data_size)
x0 = send(x0)

# Define transforms.
transforms = utils.make_transforms_rotation(
    args.meas_angle_min, args.meas_angle_max, args.meas_num
)
transforms = [transform.to(device) for transform in transforms]

# Create histogram diagnostic (x axis).
xmax = args.meas_xmax
bin_edges = torch.linspace(-xmax, xmax, args.meas_bins + 1)
bin_edges = send(bin_edges)

diagnostic = mf.diagnostics.Histogram1D(axis=0, bin_edges=bin_edges)
diagnostic = diagnostic.to(device)
diagnostics = [diagnostic]

# Generate training data.
measurements = mf.simulate_nokde(x0, transforms, diagnostics)

# Add measurement noise.
measurements = utils.add_measurement_noise(
    measurements, 
    scale=args.meas_noise, 
    noise_type=args.meas_noise_type, 
    device=device
)

View the data in the transformed space.

In [None]:
@interact
def update(index=(0, args.meas_num - 1)):
    x = transforms[index](x0)
    x = grab(x)
    
    fig, ax = pplt.subplots()
    limits = 2 * [(-xmax, +xmax)]
    ax.hist2d(x[:, 0], x[:, 1], bins=100, range=limits)
    
    pax = ax.panel_axes("bottom", width=0.75)
    _hist, _edges = np.histogram(x[:, 0], bins=(len(bin_edges) - 1), density=True)
    _meas = grab(measurements[index][0])
    _scale = _meas.max()
    plotting.plot_hist(_hist / _scale, _edges, ax=pax, color="black", alpha=0.3)
    plotting.plot_hist(_meas / _scale, grab(bin_edges), ax=pax, color="black")

Plot the integration lines in the input and transformed space.

In [None]:
@interact
def update(index=(0, len(measurements) - 1)):
    transform = transforms[index]
    
    x = send(dist.sample(100000))
    y = transform(x)
    y = grab(y)
    x = grab(x)
    
    fig, axs = pplt.subplots(ncols=2)
    limits = 2 * [(-xmax, +xmax)]
    
    for ax, _x in zip(axs, [x, y]):
        ax.hist2d(_x[:, 0], _x[:, 1], bins=100, range=limits,)
    
    n_lines = 20
    n_dots_per_line = 100
    y = mf.utils.get_grid_points_torch(
        2.0 * torch.linspace(-xmax, +xmax, n_lines),
        2.0 * torch.linspace(-xmax, +xmax, n_dots_per_line),
    )
    y = send(y)
    
    x = transform.inverse(y)
    x = grab(x)
    y = grab(y)
    
    for ax, _x in zip(axs, [x, y]):
        for line in np.split(_x, n_lines):
            ax.plot(line[:, 0], line[:, 1], color="white", alpha=0.5)
    axs.format(xlim=(-xmax, xmax), ylim=(-xmax, xmax))
    pplt.show()

## Model

In [None]:
transformer = mf.models.NNTransformer(
    input_features=args.input_features,
    output_features=d,
    hidden_layers=args.hidden_layers,
    hidden_units=args.hidden_units,
    dropout=args.dropout,
    activation=args.activation,
)
base = torch.distributions.Normal(
    send(torch.zeros(args.input_features)),
    send(torch.ones(args.input_features)),
)
generator = mf.models.NNGenerator(base, transformer)
generator = generator.to(device)

entropy_estimator = mf.entropy.EmptyEntropyEstimator()
if args.entest == "cov":
    entropy_estimator = mf.entropy.CovarianceEntropyEstimator()
if args.entest == "knn":
    entropy_estimator = mf.entropy.KNNEntropyEstimator(k=5)

model = mf.MENTFlow(
    generator=generator,
    prior=None,
    entropy_estimator=entropy_estimator,
    transforms=transforms,
    diagnostics=diagnostics,
    measurements=measurements,
    penalty_parameter=args.penalty,
    discrepancy_function=args.disc,
)
model = model.to(device)
print(model)

cfg = {
    "generator": {
        "input_features": d,
        "output_features": d,
        "hidden_units": args.hidden_units,
        "hidden_layers": args.hidden_layers,
        "dropout": args.dropout,
        "activation": args.activation,
    },
}
man.save_pickle(cfg, "cfg.pkl")

## Training

In [None]:
def plotter(model):
    figs = plotting.plot_model(
        model,
        dist,
        n=args.vis_size, 
        bins=args.vis_bins,
        xmax=xmax,
        maxcols=args.vis_maxcols, 
        kind=args.vis_line,
        device=device
    )
    return figs

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=args.lr,
    weight_decay=0.0,  # would need to reset this after each epoch...
)

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    min_lr=args.lr_min,
    patience=args.lr_patience,
    factor=args.lr_drop,
)

In [None]:
trainer = mf.train.Trainer(
    model=model,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    plotter=plotter,
    output_dir=man.output_dir,
    precision=precision,
    device=device,
    notebook=True,
)

trainer.train(
    epochs=args.epochs,
    iterations=args.iters,
    batch_size=args.batch_size,
    rtol=args.rtol,
    atol=args.atol,
    dmax=args.dmax,
    penalty_step=args.penalty_step,
    penalty_scale=args.penalty_scale,
    penalty_max=args.penalty_max,
    vis_freq=args.vis_freq,
    eval_freq=args.eval_freq,
    savefig_kws=dict(ext=args.fig_ext, dpi=args.fig_dpi),
)