# Locating longitudinal-transverse correlations in phase space 

The goal of this notebook is to identify the "width" of the measured longitudinal-transverse correlation in the phase space distribution of the BTF bunch at the first emittance station. We will look at the energy distribution of particles within a boundary in 4D transverse phase space.

In [None]:
import sys
import os
from os.path import join
import importlib
import numpy as np
import h5py
import itertools
from tqdm.notebook import tqdm
from tqdm.notebook import trange
from plotly import graph_objects as go
from matplotlib import pyplot as plt
from ipywidgets import interactive
from ipywidgets import widgets
import proplot as pplt

sys.path.append('../../')
from tools import plotting as mplt
from tools import utils
from tools.utils import project
from tools import analysis as ba

In [None]:
pplt.rc['grid'] = False
pplt.rc['cmap.discrete'] = False
pplt.rc['cmap.sequential'] = 'viridis'
pplt.rc['figure.facecolor'] = 'white'

## Load data 

In [None]:
folder = '_saved/2022-07-15-VS06/'

In [None]:
info = utils.load_pickle(join(folder, 'info.pkl'))
info

In [None]:
filename = info['filename']
coords = utils.load_stacked_arrays(join(folder, f'coords_{filename}.npz'))
shape = tuple([len(c) for c in coords])
print('shape:', shape)

In [None]:
f = np.memmap(join(folder, f'f_{filename}.mmp'), shape=shape, dtype='float', mode='r')

Crop the 5D array.

In [None]:
dims = ["x", "x'", "y", "y'", "w"]
units = ["mm", "mrad", "mm", "mrad", "MeV"]
dims_units = [f'{d} [{u}]' for d, u in zip(dims, units)]
prof_kws = dict(lw=0.5, alpha=0.7, color='white', scale=0.12)

In [None]:
prof = []
for i in range(5):
    p = utils.project(f, i)
    p = p / np.sum(p)
    prof.append(p)

In [None]:
for i in range(5):
    coords[i] = coords[i] - np.average(coords[i], weights=prof[i]) 

In [None]:
crop = (
    (0, f.shape[0]),
    (10, f.shape[1] - 10),
    (0, f.shape[2]),
    (16, f.shape[3] - 16),
    (12, f.shape[4] - 19),
)

fig, axes = pplt.subplots(ncols=5, figwidth=7, spanx=False, figheight=None)
for i, ax in enumerate(axes):
    ax.plot(prof[i], color='black')
    ax.axvspan(crop[i][0], crop[i][1] - 1, color='black', alpha=0.1)
    ax.format(xlabel=dims[i] + ' [pixel]')
# axes.format(yscale='log')
plt.show()

In [None]:
ind = tuple([slice(c[0], c[1]) for c in crop])
f = f[ind]
coords = [c[ind[i]] for i, c in enumerate(coords)]

Clip small negative values.

In [None]:
f_max = np.max(f)
f_min = np.min(f)
if f_min < 0.0:
    print(f'min(f) = {f_min}')
    print('Clipping to zero.')
    f = np.clip(f, 0.0, None)

In [None]:
f = f / f_max

## Rectangular slices

### 2D projections 

First view the 2D projections at integer slices. 

In [None]:
mplt.interactive_proj2d(f, coords=coords, dims=dims, units=units, slice_type='int')

And at range slices.

In [None]:
mplt.interactive_proj2d(f, coords=coords, dims=dims, units=units,
                        slice_type='range')

In [None]:
frac_thresh = 10.0**-3.5

And all 2D projections.

In [None]:
axes = mplt.corner(
    f,
    coords=coords,
    diag_kind='None',
    labels=dims_units,
    norm='log', 
    handle_log='floor',
    thresh=frac_thresh,
    thresh_type='frac',
    prof='edges',
    prof_kws=dict(lw=0.5, alpha=0.7, color='white', scale=0.12),
    linewidth=0, rasterized=True,
)
axes.format(xlabel_kw=dict(fontsize='large'), ylabel_kw=dict(fontsize='large'))

### 1D projections 

In [None]:
mplt.interactive_proj1d(f, coords=coords, dims=dims, units=units, default_ind=4,
                        kind='line', slice_type='int')

In [None]:
mplt.interactive_proj1d(f, coords=coords, dims=dims, units=units, default_ind=4,
                        kind='line', slice_type='range')

In [None]:
# pplt.rc['pdf.fonttype'] = 42

In [None]:
_ind = [np.argmin(np.abs(c)) for c in coords]

fig, axes = pplt.subplots(ncols=4, figwidth=4.5, figheight=1.35)
axes.format(xlabel="w [MeV]", yticklabels=[], xlim=(-0.065, 0.065))
for i, ax in enumerate(axes):
    idx = utils.make_slice(4, np.arange(i + 1), _ind[:i+1])
    
    frac = np.sum(f[idx]) / np.sum(f)
    
    pw = project(f[idx], 3 - i)
    pw = pw / np.sum(pw)
    ax.bar(coords[4], pw, color='black', width=1)
    title = ""
    for j in range(i + 1):
        title += r"${} \approx$".format(dims[j])
    title += "0"
    title += f"\n({100.0*frac:.2f}%)"
    ax.format(title=title, title_kw=dict(fontsize='medium'))
axes.format(ylim=(0.0, axes[0].get_ylim()[1]), 
            xspineloc='bottom', yspineloc='neither', 
            # ylabel='Density'
           )
# for png in [False, True]:
#     figname = '_output/w_slices'
#     if png:
#         figname += '.png'
#     plt.savefig(figname)

## Contour slices

We can also use contour slices — volumes defined by density contours.

### Energy distribution 

Here, we observe the energy distribution within density contours in the 4D transverse phase space.

In [None]:
ftr = project(f, axis=(0, 1, 2, 3))
ftr = ftr / np.max(ftr)

View the distribution of pixel values in the array.

In [None]:
fig, ax = pplt.subplots()
ax.hist(ftr.ravel(), bins=20, color='black')
ax.format(yscale='log', ylabel='Count', xlabel='4D pixel value')
plt.show()

In [None]:
def energy_proj(f, level=0.5, ftr=None, normalize=True, return_frac=False):
    if ftr is None:
        ftr = np.sum(f, axis=-1)
    ftr = ftr / np.max(ftr)
    idx = np.where(ftr > level)
    frac = np.sum(ftr[idx]) / np.sum(ftr)
    pw = np.sum(f[idx], axis=0)
    if normalize:
        pw = pw / np.sum(pw)
    if return_frac:
        return pw, frac
    return pw

In [None]:
def update(thresh):
    pw, frac = energy_proj(f, thresh, ftr=ftr, normalize=True, return_frac=True)
    fig, ax = pplt.subplots(figsize=(4, 1.5))
    ax.format(xlabel=dims_units[4], title=f'frac = {frac:.2f}')
    ax.bar(coords[4], pw, color='black')
    plt.show()
    
interactive(update, thresh=(0.0, 0.99, 0.01))

Here are other plots of the same data.

In [None]:
n = 20
levels = np.linspace(0.9, 0.0, n)

pws, fracs = [], []
for level in levels:
    pw, frac = energy_proj(f, level, ftr=ftr, normalize=True, return_frac=True)
    pws.append(pw)
    fracs.append(frac)
pws = pws / np.max(pws)

In [None]:
fig, ax = pplt.subplots()
ax.plot(
    levels[::-1], fracs[::-1], color='black',
    marker='.', lw=0,
)
ax.format(xlabel="Threshold (x-x'-y-y')", ylabel='Fraction of particles', xlim=(-0.02, 1.0))

In [None]:
fig, ax = pplt.subplots(figsize=(4, 1.75))
ax.pcolormesh(coords[4], levels[::-1], pws[::-1],
              colorbar=True, colorbar_kw=dict(label='Density (arb. units)', width=0.1))
ax.format(xlabel='Energy [MeV]', ylabel='4D thresh')
plt.show()

In [None]:
cmap = pplt.Colormap('fire_r', left=0.0, right=0.9)
# cmap = pplt.Colormap('crest', left=0.0, right=1.0)

fig, ax = pplt.subplots(figsize=(4, 1.75))
ax.plot(coords[4], pws[::-1].T, cycle=cmap, lw=1, colorbar=True, 
        colorbar_kw=dict(values=levels[::-1], label="Threshold (x-x'-y-y')"))
ax.format(xlabel="Energy [MeV]")
plt.show()

In [None]:
fig, ax = pplt.subplots(figsize=(4.5, 1.55))

alpha = 0.3
color = 'red6'
ax2 = ax.alty(color=color)
ax2.format(ylabel='Fraction of beam', 
           yscale='log', 
           ylim=(0.001, 1.0))

_levels = np.linspace(0.0, 0.95, 35)
_fracs = [energy_proj(f, _level, ftr=ftr, normalize=True, return_frac=True)[1]
          for _level in _levels]
ax2.plot(_levels[::-1], _fracs[::-1], zorder=0, color=color, alpha=alpha, lw=1.25)

for level, pw in zip(levels, pws):
    ax.plotx(coords[4], level + 0.045 * pw, 
             # color='black', alpha=0.3,
             color='black', alpha=1,
             zorder=999999)
ax.format(
    ylabel='Energy [MeV]', 
    xlabel="Threshold (x-x'-y-y')",
    ylim=(-0.09, 0.09), xlim=(-0.03, 0.97),
)
# plt.savefig('_output/waterfall')

In [None]:
# Z = pws[::-1]
# X, Y = np.meshgrid(levels[::-1], coords[4], indexing='ij')    
# lines = []
# line_marker = dict(color='black', width=3)
# for x, y, z in zip(X, Y, Z):
#     lines.append(go.Scatter3d(x=x, y=y, z=z, mode='lines', line=line_marker))
# uaxis= dict(
#     gridcolor='rgb(255, 255, 255)',
#     zerolinecolor='rgb(255, 255, 255)',
#     showbackground=True,
#     backgroundcolor='rgb(230, 230,230)',
# )
# layout = go.Layout(
#     width=500,
#     height=500,
#     showlegend=False,
#     scene=dict(
#         xaxis=uaxis, 
#         yaxis=uaxis,
#         zaxis=uaxis,
#     ),
# )
# fig = go.Figure(data=lines, layout=layout)
# fig.show()

In [None]:
# fig = go.Figure(data=[go.Surface(x=levels[::-1], y=coords[4], z=pws[::-1].T)])
# fig.update_layout(width=500, height=500)
# fig.show()

Plot the 2D projections of the 4D phase space distribution as the boundary is changed, along with the energy distribution within the boundary.

In [None]:
def update(log=False, thresh=0.5):
    global ftr 
    frac = np.sum(ftr[ftr > thresh]) / np.sum(ftr)
    print(f'frac = {frac:.4f}')
    axes = mplt.corner(
        np.ma.masked_where(ftr < thresh, ftr),
        coords=coords[:4],
        diag_kind='None',
        labels=dims_units,
        thresh=10.0**-3.0,
        thresh_type='frac',
        fill_value=0.0,
        norm='log' if log else None,
        handle_log='floor',
        prof='edges',
        prof_kws=dict(kind='step', lw=0.4, alpha=0.8),
        linewidth=0,rasterized=True
    )
    pw = energy_proj(f, thresh, ftr=ftr, normalize=True)
    ax = axes[0, 2]
    for i in range(3):
        for j in range(i + 1):
            axes[i, j]._shared_x_axes.remove(ax)
            axes[i, j]._shared_y_axes.remove(ax)
    ax.axis('on')
    ax.plot(coords[4], pw, color='black')
    ax.format(
        xlim=(np.min(coords[4]), np.max(coords[4])),
        ylim=(0, 0.1),
        xspineloc='bottom', yspineloc='left',
        title='energy projection',
    )
    return axes, frac
    
interactive(update, log=False, thresh=(0.0, 0.99, 0.01))

Plot the 2D projections of the 5D phase space distribution as the 4D boundary is changed. To slice the array, we compute a 4D mask in the transverse phase space, then copy the mask along the last axis of the array to get a 5D mask. *This is very slow.*

In [None]:
def mask_4d(f, thresh=0.0, ftr=None):
    """Mask N-D array `f` based on contours of `f.sum(axis=-1)`."""
    if ftr is None:
        ftr = np.sum(f, axis=-1)
        ftr = ftr / np.max(ftr)
    condition = utils.copy_into_new_dim(ftr < thresh, f.shape[-1])
    return np.ma.masked_where(condition, f)

In [None]:
def update(log, thresh):
    axes = mplt.corner(
        mask_4d(f, thresh, ftr=ftr),
        coords=coords,
        diag_kind='None',
        labels=dims_units,
        thresh=10.0**-3.0,
        thresh_type='frac',
        fill_value=0.0,
        norm='log' if log else None,
        handle_log='floor',
        prof='edges',
        prof_kws=dict(kind='step', lw=0.4, alpha=0.8),
    )
    plt.show()
    
interactive(update, log=False, thresh=(0.0, 0.99, 0.01))

### Repeat for any 1D or 2D projection

In [None]:
def proj1d_mask(f, thresh=0.1, fpr=None, normalize=True, return_frac=False, axis=0):
    axis_proj = [i for i in range(f.ndim) if i != axis]
    if fpr is None:
        fpr = utils.project(f, axis_proj)
    fpr = fpr / np.max(fpr)
    idx = np.where(fpr > thresh)
    frac = np.sum(fpr[idx]) / np.sum(fpr)
    
    idx = utils.make_slice(f.ndim, axis_proj, idx)    
    p = np.sum(f[idx], axis=int(axis == 0))

    if normalize:
        p = p / np.sum(p)
    if return_frac:
        return p, frac
    return p

In [None]:
def update(thresh, dim):
    axis = dims.index(dim)
    p, frac = proj1d_mask(f, thresh, normalize=True, return_frac=True, axis=axis)
    
    fig, ax = pplt.subplots(figsize=(4, 1.5))
    ax.format(xlabel=dims_units[axis], title=f'frac = {frac:.2f}')
    ax.bar(coords[axis], p, color='black')
    plt.show()
    
thresh = widgets.FloatSlider(description='4D thresh', min=0.0, max=0.99, value=0.5, step=0.001)
interactive(update, thresh=thresh, dim=reversed(dims))

In [None]:
def proj2d_mask(f, thresh=0.1, fpr=None, normalize=True, return_frac=False, axis=(2, 3)):
    # Compute 3D mask.
    axis_proj = [i for i in range(f.ndim) if i not in axis]
    if fpr is None:
        fpr = utils.project(f, axis_proj)
    fpr = fpr / np.max(fpr)
    mask = fpr < thresh
    frac = np.sum(fpr[~mask]) / np.sum(fpr)

    # Copy 3D mask into the two projected dimensions.
    mask = utils.copy_into_new_dim(mask, (f.shape[axis[0]], f.shape[axis[1]]), axis=-1, copy=True)
    # Put the dimensions in the correct order.        
    isort = np.argsort(list(axis_proj) + list(axis))
    mask = np.moveaxis(mask, isort, np.arange(5))
    
    # Now project the masked 5D array onto the specified axis.    
    im = utils.project(np.ma.masked_array(f, mask=mask), axis=axis)
    if return_frac:
        return im, frac
    return im

In [None]:
def update(thresh, dim1, dim2, log=False, **kws):
    if dim1 == dim2:
        return
    axis = [dims.index(dim1), dims.index(dim2)]
    im, frac = proj2d_mask(f, thresh, normalize=True, return_frac=True, axis=axis)
    print('frac = {}'.format(frac))
    
    fig, ax = pplt.subplots()
    kws['norm'] = 'log' if log else None
    kws.setdefault('thresh', 10.0**-3.5)
    kws.setdefault('thresh_type', 'frac')
    kws.setdefault('colorbar', True)
    kws.setdefault('profx', True)
    kws.setdefault('profy', True)
    ax.format(xlabel=dims_units[axis[0]], ylabel=dims_units[axis[1]], title=f'frac = {frac:.2f}')
    mplt.plot_image(im, x=coords[axis[0]], y=coords[axis[1]], ax=ax, **kws)
    ax.format(xlim=sorted(ax.get_xlim()))
    plt.show()
    
thresh = widgets.FloatSlider(description='3D thresh', min=0.0, max=0.99, value=0.5, step=0.001)
kws = dict()
interactive(update, thresh=thresh, dim1=dims, dim2=reversed(dims), log=False)

## Radial density

In [None]:
f[f < 10.0**-3.5] = 0.0

In [None]:
def get_radii(coords, Sigma):
    COORDS = np.meshgrid(*coords, indexing='ij')
    shape = tuple([len(c) for c in coords])
    R = np.zeros(shape)
    Sigma_inv = np.linalg.inv(Sigma)
    for ii in tqdm(np.ndindex(shape)):
        vec = np.array([C[ii] for C in COORDS])
        R[ii] = np.sqrt(np.linalg.multi_dot([vec.T, Sigma_inv, vec]))
    return R

def radial_density(f, R, radii, dr=None):
    if dr is None:
        dr = 0.5 * np.max(R) / (len(R) - 1)
    fr = []
    for r in tqdm(radii):
        f_masked = np.ma.masked_where(np.logical_or(R < r, R > r + dr), f)
        fr.append(np.mean(f_masked))
    return np.array(fr)

In [None]:
ndim = 3  # number of dimensions in projected array
nr = 70  # number of radial points
for axis in itertools.combinations([0, 1, 2, 3, 4], ndim):    
    # Project the distribution onto the specified axis.
    print('axis =', axis)
    _f = utils.project(f, axis)
    
    # Compute the covariance matrix.
    _coords = [coords[i] for i in axis]
    Sigma, mu = ba.dist_cov(_f, _coords, disp=True)
    
    # Compute the radii in normalized space.
    print('Computing radii...')
    R = get_radii(_coords, Sigma)
    
    # Compute the approximate radial density in normalized space.
    radii = np.linspace(0.0, np.max(R), nr)
    fr = radial_density(_f, R, radii, dr=None)
    fr = fr / np.nanmax(fr)
    
    # Plot on top of a Gaussian which is normalized in the same way.
    fig, axes = pplt.subplots(ncols=2, figwidth=5.75, figheight=2.55, share=False, space=10)
    for ax in axes:
        alpha = 0.2
        ax.plot(radii, np.exp(-0.5 * radii**2), color='red', alpha=alpha, label='gauss')
        # ax.plot(rs, [float(r <= 2.0) for r in rs], color='blue', alpha=alpha, label='uniform')
        ax.plot(
            radii, fr, color='black', label='data',
            marker='.', lw=0, ms=3,
        )
        ax.legend(ncols=1, loc='upper right', handlelength=1)
    axes.format(xlim=(0.0, 5.0))
    axes[1].format(yscale='log', yformatter='log')
    axes[1].format(ymin=1e-5, ymax=1.5)
    
    title = "-".join([dims[i] for i in axis])
    axes.format(xlabel=r"$r = \sqrt{\mathbf{x}^T\mathbf{\Sigma}^{-1}\mathbf{x}}$", 
                ylabel=r"$f(r)$", title=title)
    # plt.savefig(f'_output/{title}.png')
    plt.show()