# 4D reconstruction from mismatched turn-by-turn profiles — McMillan map

In [None]:
import os
import sys
import time
from typing import Callable

import matplotlib.pyplot as plt
import matplotlib.colors
import matplotlib.lines
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.plot as psv
from ipywidgets import interact
from ipywidgets import widgets
from ipywidgets import Dropdown
from ipywidgets import IntSlider
from ipywidgets import BoundedIntText
from tqdm import trange

import ment
from ment.sim import Transform
from ment.sim import LinearTransform
from ment.sim import ComposedTransform
from ment.utils import unravel

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

Create lattice.

In [None]:
class AxiallySymmetricNonlinearKick(Transform):
    def __init__(self, alpha: float, beta: float, phi: float, A: float, E: float, T: float) -> None:
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.phi = phi
        self.A = A
        self.E = E
        self.T = T

    def forward(self, x: np.ndarray) -> np.ndarray:
        ndim = x.shape[1]

        r = np.sqrt(x[:, 0] ** 2 + x[:, 2] ** 2)
        theta = np.arctan2(x[:, 2], x[:, 0])

        dr = -(1.0 / (self.beta * np.sin(self.phi))) * (
            (self.E * r) / (self.A * r**2 + self.T)
        ) - ((2.0 * r) / (self.beta * np.tan(self.phi)))

        x_out = np.copy(x)
        x_out[:, 1] += dr * np.cos(theta)
        x_out[:, 3] += dr * np.cos(theta)
        return x_out[:, :ndim]

    def inverse(self, x: np.ndarray) -> np.ndarray:
        x[:, 1] *= -1.0
        x = self.forward(x)
        x[:, 1] *= -1.0
        return x

In [None]:
# Define periodic lattice parameters
alpha_x = 0.0
alpha_y = 0.0
beta_x = 1.0
beta_y = 1.0
phi_x = 2.0 * np.pi * 0.18
phi_y = 2.0 * np.pi * 0.18

# Create one-turn transfer matrix
Vinv = np.identity(4)
Vinv[0:2, 0:2] = ps.cov.norm_matrix_from_twiss_2x2(alpha_x, beta_x)
Vinv[2:4, 2:4] = ps.cov.norm_matrix_from_twiss_2x2(alpha_y, beta_y)
V = np.linalg.inv(Vinv)
R = ps.ap.phase_adv_matrix(phi_x, phi_y)
M = np.linalg.multi_dot([V, R, Vinv])

# Make lattice transform
lattice = ComposedTransform(
    LinearTransform(M),
    AxiallySymmetricNonlinearKick(alpha_x, beta_x, phi_x, A=1.0, E=0.5, T=1.0),
)

Make ground-truth phase space distribution. The distribution must not be matched to the lattice optics, else all profiles will be the same and provide no new information. We can apply a linear mismatch. But note that the distribution will *not* be matched to the nonlinear lattice --- it will diffuse.

In [None]:
n_parts = 1_000_000
mismatch = True

# x_true = np.random.normal(size=(n_parts, 4))
# x_true = x_true / np.linalg.norm(x_true, axis=1)[:, None]
# x_true = x_true / np.std(x_true, axis=0)

# x_true[:, (0, 2)] = ment.dist.get_dist("galaxy").sample(x_true.shape[0])

x_true = np.random.normal(size=(n_parts, 4))

if mismatch:
    Vinv_mismatch = np.eye(4)
    Vinv_mismatch[0:2, 0:2] = ps.cov.norm_matrix_from_twiss_2x2(alpha=+0.5, beta=1.0)
    Vinv_mismatch[2:4, 2:4] = ps.cov.norm_matrix_from_twiss_2x2(alpha=-0.5, beta=1.0)
    x_true = np.matmul(x_true, Vinv_mismatch.T)

Define tomographic phase space transformations. Each transformation just tracks the beam for a certain number of turns.

In [None]:
n_turns = 100

transforms = []
for turn in range(n_turns):
    transform = [lattice] * turn
    transform = ComposedTransform(*transform)
    transforms.append(transform)

In [None]:
xmax = 6.0
limits = 2 * [(-xmax, xmax)]
n_bins = 64
bin_edges = np.linspace(-xmax, xmax, 100)
bin_coords = 0.5 * (bin_edges[1:] + bin_edges[:-1])

diagnostic_x = ment.diag.Histogram1D(axis=0, edges=bin_edges)
diagnostic_y = ment.diag.Histogram1D(axis=2, edges=bin_edges)

diagnostics = [[diagnostic_x, diagnostic_y] for _ in transforms]

Generate training data.

In [None]:
# projections = ment.sim.forward(x_true, transforms, diagnostics)

projections = []

u = x_true.copy()
for i in trange(n_turns):
    if i > 0:
        u = lattice(u)
    projections.append([diagnostic(u) for diagnostic in diagnostics[i]])

Plot transformed distribution.

In [None]:
@interact(
    index=IntSlider(min=0, max=(n_turns - 1), val=0),
    dim1=["x", "x'", "y", "y'"],
    dim2=["x", "x'", "y", "y'"],
    bins=(32, 128),
)
def int_plot_meas(index: int, dim1: str = "x", dim2: str = "x'", bins: int = 64):
    axis = [["x", "x'", "y", "y'"].index(dim) for dim in [dim1, dim2]]
    
    transform = transforms[index]
    u = transform(x_true)

    _hist_x, _edges_x = np.histogram(u[:, 0], bins=bins, range=(-xmax, xmax))
    _hist_x = _hist_x / np.max(_hist_x)

    fig, ax = pplt.subplots()
    ax.hist2d(u[:, axis[0]], u[:, axis[1]], bins=bins, range=limits, cmap="mono")
    pax = ax.panel_axes("bottom")
    pax.stairs(_hist_x, _edges_x, color="black", lw=1.25, fill=False)
    pax.format(ylim=(0.0, 1.15))

    plt.show()

Create MENT model.

In [None]:
ndim = 4
prior = ment.GaussianPrior(ndim=ndim, scale=4.0)

samp_grid_limits = [(-3.5, 3.5)] * ndim
samp_grid_res = 25
samp_grid_noise = 0.5
samp_grid_shape = ndim * [samp_grid_res]
sampler = ment.samp.GridSampler(
    grid_limits=samp_grid_limits, grid_shape=samp_grid_shape, noise=samp_grid_noise
)

integration_limits = [(-5.0, 5.0)] * (ndim - 1)
integration_limits = [[integration_limits, integration_limits] for transform in transforms]
integration_size = 25 ** (ndim - 1)

model = ment.MENT(
    ndim=ndim,
    transforms=transforms,
    projections=projections,
    diagnostics=diagnostics,
    prior=prior,
    interpolation_kws=dict(method="linear"),
    integration_limits=integration_limits,
    integration_size=integration_size,
    sampler=sampler,
    nsamp=10_000,
    verbose=2,
    mode="sample",  # {"integrate", "sample"}
)

Train MENT model.

In [None]:
learning_rate = 0.80
n_epochs = 6

for epoch in range(n_epochs + 1):
    print("epoch =", epoch)

    if epoch > 0:
        model.gauss_seidel_step(learning_rate)

    x_pred = model.sample(100_000)

    grid = psv.CornerGrid(ndim)
    grid.plot_points(
        x_pred,
        limits=samp_grid_limits,
        bins=75,
        mask=False,
    )

    projections_pred = ment.sim.forward(x_pred, transforms, diagnostics)
    projections_meas = projections

    fig, axs = pplt.subplots(ncols=7, nrows=3, figsize=(9.0, 2.5))
    for y_pred, y_meas, ax in zip(unravel(projections_pred), unravel(projections_meas), axs):
        ax.plot(bin_coords, y_pred / y_meas.max(), color="red3")
        ax.plot(bin_coords, y_meas / y_meas.max(), color="black", marker=".", ms=1.0, lw=0.0)
    axs.format(xlim=(-xmax, xmax), ylim=(0.0, 1.25))
    plt.show()