# Locating correlations in five-dimensional phase space 

In [None]:
import sys
import os
from os.path import join
import importlib
import numpy as np
import h5py
import itertools
from tqdm.notebook import tqdm
from tqdm.notebook import trange
from plotly import graph_objects as go
from matplotlib import pyplot as plt
from ipywidgets import interactive
from ipywidgets import widgets
import proplot as pplt

sys.path.append('../../')
from tools import plotting as mplt
from tools import utils
from tools.utils import project
from tools import analysis as ba

In [None]:
pplt.rc['grid'] = False
pplt.rc['cmap.discrete'] = False
pplt.rc['cmap.sequential'] = 'viridis'
pplt.rc['figure.facecolor'] = 'white'

## Load data 

In [None]:
folder = '_saved/2022-07-15-VS06/'

Load the phase space density array `f` and grid coordinates `coords`.

In [None]:
info = utils.load_pickle(join(folder, 'info.pkl'))
filename = info['filename']
coords = utils.load_stacked_arrays(join(folder, f'coords_{filename}.npz'))
shape = tuple([len(c) for c in coords])
f = np.memmap(join(folder, f'f_{filename}.mmp'), shape=shape, dtype='float', mode='r')

print('f.shape:', shape)

Center the coordinates on the centroid calculated from the profiles.

In [None]:
profs = []
for i in range(5):
    prof = utils.project(f, axis=i)
    prof = prof / np.sum(prof)
    coords[i] = coords[i] - np.average(coords[i], weights=prof)
    profs.append(prof)

Crop the array.

In [None]:
crop = (
    (0, f.shape[0]),
    (10, f.shape[1] - 10),
    (0, f.shape[2]),
    (16, f.shape[3] - 16),
    (12, f.shape[4] - 19),
)

dims = ["x", "x'", "y", "y'", "w"]
units = ["mm", "mrad", "mm", "mrad", "MeV"]
dims_units = [f'{d} [{u}]' for d, u in zip(dims, units)]

fig, axes = pplt.subplots(ncols=5, figwidth=7, spanx=False, figheight=None)
for i, ax in enumerate(axes):
    ax.plot(profs[i], color='black')
    ax.axvspan(crop[i][0], crop[i][1] - 1, color='black', alpha=0.1)
    ax.format(xlabel=dims[i] + ' [pixel]')
# axes.format(yscale='log')
plt.show()

In [None]:
ind = tuple([slice(c[0], c[1]) for c in crop])
f = f[ind]
coords = [c[ind[i]] for i, c in enumerate(coords)]

Normalize `f` to the range [0, 1].

In [None]:
f_min = np.min(f)
if f_min < 0.0:
    print(f'min(f) = {f_min}')
    print('Clipping to zero.')
    f = np.clip(f, 0.0, None)
f = f / np.max(f)

print(f'f_min = {np.min(f)}')
print(f'f_max = {np.max(f)}')

Apply a threshold to `f`.

In [None]:
f[f < 10.0**-3.5] = 0

## Rectangular slices

### 1D projections 

`slice_type` can be 'int', in which case only one index is selected along the axis, or 'range', in which case a range of indices are selected.

In [None]:
mplt.interactive_proj1d(
    f, coords=coords, dims=dims, units=units, default_ind=4,
    kind='line',  # {'line', 'step', 'bar'}
    slice_type='int',  # {'int', 'range'}
)

Zoom in on the center of transverse phase space.

In [None]:
ind_center = [np.argmin(np.abs(c)) for c in coords]

fig, axes = pplt.subplots(ncols=4, figwidth=4.5, figheight=1.35)
axes.format(xlabel="w [MeV]", yticklabels=[], xlim=(-0.065, 0.065))
for i, ax in enumerate(axes):
    idx = utils.make_slice(4, np.arange(i + 1), ind_center[:i+1])
    frac = np.sum(f[idx]) / np.sum(f)
    pw = project(f[idx], 3 - i)
    pw = pw / np.sum(pw)
    ax.bar(coords[4], pw, color='black', width=1)
    title = ""
    for j in range(i + 1):
        title += r"${} \approx$".format(dims[j])
    title += "0"
    title += f"\n({100.0*frac:.2f}%)"
    ax.format(title=title, title_kw=dict(fontsize='medium'))
axes.format(ylim=(0.0, axes[0].get_ylim()[1]), xspineloc='bottom', yspineloc='neither')

### 2D projections 

The `thresh` slider applies a fractional threshold to the 2D projection; a good value is around -3.5 for the most recent measurements.

In [None]:
prof_kws = dict(lw=0.5, alpha=0.7, color='white', scale=0.12)  # for 1D profiles
mplt.interactive_proj2d(
    f, coords=coords, dims=dims, units=units, prof_kws=prof_kws, 
    slice_type='int',  # {'int', 'range'}
)

## Contour slices

We can also use contour slices — volumes defined by density contours.

### 1D projections

Plot the 1D projection of the distribution along any axis. The `4D thresh` slider applies a threshold to the remaining four dimensions before computing the projection. For example, if we are looking at the energy projection, `4D thresh` applies to the transverse phase space x-x'-y-y'. The density in this space is normalized to the range [0, 1], so a threshold of 0 selects all particles, while a threshold of 1 selects only the brightest pixel in x-x'-y-y'. The fraction of the beam selected is printed on the figure (`frac`).

In [None]:
def proj1d_mask(f, thresh=0.1, fpr=None, normalize=True, return_frac=False, axis=0):
    axis_proj = [i for i in range(f.ndim) if i != axis]
    if fpr is None:
        fpr = utils.project(f, axis_proj)
    fpr = fpr / np.max(fpr)
    idx = np.where(fpr > thresh)
    frac = np.sum(fpr[idx]) / np.sum(fpr)
    idx = utils.make_slice(f.ndim, axis_proj, idx)    
    p = np.sum(f[idx], axis=int(axis == 0))
    if normalize:
        p = p / np.sum(p)
    if return_frac:
        return p, frac
    return p

In [None]:
def update(thresh, dim):
    axis = dims.index(dim)
    p, frac = proj1d_mask(f, thresh, normalize=True, return_frac=True, axis=axis)
    
    fig, ax = pplt.subplots(figsize=(4, 1.5))
    ax.format(xlabel=dims_units[axis], title=f'frac = {frac:.2f}')
    ax.bar(coords[axis], p, color='black')
    plt.show()
    
    
thresh = widgets.FloatSlider(description='4D thresh', min=0.0, max=0.99, value=0.5, step=0.001)
interactive(update, thresh=thresh, dim=reversed(dims))

### 2D projections 

We can do the same thing with 2D projections, applying a threshold in the remaining three dimensions.

In [None]:
def proj2d_mask(f, thresh=0.1, fpr=None, normalize=True, return_frac=False, axis=(2, 3)):
    # Compute the 3D mask.
    axis_proj = [i for i in range(f.ndim) if i not in axis]
    if fpr is None:
        fpr = utils.project(f, axis_proj)
    fpr = fpr / np.max(fpr)
    mask = fpr < thresh
    frac = np.sum(fpr[~mask]) / np.sum(fpr)

    # Copy the 3D mask into the two projected dimensions.
    mask = utils.copy_into_new_dim(mask, (f.shape[axis[0]], f.shape[axis[1]]), axis=-1, copy=True)
    # Put the dimensions in the correct order.        
    isort = np.argsort(list(axis_proj) + list(axis))
    mask = np.moveaxis(mask, isort, np.arange(5))
    
    # Project the masked 5D array onto the specified axis.    
    im = utils.project(np.ma.masked_array(f, mask=mask), axis=axis)
    if return_frac:
        return im, frac
    return im

In [None]:
def update(thresh, dim1, dim2, log=False, **kws):
    if dim1 == dim2:
        return
    axis = [dims.index(dim1), dims.index(dim2)]
    im, frac = proj2d_mask(f, thresh, normalize=True, return_frac=True, axis=axis)
    
    fig, ax = pplt.subplots()
    kws['norm'] = 'log' if log else None
    kws.setdefault('thresh', 10.0**-3.5)
    kws.setdefault('thresh_type', 'frac')
    kws.setdefault('colorbar', True)
    kws.setdefault('profx', True)
    kws.setdefault('profy', True)
    ax.format(xlabel=dims_units[axis[0]], ylabel=dims_units[axis[1]], title=f'frac = {frac:.2f}')
    mplt.plot_image(im, x=coords[axis[0]], y=coords[axis[1]], ax=ax, prof_kws=prof_kws, **kws)
    ax.format(xlim=sorted(ax.get_xlim()))
    plt.show()
    
    
thresh = widgets.FloatSlider(description='3D thresh', min=0.0, max=0.99, value=0.5, step=0.001)
kws = dict()
interactive(update, thresh=thresh, dim1=dims, dim2=reversed(dims), log=False)

## Radial density

Radial density plots may also convey useful information. We define the radius as $r = \sqrt{\mathbf{x}^T \mathbf{\Sigma}^{-1} \mathbf{x}}$, where $\mathbf{x}$ is the phase space coordinate vector and $\mathbf{\Sigma} = \langle{\mathbf{x}\mathbf{x}^T}\rangle$ is the covariance matrix. So we are looking at the density within nested ellipsoidal shells in the phase space. (This hides information unless the distribution has ellipsoidal symmetry.)

In [None]:
def get_radii(coords, Sigma):
    COORDS = np.meshgrid(*coords, indexing='ij')
    shape = tuple([len(c) for c in coords])
    R = np.zeros(shape)
    Sigma_inv = np.linalg.inv(Sigma)
    for ii in tqdm(np.ndindex(shape)):
        vec = np.array([C[ii] for C in COORDS])
        R[ii] = np.sqrt(np.linalg.multi_dot([vec.T, Sigma_inv, vec]))
    return R

def radial_density(f, R, radii, dr=None):
    if dr is None:
        dr = 0.5 * np.max(R) / (len(R) - 1)
    fr = []
    for r in tqdm(radii):
        f_masked = np.ma.masked_where(np.logical_or(R < r, R > r + dr), f)
        # I just use the mean density within this shell...
        fr.append(np.mean(f_masked))
    return np.array(fr)

In [None]:
def plot_radial_density(axis=None, nr=70):
    # Project the distribution onto the specified axis.
    _f = utils.project(f, axis)
    
    # Compute the covariance matrix.
    _coords = [coords[i] for i in axis]
    Sigma, mu = ba.dist_cov(_f, _coords, disp=True)
    
    # Compute the radii in normalized space.
    print('Computing radii on mesh...')
    R = get_radii(_coords, Sigma)
    
    # Compute the approximate radial density in normalized space.
    radii = np.linspace(0.0, np.max(R), nr)
    fr = radial_density(_f, R, radii, dr=None)
    fr = fr / np.nanmax(fr)
    
    # Plot on top of a Gaussian which is normalized in the same way.
    fig, axes = pplt.subplots(ncols=2, figwidth=5.75, figheight=2.55, share=False, space=10)
    for ax in axes:
        alpha = 0.2
        ax.plot(radii, np.exp(-0.5 * radii**2), color='red', alpha=alpha, label='gauss')
        # ax.plot(rs, [float(r <= 2.0) for r in rs], color='blue', alpha=alpha, label='uniform')
        ax.plot(
            radii, fr, color='black', label='data',
            marker='.', lw=0, ms=3,
        )
        ax.legend(ncols=1, loc='upper right', handlelength=1)
    axes.format(xlim=(0.0, 5.0))
    axes[1].format(yscale='log', yformatter='log')
    axes[1].format(ymin=1e-5, ymax=1.5)
    
    title = "-".join([dims[i] for i in axis])
    axes.format(xlabel=r"$r = \sqrt{\mathbf{x}^T\mathbf{\Sigma}^{-1}\mathbf{x}}$", 
                ylabel=r"$f(r)$", title=title)
    return axes

Here are a few examples in 3D. Note the x-y-w distribution, in which the hollowish core is visible.

In [None]:
ndim = 3  # number of dimensions in projected array
nr = 70  # number of radial points
for axis in itertools.combinations([0, 1, 2, 3, 4], ndim):    
    print('axis =', axis)
    axes = plot_radial_density(axis, nr)
    plt.show()

## Corner plots 

### Corner plot with rectangular slices

This shows the full 2D projections.

In [None]:
def update(cmap=None, log=False):
    axes = mplt.corner(
        f,
        coords=coords,
        diag_kind='None',
        labels=dims_units,
        norm='log' if log else None,
        handle_log='floor',
        thresh=10.0**-3.5,
        thresh_type='frac',
        prof='edges',
        prof_kws=prof_kws,
        linewidth=0, rasterized=True,
        cmap=cmap,
    )
    axes.format(xlabel_kw=dict(fontsize='large'), ylabel_kw=dict(fontsize='large'))
    return axes

interactive(update, cmap=mplt.CMAPS, log=False)

This plot shows everything — all 2D projections — for rectangular slices of the 5D array. (*Very slow*.)

In [None]:
slice_type = 'int'

# Sliders
n = f.ndim
sliders, checks = [], []
for k in range(n):
    if slice_type == "int":
        slider = widgets.IntSlider(
            min=0,
            max=f.shape[k],
            value=f.shape[k] // 2,
            description=dims[k],
            continuous_update=True,
        )
    elif slice_type == "range":
        slider = widgets.IntRangeSlider(
            value=(0, f.shape[k]),
            min=0,
            max=f.shape[k],
            description=dims[k],
            continuous_update=True,
        )
    else:
        raise ValueError("Invalid `slice_type`.")
    slider.layout.display = "none"
    sliders.append(slider)
    checks.append(widgets.Checkbox(description=f"slice {dims[k]}"))

def hide(button):
    for k in range(n):
        for element in [sliders[k], checks[k]]:
            element.layout.display = None
        if not checks[k].value:
            sliders[k].layout.display = "none"
            
for check in checks:
    check.observe(hide, names="value")
            
mask = np.full(f.shape, False)
_f = np.ma.masked_array(f, mask=mask)
            
def update(
    cmap,
    handle_log,
    log,
    check1,
    check2,
    check3,
    check4,
    check5,
    slider1,
    slider2,
    slider3,
    slider4,
    slider5,
):
    # Make slice that keeps the original dimensions of the array.
    checks = [check1, check2, check3, check4, check5]
    sliders = [slider1, slider2, slider3, slider4, slider5]
    axis_slice = [dims.index(dim) for dim, check in zip(dims, checks) if check]
    ind = sliders
    for k in range(n):
        if type(ind[k]) is int:
            ind[k] = (ind[k], ind[k] + 1)
    idx = utils.make_slice(f.ndim, axis=axis_slice, ind=ind)

    _f.mask[:, :, :, :, :] = True
    _f.mask[idx] = False
    
    # Corner plot of this slice.
    axes = mplt.corner(
        _f,
        coords=coords,
        diag_kind='None',
        labels=dims_units,
        norm='log' if log else None,
        thresh=10.0**-3.5,
        thresh_type='frac',
        prof='edges',
        prof_kws=prof_kws,
        linewidth=0, rasterized=True,
        cmap=cmap,
        handle_log=handle_log,
    )
    axes.format(xlabel_kw=dict(fontsize='large'), ylabel_kw=dict(fontsize='large'))
    return axes

kws = dict()
kws['cmap'] = mplt.CMAPS
kws['handle_log'] = 'floor'
kws['log'] = False
for i, check in enumerate(checks, start=1):
    kws[f"check{i}"] = check
for i, slider in enumerate(sliders, start=1):
    kws[f"slider{i}"] = slider
interactive(update, **kws)

### Corner plot with contour slices 

#### 4D transverse contours: 2D transverse projections and 1D energy projection

Here, we observe the energy distribution within density contours in the 4D transverse phase space.

In [None]:
ftr = project(f, axis=(0, 1, 2, 3))
ftr = ftr / np.max(ftr)

def update(log=False, thresh=0.5):
    global ftr 
    frac = np.sum(ftr[ftr > thresh]) / np.sum(ftr)
    print(f'frac = {frac:.4f}')
    axes = mplt.corner(
        np.ma.masked_where(ftr < thresh, ftr),
        coords=coords[:4],
        diag_kind='None',
        labels=dims_units,
        thresh=10.0**-3.0,
        thresh_type='frac',
        fill_value=0.0,
        norm='log' if log else None,
        handle_log='floor',
        prof='edges',
        prof_kws=dict(kind='step', lw=0.4, alpha=0.8),
        linewidth=0,rasterized=True
    )
    pw = proj1d_mask(f, thresh, fpr=ftr, normalize=True, return_frac=False, axis=4)
    ax = axes[0, 2]
    for i in range(3):
        for j in range(i + 1):
            axes[i, j]._shared_x_axes.remove(ax)
            axes[i, j]._shared_y_axes.remove(ax)
    ax.axis('on')
    ax.bar(coords[4], pw, color='black')
    ax.format(
        xlim=(np.min(coords[4]), np.max(coords[4])),
        ylim=(0, 0.1),
        xspineloc='bottom', yspineloc='left',
        title='energy projection',
    )
    return axes, frac
    
interactive(update, log=False, thresh=(0.0, 0.99, 0.01))

#### 4D transverse contours: all projections

Plot the 2D projections of the 5D phase space distribution as the 4D boundary is changed in transverse phase space. To slice the array, we compute a 4D mask in the transverse phase space, then copy the mask along the last axis of the array to get a 5D mask. *(This is slow.)*. The only thing new, relative to the previous plot, is the 2D correlations of energy with the other variables.

In [None]:
def mask_4d(f, thresh=0.0, ftr=None):
    """Mask N-D array `f` based on contours of `f.sum(axis=-1)`."""
    if ftr is None:
        ftr = np.sum(f, axis=-1)
        ftr = ftr / np.max(ftr)
    condition = utils.copy_into_new_dim(ftr < thresh, f.shape[-1])
    return np.ma.masked_where(condition, f)

In [None]:
def update(log, thresh):
    axes = mplt.corner(
        mask_4d(f, thresh, ftr=ftr),
        coords=coords,
        diag_kind='None',
        labels=dims_units,
        thresh=10.0**-3.0,
        thresh_type='frac',
        fill_value=0.0,
        norm='log' if log else None,
        handle_log='floor',
        prof='edges',
        prof_kws=dict(kind='step', lw=0.4, alpha=0.8),
    )
    plt.show()
    
thresh = widgets.FloatSlider(description="x-x'-y-y' thresh", min=0.0, max=0.99, value=0.5, step=0.001)
interactive(update, log=False, thresh=thresh)

#### 5D contours

Plot the 2D projections of the 5D array within 5D density contours.

In [None]:
def update(cmap, log, thresh):
    axes = mplt.corner(
        np.ma.masked_less_equal(f, thresh),
        coords=coords,
        diag_kind='None',
        labels=dims_units,
        norm='log' if log else None,
        handle_log='floor',
        thresh=10.0**-3.5,
        thresh_type='frac',
        prof='edges',
        prof_kws=prof_kws,
        linewidth=0, rasterized=True,
        cmap=cmap,
    )
    axes.format(xlabel_kw=dict(fontsize='large'), ylabel_kw=dict(fontsize='large'))
    return axes

thresh = widgets.FloatSlider(description="5D thresh", min=0.0, max=0.99, value=0.5, step=0.001)
interactive(update, cmap=mplt.CMAPS, log=False, thresh=thresh)

View the x-w projection within shrinking 5D density contours.

In [None]:
threshs = np.linspace(0.0, 0.7, 12)
ims = []
for thresh in tqdm(threshs):
    im = utils.project(np.ma.masked_less_equal(f, thresh), axis=(4, 0))
    ims.append(im)

In [None]:
ncols = 6
nrows = int(np.ceil(len(ims) / ncols))
fig, axes = pplt.subplots(nrows=nrows, ncols=ncols, figwidth=8)
axes.format(xlabel=dims_units[4], ylabel=dims_units[0])
for ax, im, thresh in zip(axes, ims, threshs):
    mplt.plot_image(im, x=coords[4], y=coords[0], ax=ax)
    ax.annotate(f"t = {thresh:.2f}", xy=(0.03, 0.03), xycoords='axes fraction', 
                color='red', fontsize='small')
plt.show()