# 2D reconstruction from mismatched turn-by-turn profiles — McMillan map

In [None]:
import os
import sys
import time
from typing import Callable

import matplotlib.pyplot as plt
import matplotlib.colors
import matplotlib.lines
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.plot as psv
from ipywidgets import interact
from ipywidgets import widgets
from ipywidgets import Dropdown
from ipywidgets import IntSlider
from ipywidgets import BoundedIntText

import ment
from ment.sim import Transform
from ment.sim import LinearTransform
from ment.sim import ComposedTransform

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

Create lattice.

In [None]:
class AxiallySymmetricNonlinearKick(Transform):
    def __init__(self, alpha: float, beta: float, phi: float, A: float, E: float, T: float) -> None:
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.phi = phi
        self.A = A
        self.E = E
        self.T = T

    def forward(self, x: np.ndarray) -> np.ndarray:
        ndim = x.shape[1]
        if ndim == 2:
            x = np.hstack([x, np.zeros((x.shape[0], 2))])

        r = np.sqrt(x[:, 0] ** 2 + x[:, 2] ** 2)
        theta = np.arctan2(x[:, 2], x[:, 0])

        dr = -(1.0 / (self.beta * np.sin(self.phi))) * (
            (self.E * r) / (self.A * r**2 + self.T)
        ) - ((2.0 * r) / (self.beta * np.tan(self.phi)))

        x_out = np.copy(x)
        x_out[:, 1] += dr * np.cos(theta)
        x_out[:, 3] += dr * np.cos(theta)
        return x_out[:, :ndim]

In [None]:
# Define periodic lattice parameters
alpha = 0.0
beta = 1.0
phi = 2.0 * np.pi * 0.18

# Create one-turn transfer matrix
Vinv = ps.cov.norm_matrix_from_twiss_2x2(alpha, beta)
V = np.linalg.inv(Vinv)
R = ps.ap.phase_adv_matrix(phi)
M = np.linalg.multi_dot([V, R, Vinv])

# Make lattice transform
lattice = ComposedTransform(
    LinearTransform(M),
    AxiallySymmetricNonlinearKick(alpha, beta, phi, A=1.0, E=0.5, T=1.0),
)

Make ground-truth phase space distribution. The distribution must not be matched to the lattice optics, else all profiles will be the same and provide no new information. We can apply a linear mismatch. But note that the distribution will *not* be matched to the nonlinear lattice --- it will diffuse.

In [None]:
n_parts = 100_000
mismatch = False

theta = np.linspace(0.0, 2.0 * np.pi, n_parts)
x_true = np.stack([np.cos(theta), np.sin(theta)], axis=-1)
x_true *= 2.5
x_true += np.random.normal(scale=0.25, size=x_true.shape)
x_true = x_true / np.std(x_true, axis=0)

if mismatch:
    Vinv_mismatch = ps.cov.norm_matrix_from_twiss_2x2(alpha=1.5, beta=1.0)
    x_true = np.matmul(x_true, Vinv_mismatch.T)

Define tomographic phase space transformations. Each transformation just tracks the beam for a certain number of turns.

In [None]:
n_turns = 50

transforms = []
for turn in range(n_turns):
    transform = [lattice] * turn
    transform = ComposedTransform(*transform)
    transforms.append(transform)

In [None]:
xmax = 4.0
limits = 2 * [(-xmax, xmax)]
n_bins = 64
bin_edges = np.linspace(-xmax, xmax, 100)
bin_coords = 0.5 * (bin_edges[1:] + bin_edges[:-1])

diagnostic = ment.diag.Histogram1D(axis=0, edges=bin_edges)
diagnostics = [[diagnostic] for transform in transforms]

Generate training data.

In [None]:
measurements = ment.sim.forward(x_true, transforms, diagnostics)

Plot transformed distribution.

In [None]:
@interact(
    index=IntSlider(min=0, max=(n_turns - 1), val=0),
    bins=(32, 128),
)
def int_plot_meas(index: int, bins: int):
    diag_index = 0
    transform = transforms[index]
    diagnostic = diagnostics[index][diag_index]

    u = transform(x_true)

    _hist, _edges = np.histogram(u[:, 0], bins=bins, range=limits[0])
    _hist = _hist / np.max(_hist)

    fig, ax = pplt.subplots()
    ax.hist2d(u[:, 0], u[:, 1], bins=bins, range=limits, cmap="mono")
    pax = ax.panel_axes("bottom")
    pax.stairs(_hist, _edges, color="black", lw=1.25, fill=False)
    pax.format(ylim=(0.0, 1.15))
    plt.show()

Create MENT model.

In [None]:
ndim = 2
prior = ment.GaussianPrior(ndim=ndim, scale=1.0)

limits = [(-xmax, xmax)] * ndim

samp_grid_limits = limits
samp_grid_shape = ndim * [128]
sampler = ment.samp.GridSampler(grid_limits=samp_grid_limits, grid_shape=samp_grid_shape)

integration_limits = limits[1]
integration_limits = [[integration_limits] for transform in transforms]
integration_size = 200

model = ment.MENT(
    ndim=ndim,
    measurements=measurements,
    transforms=transforms,
    diagnostics=diagnostics,
    prior=prior,
    interpolation_kws=dict(method="linear"),
    sampler=sampler,
    n_samples=100_000,
    integration_limits=integration_limits,
    integration_size=integration_size,
    verbose=2,
    mode="sample",  # {"integrate", "sample"}
)

Train MENT model.

In [None]:
learning_rate = 0.80
n_epochs = 6

for epoch in range(-1, n_epochs):
    print("epoch =", epoch)

    if epoch >= 0:
        model.gauss_seidel_step(learning_rate)

    x_pred = model.sample(1_000_000)

    fig, axs = pplt.subplots(ncols=2)
    for i, ax in enumerate(axs):
        hist, edges = np.histogramdd(x_true, bins=75, range=limits)
        psv.plot_points(
            x_pred,
            limits=limits,
            bins=75,
            offset=1.0,
            norm=("log" if i else None),
            colorbar=True,
            ax=ax,
        )
    plt.show()

    ncols = min(len(transforms), 7)
    nrows = int(np.ceil(len(transforms) / ncols))
    figwidth = 1.5 * ncols
    figheight = 1.0 * nrows

    fig, axs = pplt.subplots(
        ncols=ncols, nrows=nrows, figwidth=figwidth, figheight=figheight, sharex=True, sharey=True
    )
    for index, transform in enumerate(transforms):
        values_pred = diagnostic(transform(x_pred))
        values_meas = np.copy(measurements[index][0])
        values_pred /= np.max(values_meas)
        values_meas /= np.max(values_meas)

        diagnostic = diagnostics[index][0]
        ax = axs[index]

        ax.plot(diagnostic.coords, values_meas, color="red3")
        ax.plot(diagnostic.coords, values_pred, color="black", marker=".", ms=1.0, lw=0)
        ax.format(ymax=1.25, xlim=(-xmax, xmax))
    plt.show()