# Benchmark of 4D phase space reconstruction algorithms 

In [None]:
import sys
import importlib
import numpy as np
import pandas as pd
from skimage import filters
from matplotlib import pyplot as plt
import proplot as pplt
from tqdm.notebook import tqdm
from tqdm.notebook import trange

import reconstruct as rec
from tools import ap
from tools import analysis as ba
from tools import plotting as mplt
from tools import utils

In [None]:
pplt.rc['grid'] = False
pplt.rc['cmap.sequential'] = 'dusk_r'
pplt.rc['cmap.discrete'] = False
pplt.rc['cmap.diverging'] = 'vlag'
pplt.rc['figure.facecolor'] = 'white'
pplt.rc['grid.alpha'] = 0.04

## Setup 

Load the bunch. (Download data here: https://www.dropbox.com/sh/1f2yoh5n1wrgfxm/AAB0C__P9cKjcz7YzOA7iqmsa?dl=0. Each data set is an $N \times 6$ array, where $N$ is the number of macroparticles. The columns are {x [mm], x' [mrad], y [mm], y' [mrad], z [mm], dE [MeV]}.). 

`coords_SNS_elliptical_painting_300turns.npy` is the output of a phase space painting simulation in the Spallation Neutron Source (SNS) ring with 600,000 macroparticles (28 MB). The injection kicker waveforms were chosen such that large linear cross-plane correlations were maintained in the transverse phase space.

`2022-07-01_run_MEBT123_HZ04_VT34a_bunch.npy` is the output of a simulation of transport through the Beam Test Facility (BTF) with 10,000,000 macroparticles (480 MB). There are no linear $x$-$y$ correlations, but there are some higher-order $x$-$y$ correlations; there are also strong transverse-energy correlations, both linear and nonlinear.

In [None]:
filename = 'data/coords_SNS_elliptical_painting_300turns.npy'
# filename = 'data/2022-07-01_run_MEBT123_HZ04_VT34a_bunch.npy'

X = np.load(filename) 
X = X[:, :4]  # drop longitudinal coordinates
print('X.shape =', X.shape)

Use a random sample of particles.

In [None]:
# X = utils.rand_rows(X, int(0.1 * X.shape[0]))

Normalize $x$-$x'$ and $y$-$y'$ using the statistical 2D Twiss parameters.

In [None]:
# Compute the statistical Twiss parameters.
Sigma = np.cov(X.T)
alpha_x, alpha_y, beta_x, beta_y = ba.twiss2D(Sigma)
print('Statistical Twiss parameters:')
print('  alpha_x = {}'.format(alpha_x))
print('  alpha_y = {}'.format(alpha_y))
print('  beta_x = {}'.format(beta_x))
print('  beta_y = {}'.format(beta_y))

# Normalize the distribution using the statistical Twiss parameters.
V = ap.norm_matrix_4x4_uncoupled(alpha_x, alpha_y, beta_x, beta_y)
Vinv = np.linalg.inv(V)
Xn = utils.apply(Vinv, X)

# Define the lattice Twiss parameters at the reconstruction point. 
# `dalpha` and `dbeta` are the maximum deviations from the statistical 
# Twiss parameters. This will affect the reconstruction accuracy since 
# the actual phase advances (i.e., projection angles in normalized phase 
# space) will be different than those assumed from the lattice model.
state = np.random.RandomState()
state.seed(17)
dalpha = 0.0 * max(alpha_x, alpha_y)
dbeta = 0.0 * max(beta_x, beta_y)
alpha_x += np.random.uniform(-dalpha, dalpha)
alpha_y += np.random.uniform(-dalpha, dalpha)
beta_x += np.random.uniform(-dbeta, dbeta)
beta_y += np.random.uniform(-dbeta, dbeta)
print('Lattice Twiss parameters:')
print('  alpha_x = {}'.format(alpha_x))
print('  alpha_y = {}'.format(alpha_y))
print('  beta_x = {}'.format(beta_x))
print('  beta_y = {}'.format(beta_y))

Bin the normalized coordinates at the reconstruction location for later comparison with the reconstructed distribution. 

In [None]:
# Define the 4D reconstruction grid.
n_bins = 50  # same number of bins in x, x', y, y'
limits_rec = mplt.auto_limits(Xn, sigma=3.25)  # (min, max) along each axis

# Compute 4D histogram.
f_true, edges_rec = np.histogramdd(Xn, n_bins, limits_rec, density=True)
grid_rec = [rec.get_bin_centers(_edges) for _edges in edges_rec]
bin_volume = rec.get_bin_volume(limits_rec, n_bins)

View projections of the distribution.

In [None]:
# Labels for plotting
dims = ["x", "x'", "y", "y'"]  
units = ["mm", "mrad", "mm", "mrad"]
labels = [f"{d} [{u}]" for d, u in zip(dims, units)]
dims_n = [d + r"$_n$" for d in dims]
labels_n = dims_n

# Show all 2D projections of the distribution in normalized phase space. The 
# units of all normalized coordinates are sqrt(mm*mrad), but are not printed.
axes = mplt.corner(f_true, coords=grid_rec, diag_kind='line', labels=labels_n)
plt.show()

In [None]:
mplt.interactive_proj1d(f_true, coords=grid_rec, kind='step', dims=dims, slice_type='int')

In [None]:
mplt.interactive_proj1d(f_true, coords=grid_rec, kind='step', dims=dims, slice_type='range')

In [None]:
mplt.interactive_proj2d(f_true, coords=grid_rec, dims=dims, slice_type='int')

In [None]:
mplt.interactive_proj2d(f_true, coords=grid_rec, dims=dims, slice_type='range')

## Reconstruction 

We will reconstruct in normalized phase space so that the transfer matrices from reconstruction location to measurement location become rotations in $x$-$x'$ and $y$-$y'$ by angles equal to the phase advances.

### Hock's method

This method is detailed here: https://www.sciencedirect.com/science/article/pii/S0168900213005202. It requires independent control of the horizontal and vertical phase advances and the measurement of the $x$-$y$ projection.

Since the horizontal and vertical optics are varied independently, we can simulate the measurements by transporting $x$-$x'$ and $y$-$y'$ separately. We assume we can evenly space phase advances over 180 degrees. The Twiss parameters at the screen are randomly varied about their nominal values on each iteration.

In [None]:
def make_tmats(phase_advances, betas, alphas, Vinv):
    """Return list of 2x2 transfer matrices.
    
    phase advances : phase advances from reconstruction location to measurement location
    betas, alphas : Twiss parameters at measurement location
    Vinv : normalizing matrix at reconstruction location
    """
    tmats = []
    for phase_adv, beta, alpha in zip(phase_advances, betas, alphas):
        P = utils.rotation_matrix(phase_adv)
        V_screen = ap.norm_matrix(alpha, beta)
        tmats.append(np.linalg.multi_dot([V_screen, P, Vinv]))
    return tmats

In [None]:
K = 12  # number of horizontal optics settings
L = 12  # number of vertical optics settings

phase_adv_x = np.linspace(0.0, np.pi, K, endpoint=False)
betas_x = np.random.uniform(1.0, 3.0, size=K)
alphas_x = np.random.uniform(-0.2, 0.2, size=K)

phase_adv_y = np.linspace(0.0, np.pi, L, endpoint=False)
betas_y = np.random.uniform(1.0, 3.0, size=L)
alphas_y = np.random.uniform(-0.2, 0.2, size=L)

# Store list of x and y coordinates for each run.
tmats_x = make_tmats(phase_adv_x, betas_x, alphas_x, Vinv[:2, :2])
tmats_y = make_tmats(phase_adv_y, betas_y, alphas_y, Vinv[2:, 2:])
print('Transporting x-x.')
xx_list = [utils.apply(Mx, X[:, :2])[:, 0] for Mx in tqdm(tmats_x)]
print('Transporting y-y.')
yy_list = [utils.apply(My, X[:, 2:])[:, 0] for My in tqdm(tmats_y)]

Bin the coordinates on the screen.

In [None]:
# Define measurement grid size and resolution.
n_bins_meas = 50  # resolution of measured images
pad = 0.0  # fractional padding on min/max measured coordinates
xmax_meas = (1.0 + pad) * np.max(np.abs(xx_list))
ymax_meas = (1.0 + pad) * np.max(np.abs(yy_list))

# Create the measurement grid.
edges_meas = [
    np.linspace(-xmax_meas, xmax_meas, n_bins_meas + 1),
    np.linspace(-ymax_meas, ymax_meas, n_bins_meas + 1),
]
grid_meas = [rec.get_bin_centers(_edges) for _edges in edges_meas]

# Bins the coordinates on the measurement grid.
S = np.zeros((n_bins, n_bins, K, L))
for k, xx in enumerate(tqdm(xx_list)):
    for l, yy in enumerate(yy_list):
        S[:, :, k, l], _, _ = np.histogram2d(xx, yy, edges_meas)

In [None]:
fig, axes = pplt.subplots(nrows=K, ncols=L, figwidth=10, space=0.175)
for i in range(axes.shape[0]):
    for j in range(axes.shape[1]):
        ax = axes[i, j]
        mplt.plot_image(S[..., i, j], x=grid_meas[0], y=grid_meas[1], ax=ax)
axes.format(
    xticks=[], yticks=[], 
    xlabel='x', ylabel='y',
    suptitle='simulated images',
)
plt.show()

Reconstruct the normalized phase space distribution.

In [None]:
tmats_x_n = [np.matmul(Mx, V[:2, :2]) for Mx in tmats_x]
tmats_y_n = [np.matmul(My, V[2:, 2:]) for My in tmats_y]

In [None]:
f_rec = rec.hock4D(S, grid_meas, (grid_rec[0], grid_rec[2]),
                    tmats_x_n, tmats_y_n, method='SART', iterations=3)
f_rec = rec.process(f_rec, keep_positive=True, density=True, limits=limits_rec)

In [None]:
print('min(f_rec) = {}'.format(np.min(f_rec)))
print('max(f_rec) = {}'.format(np.max(f_rec)))
print('sum(f_rec) * bin_volume = {}'.format(np.sum(f_rec) * bin_volume))
print()
print('min(f_true) = {}'.format(np.min(f_true)))
print('max(f_true) = {}'.format(np.max(f_true)))
print('sum(f_true) * bin_volume = {}'.format(np.sum(f_true) * bin_volume))

In [None]:
def plot_compare_1D(f1, f2):
    fig, axes = pplt.subplots(ncols=6, nrows=3, figwidth=8, space=0.5)
    inds = [(0, 2), (0, 1), (2, 3), (0, 3), (2, 1), (1, 3)]
    for j, ind in enumerate(inds):
        x = grid_rec[ind[0]]
        y = grid_rec[ind[1]]
        im1 = rec.project(f1, ind)
        im2 = rec.project(f2, ind)
        # Need to use absolue difference `np.abs(im1 - im2)` if using log norm.
        for ax, H in zip(axes[:, j], [im1, im2, im1 - im2]):
            mplt.plot_image(H, ax=ax, 
                            # norm='log', handle_log='floor',
                           )
        axes[0, j].format(title=f'{dims[ind[0]]}-{dims[ind[1]]}')
    axes.format(xticks=[], yticks=[], leftlabels=['Reconstructed', 'True', 'Difference'])
    plt.show()

In [None]:
axes = plot_compare_1D(f_rec, f_true)

In [None]:
f_err = f_rec / np.max(f_rec) - f_true / np.max(f_true)
mplt.interactive_proj1d(f_err, kind='line', dims=dims, coords=grid_rec)

In [None]:
mplt.interactive_proj2d(f_rec, dims=dims, coords=grid_rec)

## 4D ART 

This method is described here: https://journals.aps.org/prab/abstract/10.1103/PhysRevAccelBeams.23.032804. The maximum reconstruction grid resolution $N$ for this method is approximately 50. I have run $N = 50$ with $8 \times 8$ projections in approximately 8 hours. $N = 25$ executes in few minutes with the same number of projections.

The accuracy is okay with $N = 50$ and $8 \times 8$ projections; in general, there are streaking artifacts in $x$-$x'$ that would probably go away if we could use more projections. I also found some improvements if I don't scale the measured projections; i.e., when the transfer matrices are really just rotation matrices.

In [None]:
n_bins = 30
f_true, edges_rec = np.histogramdd(Xn, n_bins, limits_rec, density=True)
grid_rec = [rec.get_bin_centers(_edges) for _edges in edges_rec]
bin_volume = rec.get_bin_volume(limits_rec, n_bins)

n_bins_meas = 30
edges_meas = [
    np.linspace(-xmax_meas, xmax_meas, n_bins_meas + 1),
    np.linspace(-ymax_meas, ymax_meas, n_bins_meas + 1),
]
grid_meas = [rec.get_bin_centers(_edges) for _edges in edges_meas]

Simulate measurements with a new screen resolution and number of projections.

In [None]:
K = 8  # number of horizontal optics settings
L = 8  # number of vertical optics settings

phase_adv_x = np.linspace(0.0, np.pi, K, endpoint=False)
betas_x = np.random.uniform(1.0, 1.0, size=K)
alphas_x = np.random.uniform(0.0, 0.0, size=K)

phase_adv_y = np.linspace(0.0, np.pi, L, endpoint=False)
betas_y = np.random.uniform(1.0, 1.0, size=L)
alphas_y = np.random.uniform(0.0, 0.0, size=L)

tmats_x = make_tmats(phase_adv_x, betas_x, alphas_x, Vinv[:2, :2])
tmats_y = make_tmats(phase_adv_y, betas_y, alphas_y, Vinv[2:, 2:])
print('Transporting x-x')
xx_list = [utils.apply(Mx, X[:, :2])[:, 0] for Mx in tqdm(tmats_x)]
print('Transporting y-y')
yy_list = [utils.apply(My, X[:, 2:])[:, 0] for My in tqdm(tmats_y)]

In [None]:
tmats_x_n = [np.matmul(Mx, V[:2, :2]) for Mx in tmats_x]
tmats_y_n = [np.matmul(My, V[2:, 2:]) for My in tmats_y]
projections, tmats_n = [], []
for xx, Mx in tqdm(zip(xx_list, tmats_x_n)):
    for yy, My in zip(yy_list, tmats_y_n):
        projection, _, _ = np.histogram2d(xx, yy, edges_meas)
        projections.append(projection)
        M = np.zeros((4, 4))
        M[:2, :2] = Mx
        M[2:, 2:] = My
        tmats_n.append(M)

Launch the reconstruction.

In [None]:
f_rec = rec.art4D(projections, tmats_n, grid_rec, grid_meas)
f_rec = rec.process(f_rec, keep_positive=True, density=True, limits=limits_rec)

In [None]:
axes = plot_compare_1D(f_rec, f_true)

In [None]:
mplt.interactive_proj2d(f_rec, dims=dims, coords=grid_rec)