# Reconstruct 2D phase space density from TBT profiles of mismatched beam

In [None]:
import os
import sys
import time
from typing import Callable

import matplotlib.pyplot as plt
import matplotlib.colors
import matplotlib.lines
import numpy as np
import proplot as pplt
import psdist as ps
import psdist.visualization as psv
from ipywidgets import interact
from ipywidgets import widgets

import ment

In [None]:
pplt.rc["cmap.discrete"] = False
pplt.rc["cmap.sequential"] = "viridis"
pplt.rc["figure.facecolor"] = "white"
pplt.rc["grid"] = False

In [None]:
class ComposedTransform:
    def __init__(self, *transforms) -> None:
        self.transforms = transforms
            
    def forward(self, x: np.ndarray) -> np.ndarray:
        u = x
        for transform in self.transforms:
            u = transform(u)
        return u

    def inverse(self, u: np.ndarray) -> np.ndarray:
        x = u
        for transform in reversed(self.transforms):
            x = transform.inverse(x)
        return x

    def __call__(self, x: np.ndarray) -> np.ndarray:
        return self.forward(x)

In [None]:
class Transform:
    def __init__(self) -> None:
        return
        
    def forward(self, x: np.ndarray) -> np.ndarray:
        raise NotImplementedError

    def inverse(self, u: np.ndarray) -> np.ndarray:
        raise NotImplementedError

    def __call__(self, x: np.ndarray) -> np.ndarray:
        return self.forward(x)

In [None]:
class LinearTransform(Transform):
    def __init__(self, matrix: np.ndarray) -> None:
        super().__init__()
        self.set_matrix(matrix)

    def set_matrix(self, matrix: np.ndarray) -> None:
        self.matrix = matrix
        self.matrix_inv = np.linalg.inv(matrix)
        
    def forward(self, x: np.ndarray) -> np.ndarray:
        return np.matmul(x, self.matrix.T)

    def inverse(self, u: np.ndarray) -> np.ndarray:
        return np.matmul(u, self.matrix_inv.T)

In [None]:
class AxiallySymmetricNonlinearKick(Transform):
    def __init__(self, alpha: float, beta: float, phi: float, A: float, E: float, T: float) -> None:
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.phi = phi
        self.A = A
        self.E = E
        self.T = T

    def forward(self, x: np.ndarray) -> np.ndarray:
        ndim = x.shape[1]
        if ndim == 2:
            x = np.hstack([x, np.zeros((x.shape[0], 2))])
            
        r = np.sqrt(x[:, 0] ** 2 + x[:, 2] ** 2)
        theta = np.arctan2(x[:, 2], x[:, 0])
        
        dr = -(1.0 / (self.beta * np.sin(self.phi))) * ((self.E * r) / (self.A * r**2 + self.T)) - ((2.0 * r) / (self.beta * np.tan(self.phi)))
        
        x_out = np.copy(x)
        x_out[:, 1] += dr * np.cos(theta)
        x_out[:, 3] += dr * np.cos(theta)
        return x_out[:, :ndim]

In [None]:
def norm_matrix(alpha: float, beta: float) -> np.ndarray:
    V = np.zeros((2, 2))
    V[0, 0] = np.sqrt(beta)
    V[0, 1] = 0.0
    V[1, 0] = -alpha / np.sqrt(beta)
    V[1, 1] = 1.0 / np.sqrt(beta)
    return np.linalg.inv(V)

In [None]:
def phase_advance_matrix(phi: float) -> np.ndarray:
    return np.array([[np.cos(phi), np.sin(phi)], [-np.sin(phi), np.cos(phi)]])

## Lattice

In [None]:
# Define twiss parameters at origin
alpha = 0.0
beta = 1.0
phi = 2.0 * np.pi * 0.18

# Create lattice.
Vinv = norm_matrix(alpha, beta)
V = np.linalg.inv(Vinv)
R = phase_advance_matrix(phi)
M = np.linalg.multi_dot([V, R, Vinv])

lattice = ComposedTransform(
    LinearTransform(M),
    AxiallySymmetricNonlinearKick(alpha, beta, phi, A=1.0, E=0.5, T=1.0),
)

## Bunch

In [None]:
# Create distribution
n_parts = 300_000

t = np.linspace(0.0, 2.0 * np.pi, n_parts)
x_true = np.vstack([np.cos(t), np.sin(t)]).T
x_true *= 2.5
x_true += np.random.normal(scale=0.25, size=x_true.shape)
np.random.shuffle(x_true)

## Apply linear mismatch
# x_true = np.matmul(x_true, np.linalg.inv(norm_matrix(alpha=1.5, beta=1.0).T))

## Diagnostics

In [None]:
turn_indices = np.arange(1, 20, 1)
xmax = 7.5
n_bins = 75

limits = 2 * [(-xmax, xmax)]
bin_edges = np.linspace(-xmax, xmax, 100)
bin_coords = 0.5 * (bin_edges[1:] + bin_edges[:-1])

# List of diagnostics applied at each turn
diagnostic = ment.diag.Histogram1D(axis=0, bin_edges=bin_edges)
diagnostics = [[diagnostic,] for _ in turn_indices]

Show TBT:

In [None]:
@interact(
    turn=widgets.IntSlider(min=0, max=max(turn_indices), value=0), 
    kind=["hist", "scatter"],
    bins=widgets.IntSlider(min=64, max=128, value=128),
)
def update(turn, kind: str, bins: int):
    fig, ax = pplt.subplots()

    x = x_true.copy()
    for t in range(turn):
        x = lattice(x)
            
    if kind == "scatter":
        ax.scatter(x[:, 0], x[:, 1], color="k", s=3)
    elif kind == "hist":
        ax.hist2d(x[:, 0], x[:, 1], bins=bins, range=limits, cmap="mono")
    else:
        raise ValueError

    pax = ax.panel_axes("bottom")
    hist, edges = np.histogram(x[:, 0], bins=bins, range=limits[0], density=True)
    psv.plot_profile(hist, edges=edges, ax=pax, color="black", fill=True)
    pax.format(yticks=[])

    ax.format(xlim=limits[0], ylim=limits[1])
    plt.show()

## Generate data

In [None]:
measurements = []

x = np.copy(x_true)
index = 0
for turn in range(max(turn_indices)):
    x = lattice(x)
    if turn + 1 in turn_indices:
        measurements.append([diagnostic(x) for diagnostic in diagnostics[index]])
        index += 1

## Reconstruction

In [None]:
ndim = 2
prior = ment.GaussianPrior(ndim=ndim, scale=2.0)

limits = [(-xmax, xmax)] * ndim
sampler = ment.samp.GridSampler(
    grid_limits=limits, 
    grid_shape=(ndim * [128]),
    noise=0.0,
)

model = ment.MENTRing(
    ndim=ndim,
    turn_indices=turn_indices,
    transform=lattice,
    measurements=measurements,
    diagnostics=diagnostics,
    prior=prior,
    interpolation=dict(method="linear"),
    mode="sample",
    sampler=sampler,
    n_samples=200_000,
    verbose=True,
)

learning_rate = 0.80
n_epochs = 10

In [None]:
for epoch in range(-1, n_epochs):
    if epoch >= 0:
        model.gauss_seidel_step(lr=learning_rate)

    ## Sample particles from reconstructed distribution
    x_pred = model.sample(x_true.shape[0])

    ## Plot 2D density vs. ground truth
    bins = 128
    
    fig, axs = pplt.subplots(ncols=2)
    for ax, x in zip(axs, [x_pred, x_true]):
        hist, edges = np.histogramdd(x, bins=bins, range=limits)
        ax.pcolormesh(edges[0], edges[1], hist.T + 1.0, cmap="mono")
    pplt.show()

    ## Plot measured vs simulated 1D projections
    ncols = min(len(turn_indices), 7)
    nrows = int(np.ceil(len(turn_indices) / ncols))
    figwidth = 1.6 * ncols
    figheight = 1.2 * nrows
    log = False

    error = 0.0
    
    fig, axs = pplt.subplots(ncols=ncols, nrows=nrows, figwidth=figwidth, figheight=figheight, sharex=True, sharey=True)
    for index in range(len(turn_indices)):
        y_pred = model.simulate(index, diag_index=0).copy()
        y_meas = measurements[index][0].copy()
        normalization = y_meas.max()
        y_pred /= y_meas.max()
        y_meas /= y_meas.max()  
        
        error += np.mean(np.abs(y_pred - y_meas))

        ax = axs[index]
        ax.plot(y_meas, color="black")
        ax.plot(y_pred, color="red")
        ax.format(ymax=1.25)
        if log:
            ax.format(yscale="log", ymax=5.0, ymin=1.00e-05, yformatter="log")
    plt.show()

    print(f"Epoch = {epoch}")
    print(f"Error = {error:0.3e}")

Show predicted 2D distribution vs. transformed source distribution. Here we may include turns that were not in the training data.

In [None]:
@interact(
    turn=widgets.IntSlider(min=0, max=(max(turn_indices) * 2), value=0, continuous_update=False), 
    kind=["hist", "scatter"],
    bins=widgets.IntSlider(min=64, max=128, value=128),
    n_parts=widgets.FloatLogSlider(min=3.0, max=6.0, value=100_000),
)
def update(turn, kind: str, bins: int, n_parts: int = 5.0):
    fig, axs = pplt.subplots(ncols=2)

    n_parts = int(n_parts)
    u_true = x_true[:n_parts].copy()
    u_pred = model.sample(u_true.shape[0])
    for t in range(turn):
        u_true = lattice(u_true)
        u_pred = lattice(u_pred)
        
    for ax, u in zip(axs, [u_pred, u_true]):
        ax.hist2d(u[:, 0], u[:, 1], bins=bins, range=limits, cmap="mono")
        
        pax = ax.panel_axes("bottom")
        hist, edges = np.histogram(u[:, 0], bins=bins, range=limits[0], density=True)
        psv.plot_profile(hist, edges=edges, ax=pax, color="black", fill=True)
        pax.format(yticks=[])

    axs.format(xlim=limits[0], ylim=limits[1], toplabels=["Model", "True"])
    plt.show()