# Step 4: visualization

In [None]:
import sys
import os
from os.path import join
import time
from datetime import datetime
import importlib
import numpy as np
import pandas as pd
import h5py
import imageio
from scipy import ndimage
from scipy import interpolate
import skimage
from tqdm import tqdm
from tqdm import trange
from matplotlib import pyplot as plt
from matplotlib import colors
import plotly.graph_objs as go
from ipywidgets import interact
from ipywidgets import interactive
from ipywidgets import interactive_output
from ipywidgets import widgets
from IPython.display import display
from IPython.display import clear_output
import proplot as pplt

sys.path.append('../..')
from tools import energyVS06 as energy
from tools import image_processing as ip
from tools import plotting as mplt
from tools import utils
from tools.utils import project

In [None]:
pplt.rc['grid'] = False
pplt.rc['cmap.discrete'] = False
pplt.rc['cmap.sequential'] = 'viridis'
pplt.rc['figure.facecolor'] = 'white'

## Load data 

In [None]:
folder = '.'

In [None]:
shape = tuple(np.loadtxt(join(folder, 'a5d_new_shape.txt')).astype(int))
shape

In [None]:
coords = utils.load_stacked_arrays('coords.npz')
for c in coords:
    print(c.shape)

In [None]:
f = np.memmap(join(folder, 'a5d_new.mmp'), shape=shape, dtype='float', mode='r')

In [None]:
f.shape

In [None]:
print(f'min(f) = {np.min(f)}')
f = np.clip(f, 0.0, None)

## Static plots 

In [None]:
dims = ["x", "xp", "y", "yp", "w"]
units = ["mm", "mrad", "mm", "mrad", "MeV"]
dims_units = [f'{d} [{u}]' for d, u in zip(dims, units)]
dim_to_int = {dim: i for i, dim in enumerate(dims)}

### Projections

In [None]:
for norm in [None, 'log']:
    axes = mplt.corner(
        f[:, :, :, :, ::-1],
        coords=coords,
        fig_kws=dict(figwidth=1.5*4, space=1),
        diag_kind='None',
        prof='edges',
        prof_kws=dict(alpha=0.5, lw=0.9),
        labels=dims_units,
        norm=norm,
        frac_thresh=1e-6,
    )
    plt.savefig(f'_output/corner_int_norm{norm}.png')
    plt.show()

### Slices

In [None]:
ind = np.unravel_index(np.argmax(f), f.shape)
ind = tuple([i for i in ind])
print(ind)

In [None]:
frac_thresh = 1e-5
axes_slice = [(k, j, i) for i in range(f.ndim) for j in range(i) for k in range(j)]
axes_view = [tuple([i for i in range(f.ndim) if i not in axis])
             for axis in axes_slice]
for axis, axis_view in zip(axes_slice, axes_view):
    idx = utils.make_slice(5, axis, [ind[i] for i in axis])
    f_slice = f[idx]
    f_slice = f_slice / np.max(f_slice)
    
    dim1, dim2 = [dims[i] for i in axis_view]
    
    fig, plot_axes = pplt.subplots(ncols=2)
    for ax, norm in zip(plot_axes, [None, 'log']):
        mplt.plot_image(f_slice, x=coords[axis_view[0]], y=coords[axis_view[1]], ax=ax, frac_thresh=frac_thresh, norm=norm, colorbar=True)
    plot_axes.format(xlabel=dim1, ylabel=dim2)
    string = '_output/slice_'
    for i in axis:
        string += f'{dims[i]}-{ind[i]}'
    _dims = [dims[i] for i in axis]
    _units = [units[i] for i in axis]
    _vals = [coords[i][ind[i]] for i in axis]
    plot_axes.format(suptitle=f'{_dims[0]} = {_vals[0]:.2f} [{_units[0]}],  {_dims[1]} = {_vals[1]:.2f} [{_units[1]}],  {_dims[2]} = {_vals[2]:.2f} [{_units[2]}]')
    plot_axes.format(suptitle_kw=dict(fontweight='normal'))
    plt.savefig(string + '.png')
    plt.show()

## Interactive plots

In [None]:
cmaps = ['viridis', 'dusk_r', 'mono_r', 'plasma']

Pre-compute the full 2D projections.

In [None]:
projections = []
for i in trange(5):
    projections.append([])
    for j in range(5):
        projections[i].append(utils.project(f, (i, j)))

### Projections 

In [None]:
def update_projection(dim1, dim2, cmap, thresh, log, discrete, contour, profiles):
    if dim1 == dim2:
        return
    i, j = [dim_to_int[dim] for dim in [dim1, dim2]]
    H = projections[i][j]
    
    fig, ax = pplt.subplots()
    mplt.plot_image(
        H / np.max(H), x=coords[i], y=coords[j], ax=ax, 
        norm='log' if log else None,
        frac_thresh=10.0**thresh,
        discrete=discrete,
        cmap=cmap,
        colorbar=True,
        profx=profiles,
        profy=profiles,
        prof_kws=dict(lw=1.0, alpha=0.75, color='white', scale=0.15),
        contour=contour,
        contour_kws=None,
    )
    ax.format(xlabel=dims_units[i], ylabel=dims_units[j])
    plt.show()

In [None]:
def update_projection_panels(dim1, dim2, cmap, thresh, log, discrete, contour, profiles):
    """Plots projections on panel axes (top/right)."""
    if dim1 == dim2:
        return
    i, j = [dim_to_int[dim] for dim in [dim1, dim2]]
    x = coords[i]
    y = coords[j]
    H = projections[i][j] / np.max(projections[i][j])
    
    fig, ax = pplt.subplots()
    if profiles:
        px = np.sum(H, axis=1)
        py = np.sum(H, axis=0)
        paxes = [ax.panel_axes(loc, space=0, width='3em') for loc in 'tr']
        for pax in paxes:
            pax.format(xspineloc='neither', yspineloc='neither')
        kw = dict(color='black', lw=1.0)
        paxes[0].plot(x, px, **kw)
        paxes[1].plotx(y, py, **kw)
    ax, mesh = mplt.plot_image(
        H, x=x, y=y, ax=ax, 
        frac_thresh=10.0**thresh,
        contour=contour,
        contour_kws=None,
        return_mesh=True,
        norm='log' if log else None,
        discrete=discrete,
        cmap=cmap,
    )
    if colorbar:
        space = 2 if profiles else None
        ax.colorbar(mesh, space=space)

    ax.format(xlabel=dims_units[i], ylabel=dims_units[j])
    plt.show()

In [None]:
dim1 = widgets.Dropdown(options=dims, index=2, description='dim 1')
dim2 = widgets.Dropdown(options=dims, index=3, description='dim 2')
cmap = widgets.Dropdown(options=cmaps, description='cmap')
thresh = widgets.FloatSlider(value=-5.0, min=-8.0, max=0.0,  
                             step=0.1, description='thresh')
discrete = widgets.Checkbox(value=False, description='discrete')
log = widgets.Checkbox(value=False, description='log')
contour = widgets.Checkbox(value=False, description='contour')
profiles = widgets.Checkbox(value=True, description='profiles')

In [None]:
# kwargs = dict(dim1=dim1, dim2=dim2, cmap=cmap,
#               thresh=thresh, log=log, discrete=discrete, contour=contour, profiles=profiles)
# interactive(update_projection, **kwargs)

### Corner (compact) 

In [None]:
ranges = []
range_sliders = []
for i in range(5):
    xmin, xmax = np.min(coords[i]), np.max(coords[i])
    delta = np.diff(coords[i])[0]
    _range = [xmin - 0.5 * delta, xmax + 0.5 * delta]
    ranges.append(_range)
    slider = widgets.FloatRangeSlider(
        value=_range,
        min=_range[0],
        max=_range[1],
        step=delta,
        description=f'{dims[i]}',
    )
    range_sliders.append(slider)

In [None]:
# prof_kws = dict(color='white', lw=1.0, alpha=0.75)
prof_kws = dict(lw=1.0, alpha=0.75, color='red')
plot_kws = dict(ec='None')

In [None]:
%%capture

n = 5
fig, axes = pplt.subplots(nrows=n-1, ncols=n-1, spanx=False, spany=False, 
                          aligny=True, figwidth=1.5*(n-1), space=1.1)
meshes = [[], [], [], []]
fx_lines = [[], [], [], []]
fy_lines = [[], [], [], []]
for i in range(n - 1):
    axes[0, i].format(xlabel=dims_units[i])
    axes[i, 0].format(ylabel=dims_units[i + 1])
    for j in range(n - 1):
        ax = axes[i, j]
        if j > i:
            ax.axis('off')
            continue
        x = coords[j]
        y = coords[i + 1]

        H = utils.project(f, (j, i + 1))
        mesh = ax.pcolormesh(x, y, H.T, **plot_kws)
        meshes[i].append(mesh)
        
        fx_line, = ax.plot([], [], **prof_kws)
        fx_lines[i].append(fx_line)
        
        fy_line, = ax.plotx([], [], **prof_kws)
        fy_lines[i].append(fy_line)
        
        for item in [fx_line, fy_line]:
            item.set_visible(False)
            
        
def update(xrange, xprange, yrange, yprange, wrange, prof=True):
    # Slice the array.
    print('Making slice...')
    ranges = [xrange, xprange, yrange, yprange, wrange]
    mask = np.full(f.shape, False)
    for i, (umin, umax) in enumerate(ranges):
        idx = 5 * [slice(None)]
        idx[i] = np.logical_or(coords[i] < umin, coords[i] > umax)
        mask[tuple(idx)] = True
    f_slice = np.ma.masked_where(mask, f)
        
    print('Updating display...')
    for i in range(n - 1):
        for j in range(i + 1):
            H = utils.project(f_slice, (j, i + 1))
            H = H / np.max(H)
            meshes[i][j].set_array(H.T)
            meshes[i][j].set_norm(colors.Normalize(np.min(H), np.max(H)))
            if prof and i == n - 2:
                x = coords[j]
                y = coords[i + 1]
                fx = np.sum(H, axis=1)
                scale = 0.15
                fx = y[0] + scale * np.abs(y[-1] - y[0]) * fx / np.max(fx)
                fx_lines[i][j].set_data(coords[j], fx)
                fx_lines[i][j].set_visible(True)
    display(fig)
    clear_output(wait=True)

In [None]:
kwargs = dict(xrange=range_sliders[0], xprange=range_sliders[1], yrange=range_sliders[2], 
              yprange=range_sliders[3], wrange=range_sliders[4],
              cmap=cmap, thresh=thresh, discrete=discrete, log=log, 
              contour=contour, profiles=profiles)
widgets.interact(update, **kwargs)

### Corner

In [None]:
# %%capture

prof_kws = dict(color='black', lw=1.0)
meshes = [[], [], [], [], []]
prof_lines = []

fig, axes = pplt.subplots(nrows=5, ncols=5, spanx=False, spany=False, 
                          sharex=1, sharey=1, aligny=True, figwidth=1.5*5, space=1.1)
for i in range(5):
    axes[i, i].format(ylim=(-0.01, 1.25), yticks=[])
for ax, label in zip(axes[-1, :], dims_units):
    ax.format(xlabel=label)
for ax, label in zip(axes[1:, 0], reversed(dims_units[1:])):
    ax.format(ylabel=label)
for i in range(5):
    for j in range(5):
        ax = axes[i, j]
        if j > i:
            ax.axis('off')
            continue
        if j > 0:
            ax.format(yticklabels=[])
        if i < 4:
            ax.format(xticklabels=[])
        x = coords[j]
        y = coords[i]
        if i == j:
            prof = utils.project(f, i)
            prof_line, = ax.plot(x, prof, **prof_kws)
            prof_lines.append(prof_line)
        else:
            H = utils.project(f, (j, i))
            mesh = ax.pcolormesh(x, y, H.T, **plot_kws)
            meshes[i].append(mesh)            
        
def update(xrange, xprange, yrange, yprange, wrange):
    print('Making slice...')
    ranges = [xrange, xprange, yrange, yprange, wrange]
    mask = np.full(f.shape, False)
    for i, (umin, umax) in enumerate(ranges):
        idx = 5 * [slice(None)]
        idx[i] = np.logical_or(coords[i] < umin, coords[i] > umax)
        mask[tuple(idx)] = True
    f_slice = np.ma.masked_where(mask, f)
    
    print('Updating display...')
    for i in range(5):
        for j in range(i + 1):
            if i == j:
                prof = utils.project(f_slice, i)
                prof = prof / np.max(prof)
                prof_lines[i].set_data(coords[i], prof)
            else:
                H = utils.project(f_slice, (j, i))
                H = H / np.max(H)
                meshes[i][j].set_array(H.T)
                meshes[i][j].set_norm(colors.Normalize(np.min(H), np.max(H)))
    display(fig)
    # clear_output(wait=True)

In [None]:
update(
    xrange=ranges[0], 
    xprange=ranges[1],
    yrange=ranges[2], 
    yprange=ranges[3], 
    wrange=(-0.02, 0.02),
)

In [None]:
# kwargs = dict(xrange=range_sliders[0], xprange=range_sliders[1], yrange=range_sliders[2], 
#               yprange=range_sliders[3], wrange=range_sliders[4],
#               cmap=cmap, thresh=thresh, discrete=discrete, log=log, 
#               contour=contour)
# widgets.interact(update, **kwargs)

## Emittance vs. energy

In [None]:
def sigma(x, xp, H):
    np.