# N:1 MENT-S — marginal projections

In [None]:
import os
import sys
import time

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.plot as psv
from tqdm.notebook import tqdm
from tqdm.notebook import trange

import ment
from ment.train.plot import PlotDistCorner
from ment.train.plot import PlotProj1D
from ment.utils import unravel

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

## Settings

In [None]:
dist_name = "gaussian_mixture"
ndim = 6
xmax = 3.5
seed = 12345

## Source distribution

In [None]:
dist = ment.dist.get_dist(dist_name, ndim=ndim, seed=seed)
X_true = dist.sample(1_000_000)

In [None]:
limits = ndim * [(-xmax, xmax)]

grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.5))
grid.plot_points(X_true, limits=limits, bins=64, mask=False)
plt.show()

## Data generation

In [None]:
# Settings
nbins = 40
nmeas = ndim
kde = False
kde_bandwidth = 1.0

## Measure 1D marginals
rng = np.random.default_rng(seed)
axis_meas = 0
n_meas = ndim

# Create transforms (permutation matrices)
transfer_matrices = []
for i in range(ndim):
    j = axis_meas
    matrix = np.identity(ndim)
    matrix[i, i] = matrix[j, j] = 0.0
    matrix[i, j] = matrix[j, i] = 1.0
    transfer_matrices.append(matrix)

transforms = []
for matrix in transfer_matrices:
    transform = ment.sim.LinearTransform(matrix)
    transforms.append(transform)

# Create histogram diagnostics
axis_proj = axis_meas
bin_edges = np.linspace(-xmax, xmax, nbins + 1)

diagnostics = []
for transform in transforms:
    diagnostic = ment.diag.Histogram1D(
        axis=axis_meas, edges=bin_edges, kde=kde, kde_bandwidth=kde_bandwidth
    )
    diagnostics.append([diagnostic])

# Generate data.
projections = ment.sim.forward(X_true, transforms, diagnostics)

## Reconstruction model

In [None]:
sampler = "mcmc"
nsamp = 500_000
burnin = 10_000

if sampler == "grid":
    samp_grid_res = 32
    samp_noise = 0.5
    samp_grid_shape = ndim * [samp_grid_res]
    samp_grid_limits = limits

    sampler = ment.samp.GridSampler(
        grid_limits=samp_grid_limits,
        grid_shape=samp_grid_shape,
        noise=samp_noise,
    )
elif sampler == "mcmc":
    sampler = ment.samp.MetropolisHastingsSampler(
        ndim=ndim,
        chains=248,
        proposal_cov=np.eye(ndim),
        burnin=burnin,
        shuffle=True,
        verbose=True,
    )
else:
    raise ValueError

prior = ment.GaussianPrior(ndim=ndim, scale=1.0)

model = ment.MENT(
    ndim=ndim,
    projections=projections,
    transforms=transforms,
    diagnostics=diagnostics,
    prior=prior,
    interpolation_kws=dict(method="linear"),
    sampler=sampler,
    nsamp=nsamp,
    mode="sample",
    verbose=True,
)

## Training

In [None]:
plot_nsamp = 1_000_000
plot_nbins = 64

In [None]:
plot_model = ment.train.Plotter(
    n_samples=plot_nsamp,
    plot_proj=[
        PlotProj1D(log=False),
        PlotProj1D(log=True),
    ],
    plot_dist=[
        PlotDistCorner(
            fig_kws=dict(figwidth=(ndim * 1.25), diag_shrink=0.80),
            limits=(ndim * [(-xmax, xmax)]),
            bins=plot_nbins,
        ),
    ],
)

eval_model = ment.train.Evaluator(n_samples=plot_nsamp)

In [None]:
trainer = ment.train.Trainer(
    model,
    plot_func=plot_model,
    eval_func=eval_model,
    notebook=True,
)

trainer.train(epochs=3, learning_rate=0.80)

## Evaluation

In [None]:
X_pred = model.sample(100_000)

In [None]:
# color = "pink"
# bins = 32

# grid = psv.CornerGrid(ndim, figwidth=(ndim * 1.5))
# grid.plot_points(
#     X_true[: X_pred.shape[0], :],
#     limits=(ndim * [(-xmax, xmax)]),
#     bins=bins,
#     mask=False,
#     cmap="mono",
#     diag_kws=dict(lw=1.25, color="black"),
# )
# grid.plot_points(
#     X_pred,
#     limits=(ndim * [(-xmax, xmax)]),
#     bins=bins,
#     diag_kws=dict(lw=1.25, color="pink5"),
#     alpha=0.0,
# )
# grid.plot_points(
#     X_pred[:1000, :],
#     diag=False,
#     kind="scatter",
#     c=color,
#     s=0.5,
# )
# grid.set_limits(ndim * [(-xmax, xmax)])
# grid.set_labels([r"$x$", r"$p_x$", r"$y$", r"$p_y$", r"$z$", r"$p_z$"])
# plt.show()