# 2:1 MENT

In [None]:
import os
import sys
import time

import matplotlib.pyplot as plt
import numpy as np
import proplot as pplt

import ment

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

Settings:

In [None]:
dist_name = "swissroll"
ndim = 2
n_meas = 6
n_bins = 75
xmax = 3.0
seed = 0

Define the source distribution.

In [None]:
dist = ment.dist.get_dist(dist_name, ndim=ndim, seed=seed)
x_true = dist.sample(1_000_000)

limits = 2 * [(-xmax, xmax)]

fig, ax = pplt.subplots()
ax.hist2d(x_true[:, 0], x_true[:, 1], bins=75, range=limits)
plt.show()

Create the measurement data.

In [None]:
phase_advances = np.linspace(0.0, np.pi, n_meas, endpoint=False)

transfer_matrices = []
for phase_advance in phase_advances:
    matrix = ment.sim.rotation_matrix(phase_advance)
    transfer_matrices.append(matrix)

transforms = []
for matrix in transfer_matrices:
    transform = ment.sim.LinearTransform(matrix)
    transforms.append(transform)

bin_edges = np.linspace(-xmax, xmax, n_bins + 1)

diagnostics = []
for transform in transforms:
    diagnostic = ment.diag.Histogram1D(axis=0, bin_edges=bin_edges)
    diagnostics.append([diagnostic])
diagnostics

measurements = []
for index, transform in enumerate(transforms):
    u = transform(x_true)
    measurements.append([diagnostic(u) for diagnostic in diagnostics[index]])

In [None]:
prior = ment.UniformPrior(ndim=2, scale=(2.0 * xmax))

sampler = ment.samp.GridSampler(grid_limits=limits, grid_shape=(ndim * [100]))

model = ment.MENT(
    ndim=ndim,
    measurements=measurements,
    transforms=transforms,
    diagnostics=diagnostics,
    prior=prior,
    sampler=sampler,
    n_samples=500_000,
    interpolation=dict(method="linear"),
    verbose=True,
)

learning_rate = 0.99
n_epochs = 10

In [None]:
for epoch in range(-1, n_epochs):
    if epoch >= 0:
        model.gauss_seidel_step(lr=learning_rate)

    x = model.sample(1_000_000)
    
    fig, ax = pplt.subplots()
    ax.hist2d(x[:, 0], x[:, 1], bins=75, range=limits)
    plt.show()

    ncols = min(n_meas, 7)
    nrows = int(np.ceil(n_meas / ncols))
    figwidth = 2.0 * ncols
    figheight = 1.25 * nrows
    fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=(figwidth, figheight), sharex=True, sharey=True)
    for index, transform in enumerate(transforms):
        prediction = model.simulate(index, diag_index=0).copy()
        measurement = measurements[index][0].copy()
        ax = axs.ravel()[index]

        normalization = measurement.max()
        prediction = prediction / normalization
        measurement = measurement / normalization
        ax.plot(measurement, color="red3");
        ax.plot(prediction, color="black");
    plt.show()