# 2:1 MENT

In [1]:
import os
import sys
import time

import matplotlib.pyplot as plt
import matplotlib.colors
import numpy as np
import proplot as pplt

import ment

In [2]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

Settings:

In [3]:
dist_name = "galaxy"
ndim = 2
n_meas = 6
n_bins = 75
xmax = 4.5
seed = 0

Define the source distribution.

In [None]:
dist = ment.dist.get_dist(dist_name, ndim=ndim, seed=seed)
x_true = dist.sample(1_000_000)

limits = 2 * [(-xmax, xmax)]

fig, axs = plt.subplots(ncols=2, figsize=(5.5, 2.5), constrained_layout=True)
for i, ax in enumerate(axs):
    norm = None
    if i:
        norm = matplotlib.colors.LogNorm(vmin=1.0, vmax=np.max(hist))
    hist, edges = np.histogramdd(x_true, bins=75, range=limits)
    m = ax.pcolormesh(edges[0], edges[1], hist.T + 1.0, norm=norm)
fig.colorbar(m)
plt.show()

Create the measurement data.

In [None]:
phase_advances = np.linspace(0.0, np.pi, n_meas, endpoint=False)

transfer_matrices = []
for phase_advance in phase_advances:
    matrix = ment.sim.rotation_matrix(phase_advance)
    transfer_matrices.append(matrix)

transforms = []
for matrix in transfer_matrices:
    transform = ment.sim.LinearTransform(matrix)
    transforms.append(transform)

bin_edges = np.linspace(-xmax, xmax, n_bins + 1)

diagnostics = []
for transform in transforms:
    diagnostic = ment.diag.Histogram1D(axis=0, bin_edges=bin_edges)
    diagnostics.append([diagnostic])
diagnostics

measurements = []
for index, transform in enumerate(transforms):
    u = transform(x_true)
    measurements.append([diagnostic(u) for diagnostic in diagnostics[index]])

In [None]:
prior = ment.UniformPrior(ndim=2, scale=(2.0 * xmax))

sampler = ment.samp.GridSampler(grid_limits=limits, grid_shape=(ndim * [100]))

model = ment.MENT(
    ndim=ndim,
    measurements=measurements,
    transforms=transforms,
    diagnostics=diagnostics,
    prior=prior,
    interpolation=dict(method="linear"),
    
    sampler=sampler,
    n_samples=200_000,
    
    integration_limits=[[limits[1],] for transform in transforms],
    integration_size=50_000, 
    integration_batches=1,

    verbose=True,
    mode="integrate",  # {"integrate", "sample"}
)

learning_rate = 0.90
n_epochs = 10

In [None]:
for epoch in range(-1, n_epochs):
    if epoch >= 0:
        model.gauss_seidel_step(lr=learning_rate)

    x = model.sample(1_000_000)

    fig, axs = plt.subplots(ncols=2, figsize=(5.5, 2.5), constrained_layout=True)
    for i, ax in enumerate(axs):
        norm = None
        if i:
            norm = matplotlib.colors.LogNorm(vmin=1.0, vmax=np.max(hist))
        hist, edges = np.histogramdd(x, bins=75, range=limits)
        m = ax.pcolormesh(edges[0], edges[1], hist.T + 1.0, norm=norm)
    fig.colorbar(m, ax=axs[1])
    plt.show()

    ncols = min(n_meas, 7)
    nrows = int(np.ceil(n_meas / ncols))
    figwidth = 1.5 * ncols
    figheight = 1.25 * nrows

    for log in [False, True]:
        fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figwidth=figwidth, figheight=figheight, sharex=True, sharey=True)
        for index, transform in enumerate(transforms):
            prediction = model.simulate(index, diag_index=0).copy()
            measurement = measurements[index][0].copy()
            ax = axs[index]
    
            normalization = measurement.max()
            prediction = prediction / normalization
            measurement = measurement / normalization
            ax.plot(measurement, color="red3")
            ax.plot(prediction, color="black")
            ax.format(ymax=1.25)
            if log:
                ax.format(yscale="log", ymax=5.0, ymin=1.00e-05, yformatter="log")
        plt.show()