# Step 4: visualization

In [None]:
import sys
import os
from os.path import join
import time
from datetime import datetime
import importlib
import numpy as np
import pandas as pd
import h5py
import imageio
from scipy import ndimage
from scipy import interpolate
import skimage
from tqdm import tqdm
from tqdm import trange
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import colors
from matplotlib import patches
import plotly.graph_objs as go
from ipywidgets import interact
from ipywidgets import interactive
from ipywidgets import interactive_output
from ipywidgets import widgets
from IPython.display import display
from IPython.display import clear_output
import proplot as pplt

sys.path.append('../..')
from tools import energyVS06 as energy
from tools import image_processing as ip
from tools import plotting as mplt
from tools import utils
from tools.utils import project

In [None]:
mpl.rcParams['path.simplify'] = True
mpl.rcParams['path.simplify_threshold'] = 1.0
mpl.style.use('fast')
pplt.rc['grid'] = False
pplt.rc['cmap.discrete'] = False
pplt.rc['cmap.sequential'] = 'viridis'
pplt.rc['figure.facecolor'] = 'white'

## Load data 

In [None]:
folder = '_saved/2022-04-29'

In [None]:
shape = tuple(np.loadtxt(join(folder, 'a5d_new_shape.txt')).astype(int))
shape

In [None]:
coords = utils.load_stacked_arrays(join(folder, 'coords.npz'))
for c in coords:
    print(c.shape)

In [None]:
f = np.memmap(join(folder, 'a5d_new.mmp'), shape=shape, dtype='float', mode='r')

In [None]:
print(f'min(f) = {np.min(f)}')
f = np.clip(f, 0.0, None)

Flip the energy axis (for now).

In [None]:
f = f[..., ::-1]

## Static 

In [None]:
dims = ["x", "xp", "y", "yp", "w"]
units = ["mm", "mrad", "mm", "mrad", "MeV"]
dims_units = [f'{d} [{u}]' for d, u in zip(dims, units)]
dim_to_int = {dim: i for i, dim in enumerate(dims)}
prof_kws = dict(lw=1.0, alpha=0.5, color='white', scale=0.15)

### Projections

In [None]:
for norm in [None, 'log']:
    axes = mplt.corner(
        f,
        coords=coords,
        diag_kind='None',  # {'line', 'None'}
        prof='edges',  # {True, False, 'edges'}
        prof_kws=prof_kws,
        labels=dims_units,
        norm=norm,
        frac_thresh=1e-6,
    )
    plt.savefig(f'_output/corner_int_norm{norm}.png')
    plt.show()

### Slices

In [None]:
ind = np.unravel_index(np.argmax(f), f.shape)
ind = tuple([i for i in ind])
print(ind)

In [None]:
frac_thresh = 1e-5
prof = True

axes_slice = [(k, j, i) for i in range(f.ndim) for j in range(i) for k in range(j)]
axes_view = [tuple([i for i in range(f.ndim) if i not in axis])
             for axis in axes_slice]
for axis, axis_view in zip(axes_slice, axes_view):
    idx = utils.make_slice(5, axis, [ind[i] for i in axis])
    f_slice = f[idx]
    f_slice = f_slice / np.max(f_slice)
    
    dim1, dim2 = [dims[i] for i in axis_view]
    
    fig, plot_axes = pplt.subplots(ncols=2)
    for ax, norm in zip(plot_axes, [None, 'log']):
        mplt.plot_image(f_slice, x=coords[axis_view[0]], y=coords[axis_view[1]],
                        ax=ax,
                        profx=prof, profy=prof, prof_kws=prof_kws,
                        frac_thresh=frac_thresh, norm=norm, colorbar=True)
    plot_axes.format(xlabel=dim1, ylabel=dim2)
    string = '_output/slice_'
    for i in axis:
        string += f'{dims[i]}-{ind[i]}'
    _dims = [dims[i] for i in axis]
    _units = [units[i] for i in axis]
    _vals = [coords[i][ind[i]] for i in axis]
    plot_axes.format(suptitle=f'{_dims[0]} = {_vals[0]:.2f} [{_units[0]}],  {_dims[1]} = {_vals[1]:.2f} [{_units[1]}],  {_dims[2]} = {_vals[2]:.2f} [{_units[2]}]')
    plot_axes.format(suptitle_kw=dict(fontweight='normal'))
    plt.savefig(string + '.png')
    plt.show()

## Interactive plots

In [None]:
# Widgets
default_ind = (2, 3)  # plot on initial render
cmaps = ['viridis', 'dusk_r', 'mono_r', 'plasma']
cmap = widgets.Dropdown(options=cmaps, description='cmap')
thresh = widgets.FloatSlider(value=-5.0, min=-8.0, max=0.0, step=0.1, description='thresh', continuous_update=True)
discrete = widgets.Checkbox(value=False, description='discrete')
log = widgets.Checkbox(value=False, description='log')
contour = widgets.Checkbox(value=False, description='contour')
profiles = widgets.Checkbox(value=True, description='profiles')
scale = widgets.FloatSlider(value=0.15, min=0.0, max=1.0, step=0.01, description='scale')
dim1 = widgets.Dropdown(options=dims, index=default_ind[0], description='dim 1')
dim2 = widgets.Dropdown(options=dims, index=default_ind[1], description='dim 2')

# Sliders
sliders, range_sliders, checks = [], [], []
for k in range(5):
    slider = widgets.IntSlider(
        min=0, max=shape[k], value=shape[k]//2,
        description=dims[k], 
        continuous_update=True,
    )
    range_slider = widgets.IntRangeSlider(
        value=(0, f.shape[k]), min=0, max=f.shape[k],
        description=dims[k], 
        continuous_update=True,
    )
    for _slider in (slider, range_slider):
        slider.layout.display = 'none'
    sliders.append(slider)
    range_sliders.append(range_slider)
    checks.append(widgets.Checkbox(description=f'slice {dims[k]}'))
    
# Hide/show sliders.
def hide(button):
    for k in range(5):
        # Hide elements for dimensions being plotted.
        valid = dims[k] not in (dim1.value, dim2.value)
        disp = None if valid else 'none'
        for element in [sliders[k], range_sliders[k], checks[k]]:
            element.layout.display = disp
        # Uncheck boxes for dimensions being plotted. 
        if not valid and checks[k].value:
            checks[k].value = False
        # Make sliders respond to check boxes.
        if not checks[k].value:
            for slider in [sliders[k], range_sliders[k]]:
                slider.layout.display = 'none'
                
for element in (dim1, dim2, *checks):
    element.observe(hide, names='value')
# Initial hide
for k in range(5):
    if k in default_ind:
        checks[k].layout.display = 'none'
    for slider in [sliders[k], range_sliders[k]]:
        slider.layout.display = 'none'

In [None]:
def update(dim1, dim2, check_x, check_xp, check_y, check_yp, check_w, 
           x, xp, y, yp, w, log, profiles, thresh, cmap):
    if (dim1 == dim2):
        return
    checks = [check_x, check_xp, check_y, check_yp, check_w]
    for dim, check in zip(dims, checks):
        if check and dim in (dim1, dim2):
            return
    axis_view = [dim_to_int[dim] for dim in (dim1, dim2)]
    axis_slice = [dim_to_int[dim] for dim, check in zip(dims, checks) if check]
    ind = [x, xp, y, yp, w]
    ind = [(ind[k], ind[k] + 1) for k in axis_slice]
    H = f[utils.make_slice(f.ndim, axis_slice, ind)]
    H = utils.project(H, axis_view)
    H_max = np.max(H)
    if H_max > 0:
        H = H / H_max
        
    fig, ax = pplt.subplots()
    mplt.plot_image(
        H, x=coords[axis_view[0]], y=coords[axis_view[1]], ax=ax, 
        norm='log' if log else None,
        frac_thresh=10.0**thresh,
        profx=profiles,
        profy=profiles,
        prof_kws=dict(lw=1.0, alpha=0.75, color='white', scale=0.15),
        cmap=cmap,
        colorbar=True,
    )
    ax.format(xlabel=dims_units[axis_view[0]], ylabel=dims_units[axis_view[1]])
    plt.show()

### 2D projection of int slice 

Slicing along each dimension is controlled by the checkboxes and sliders. The sliced distribution is projected onto dimensions `dim1` and `dim2`. 

In [None]:
interactive(
    update, 
    dim1=dim1, dim2=dim2, 
    check_x=checks[0], check_xp=checks[1], check_y=checks[2], check_yp=checks[3], check_w=checks[4],
    x=sliders[0], xp=sliders[1], y=sliders[2], yp=sliders[3], w=sliders[4],
    log=log, profiles=profiles, thresh=thresh, cmap=cmap,
)

### 2D projection of range slice

Unfortunately, ipywidgets does not currently allow you to drag the center of the slider. 

In [None]:
interactive(
    update, 
    dim1=dim1, dim2=dim2, 
    check_x=checks[0], check_xp=checks[1], check_y=checks[2], 
    check_yp=checks[3], check_w=checks[4],
    x=range_sliders[0], xp=range_sliders[1], y=range_sliders[2], 
    yp=range_sliders[3], w=range_sliders[4],
    log=log, profiles=profiles, thresh=thresh, cmap=cmap,
)

## Emittance vs. energy

Compute the 5$\times$5 covariance matrix (this will take a few minutes).

In [None]:
# Sigma, means = utils.dist_cov(f, coords)

In [None]:
print(Sigma)

In [None]:
i, j = (0, 2)
angle, cx, cy = utils.rms_ellipse_dims(Sigma[i, i], Sigma[j, j], Sigma[i, j])

fig, ax = pplt.subplots()
ax.pcolormesh(coords[i], coords[j], utils.project(f, (i, j)).T)
ax.add_patch(
    patches.Ellipse((means[i], means[j]), width=4.0*cx, height=4.0*cy, angle=-np.degrees(angle),
                    ec='white', fill=False))
plt.show()

In [None]:
axes = mplt.corner(
    f,
    coords=coords,
    fig_kws=dict(figwidth=1.5*4, space=1),
    diag_kind='None',
    prof='edges',
    prof_kws=dict(alpha=0.5, lw=0.9),
    labels=dims_units,
    frac_thresh=1e-6,
)
for i in range(4):
    for j in range(i + 1):
        ax = axes[i, j]
        angle, cx, cy = utils.rms_ellipse_dims(Sigma[j, j], Sigma[i + 1, i + 1], Sigma[i + 1, j])
        center = (means[j], means[i + 1])
        width = 4.0 * cx
        height = 4.0 * cy
        angle = -np.degrees(angle)
        ax.add_patch(patches.Ellipse(center, width, height, angle=angle,
                                     ec='white', fill=False))
plt.savefig('_output/corner_nodiag_cov.png')
plt.show()

In [None]:
alpha_x, alpha_y, beta_x, beta_y = twiss2D(Sigma)
eps_x, eps_y, eps_1, eps_2 = emittances(Sigma)
print(f'alpha_x = {alpha_x}')
print(f'alpha_y = {alpha_y}')
print(f'beta_x = {beta_x}')
print(f'beta_y = {beta_y}')
print(f'epsx = {eps_x}')
print(f'epsy = {eps_y}')
print(f'eps1 = {eps_1}')
print(f'eps2 = {eps_2}')

Try computing emittances for each energy slice.

In [None]:
f3d = utils.project(f, (i, j, 4))

In [None]:
_Sigmas = []  # transverse covariance matrices
_means = []  # transverse mean
for k in trange(shape[4]):
    _Sigma, _mean = utils.dist_cov(f[:, :, :, :, k], coords[:4])
    _Sigmas.append(_Sigma)
    _means.append(_mean)

In [None]:
_emittances = []
for _Sigma in _Sigmas:
    _emittances.append(emittances(_Sigma))
_emittances = np.array(_emittances)

In [None]:
eps_x, eps_y, eps_1, eps_2 = emittances(Sigma)

In [None]:
colors = pplt.Cycle('colorblind').by_key()['color']
labels = [r'$\varepsilon_x$', r'$\varepsilon_y$', r'$\varepsilon_1$', r'$\varepsilon_2$']

fig, ax = pplt.subplots(figsize=(4.5, 2.5))
for i in range(2):
    ax.plot(coords[4], _emittances[:, i], label=labels[i], marker='.', ms=3)
for i in range(2):
    ax.axhline([eps_x, eps_y, eps_1, eps_2][i], color=colors[i], label=labels[i]+' (full)', 
               alpha=0.3, ls='-')
ax.format(ylabel='RMS emittance [mm mrad]', xlabel='w [MeV]', title='Energy slice emittances')
ax.legend(ncols=1, loc='r')
plt.savefig('_output/slice_emittances.png')

In [None]:
n = 12
ncols = 6

for (i, j) in [(0, 1), (2, 3)]:
    f3d = utils.project(f, (i, j, 4))
    step = f3d.shape[2] // n
    nrows = int(np.ceil(n / ncols))
    w_coords = coords[4]
    ks = np.arange(0, f3d.shape[2], step)

    fig, axes = pplt.subplots(ncols=ncols, nrows=nrows, figwidth=9.0)
    for ax, k in zip(axes, ks):
        ax.pcolormesh(coords[i], coords[j], f3d[:, :, k].T)
        ax.annotate(f'w = {w_coords[k]:.2f} [MeV]', xy=(0.02, 0.98), verticalalignment='top',
                    xycoords='axes fraction', fontsize='small', color='white')
    axes.format(xlabel=dims_units[i], ylabel=dims_units[j])
    plt.savefig(f'_output/slice_plots_{dims[i]}.png')
    plt.show()