# Step 2: visualization

In [None]:
import sys
import os
from os.path import join
import time
from datetime import datetime
import importlib
import numpy as np
import pandas as pd
import h5py
import sympy
from tqdm import tqdm
from tqdm import trange
import matplotlib as mpl
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import colors
from matplotlib import patches
import plotly.graph_objs as go
from ipywidgets import interactive
from ipywidgets import widgets
from IPython.display import display
from IPython.display import clear_output
import proplot as pplt

sys.path.append('../..')
from tools import analysis as ba
from tools import energyVS06 as energy
from tools import image_processing as ip
from tools import plotting as mplt
from tools import utils
from tools.utils import project

In [None]:
mpl.rcParams['path.simplify'] = True
mpl.rcParams['path.simplify_threshold'] = 1.0
mpl.style.use('fast')
pplt.rc['grid'] = False
pplt.rc['cmap.discrete'] = False
pplt.rc['cmap.sequential'] = 'viridis'
pplt.rc['figure.facecolor'] = 'white'

## Load data 

In [None]:
folder = '_output'

In [None]:
info = utils.load_pickle(join(folder, 'info.pkl'))
info

In [None]:
filename = info['filename']
coords = utils.load_stacked_arrays(join(folder, f'coords_{filename}.npz'))
shape = tuple([len(c) for c in coords])
print('shape:', shape)

In [None]:
f = np.memmap(join(folder, f'f_{filename}.mmp'), shape=shape, dtype='float', mode='r')

In [None]:
f_max = np.max(f)
f_min = np.min(f)
if f_min < 0.0:
    print(f'min(f) = {f_min}')
    print('Clipping to zero.')
    f = np.clip(f, 0.0, None)

In [None]:
dims = ["x", "xp", "y", "yp", "w"]
units = ["mm", "mrad", "mm", "mrad", "MeV"]
dims_units = [f'{d} [{u}]' for d, u in zip(dims, units)]
prof_kws = dict(lw=0.5, alpha=0.7, color='white', scale=0.12)

## Interactive

### 2D projection of int slice 

Slicing along each dimension is controlled by the checkboxes and sliders. The sliced distribution is projected onto dimensions `dim1` and `dim2`. 

In [None]:
mplt.interactive_proj2d(f / f_max, coords=coords, default_ind=(2, 3),
                        slider_type='int', dims=dims, units=units)

### 2D projection of range slice

`ipywidgets` does not currently allow you to drag the center of the slider. 

In [None]:
mplt.interactive_proj2d(f / f_max, coords=coords, default_ind=(2, 3),
                        slider_type='range', dims=dims, units=units)

In [None]:
# axis = (0, 3, 4)
# _H = utils.project(f, axis)
# _coords = [coords[i] for i in axis]
# _dims = [dims[i] for i in axis]
# mplt.interactive_proj2d(_H / _H.max(), 
#                         coords=_coords, dims=_dims,
#                         slider_type='int')

## Static 

### Projections

In [None]:
for norm in [None, 'log']:
    axes = mplt.corner(
        f,
        coords=coords,
        diag_kind='None',  # {'line', 'None'}
        prof='edges',  # {True, False, 'edges'}
        prof_kws=prof_kws,
        labels=dims_units,
        norm=norm,
        handle_log='mask',
    )
    plt.savefig(f'_output/int_corner_norm{norm}.png')
    plt.show()

### Slices

In [None]:
ind = np.unravel_index(np.argmax(f), f.shape)
ind = tuple([i for i in ind])
print(ind)

In [None]:
frac_thresh = 1e-5
prof = True

axes_slice = [(k, j, i) for i in range(f.ndim) for j in range(i) for k in range(j)]
axes_view = [tuple([i for i in range(f.ndim) if i not in axis])
             for axis in axes_slice]
for axis, axis_view in zip(axes_slice, axes_view):
    idx = utils.make_slice(5, axis, [ind[i] for i in axis])
    f_slice = f[idx]
    f_slice = f_slice / np.max(f_slice)
    
    dim1, dim2 = [dims[i] for i in axis_view]
    
    fig, plot_axes = pplt.subplots(ncols=2)
    for ax, norm in zip(plot_axes, [None, 'log']):
        mplt.plot_image(f_slice, x=coords[axis_view[0]], y=coords[axis_view[1]],
                        ax=ax,
                        profx=prof, profy=prof, prof_kws=prof_kws,
                        frac_thresh=frac_thresh, norm=norm, colorbar=True)
    plot_axes.format(xlabel=dim1, ylabel=dim2)
    string = '_output/int_slice_'
    for i in axis:
        string += f'{dims[i]}-{ind[i]}'
    _dims = [dims[i] for i in axis]
    _units = [units[i] for i in axis]
    _vals = [coords[i][ind[i]] for i in axis]
    plot_axes.format(suptitle=f'{_dims[0]} = {_vals[0]:.2f} [{_units[0]}],  {_dims[1]} = {_vals[1]:.2f} [{_units[1]}],  {_dims[2]} = {_vals[2]:.2f} [{_units[2]}]')
    plot_axes.format(suptitle_kw=dict(fontweight='normal'))
    plt.savefig(string + '.png')
    plt.show()

## Covariance matrix

Compute the 5$\times$5 covariance matrix (this will take a few minutes).

In [None]:
Sigma, means = utils.dist_cov(f, coords)
sympy.Matrix(np.round(Sigma, 3))

Compute the 5$\times$5 correlation matrix from the covariance matrix.

In [None]:
Corr = utils.cov2corr(Sigma)
sympy.Matrix(np.round(Corr, 3))

In [None]:
g = sns.heatmap(Corr, xticklabels=dims, yticklabels=dims, annot=True,
                cbar=False, cmap='grays')
plt.savefig('_output/correlation_matrix.png')

In [None]:
np.savetxt('_output/Sigma.dat', Sigma)
np.savetxt('_output/Corr.dat', Corr)

In [None]:
for i in range(5):
    for j in range(i):
        angle, cx, cy = utils.rms_ellipse_dims(Sigma[j, j], Sigma[i, i], Sigma[j, i])
        center = (means[j], means[i])
        width = 4.0 * cx
        height = 4.0 * cy

        fig, ax = pplt.subplots()
        mplt.plot_image(utils.project(f, (j, i)), x=coords[j], y=coords[i], ax=ax)
        ax.add_patch(
            patches.Ellipse(
                center, width, height, angle=-np.degrees(angle),
                ec='white', fill=False,
            )
        )
        ax.format(xlabel=dims_units[i], ylabel=dims_units[j])
        plt.savefig(f'_output/rms_ellipse_{dims[j]}-{dims[i]}.png')

In [None]:
axes = mplt.corner(
    f,
    coords=coords,
    prof=False,
    labels=dims_units,
)
for i in range(5):
    for j in range(i):
        ax = axes[i, j]
        angle, cx, cy = utils.rms_ellipse_dims(Sigma[j, j], Sigma[i, i], Sigma[j, i])
        center = (means[j], means[i])
        width = 4.0 * cx
        height = 4.0 * cy
        angle = -np.degrees(angle)
        ax.add_patch(patches.Ellipse(center, width, height, angle=angle,
                                     ec='white', fill=False, lw=0.75, alpha=0.9))
plt.savefig('_output/int_corner_cov.png')
plt.show()

In [None]:
alpha_x, alpha_y, beta_x, beta_y = ba.twiss2D(Sigma)
eps_x, eps_y, eps_1, eps_2 = ba.emittances(Sigma)
print(f'alpha_x = {alpha_x}')
print(f'alpha_y = {alpha_y}')
print(f'beta_x = {beta_x}')
print(f'beta_y = {beta_y}')
print(f'epsx = {eps_x}')
print(f'epsy = {eps_y}')
print(f'eps1 = {eps_1}')
print(f'eps2 = {eps_2}')

## Energy slices

Computeing 4D covariance matrix for each energy slice. For now, compute x-x' and y-y' emittances.

In [None]:
emittances, twiss = [], []
for (i, j) in [(0, 1), (2, 3)]:
    f3d = utils.project(f, (i, j, 4))
    _Sigmas = np.zeros((shape[4], 2, 2))  # x-xp covariance matrix
    _means = np.zeros((shape[4], 2))  # x-xp mean
    _emittances = np.zeros(shape[4])  # rms emittance
    _twiss = np.zeros((shape[4], 2))  # rms alpha, beta
    for k in trange(shape[4]):
        _Sigmas[k], _means[k] = utils.dist_cov(f3d[:, :, k], [coords[i], coords[j]])
        _emittances[k] = ba._emittance(_Sigmas[k])
        _twiss[k] = ba._twiss(_Sigmas[k])
    emittances.append(_emittances)
    twiss.append(_twiss)
emittances = np.array(emittances).T
twiss = np.hstack(twiss)

In [None]:
colors = pplt.Cycle('colorblind').by_key()['color']
labels = [r'$\varepsilon_x$', r'$\varepsilon_y$']

fig, ax = pplt.subplots(figsize=(4.5, 2.5))
for i in range(2):
    ax.plot(coords[4], emittances[:, i], label=labels[i], marker='.', ms=3)
    ax.axhline([eps_x, eps_y][i], color=colors[i], label=labels[i]+' (full)', 
               alpha=0.3, ls='-')
ax.format(ylabel='[mm mrad]', xlabel='w [MeV]', title='Energy slice emittances')
ax.legend(ncols=1, loc='r')
plt.savefig('_output/energy_slice_emittances.png')

In [None]:
twiss_labels = [r'$\alpha_x$', r'$\beta_x$', r'$\alpha_y$', r'$\beta_y$']
cut = 20
idx = np.arange(cut, shape[4] - cut)

fig, ax = pplt.subplots(figsize=(4.5, 2.5))
for i, _alpha in zip((0, 2), [alpha_x, alpha_y]):
    ax.plot(coords[4][idx], twiss[idx, i], label=twiss_labels[i], marker='.', ms=3)
    ax.axhline(_alpha, color=colors[i-1], label=twiss_labels[i]+' (full)', 
               alpha=0.3, ls='-')
ax.format(xlabel='w [MeV]', title='Energy slice rms alpha')
ax.legend(ncols=1, loc='r')
plt.savefig('_output/energy_slice_alphas.png')

In [None]:
fig, ax = pplt.subplots(figsize=(4.5, 2.5))
for i, _alpha, color in zip((1, 3), [beta_x, beta_y], colors[:2]):
    ax.plot(coords[4][idx], twiss[idx, i], label=twiss_labels[i], marker='.', ms=3)
    ax.axhline(_alpha, color=color, label=twiss_labels[i]+' (full)', 
               alpha=0.3, ls='-')
ax.format(xlabel='w [MeV]', title='Energy slice rms alpha')
ax.legend(ncols=1, loc='r')
plt.savefig('_output/energy_slice_betas.png')

In [None]:
n = 6
ncols = 6
offset = int(0.4 * shape[4])
for (i, j) in [(0, 1), (2, 3)]:
    f3d = utils.project(f, (i, j, 4))    
    nrows = int(np.ceil(n / ncols))
    ks = np.linspace(offset, f3d.shape[-1] - offset, n).astype(int)
    vmax = np.max(f3d[:, :, ks])
    vmin = np.min(f3d[:, :, ks])
    
    fig, axes = pplt.subplots(ncols=ncols, nrows=nrows, figwidth=9.0)
    for ax, k in zip(axes, ks):
        mplt.plot_image(
            f3d[:, :, k], x=coords[i], y=coords[j], ax=ax, 
            vmin=vmin, vmax=vmax,
        )
        ax.annotate(f'w = {coords[4][k]:.2f} [MeV]', xy=(0.02, 0.98), verticalalignment='top',
                    xycoords='axes fraction', fontsize='small', color='white')
    axes.format(xlabel=dims_units[i], ylabel=dims_units[j])
    plt.savefig(f'_output/energy_slice_proj_{dims[i]}-{dims[j]}.png')
    plt.show()