# Interpolate 5D beam density

$x_3$ = x position at VS34 \
$y_3$ = y position at VS34 \
$x_2$ = position of VT06 slit \
$y_1$ = position of HZ04 slit \
$x_1$ = position of VT04 slit

$$
\begin{aligned}
x &= x_1 \\
y &= y_1 \\
x' &= \frac{x_2 - 0.35 x_1}{s_2 - s_1} \\
y' &= \frac{y_3 - y_1}{s_3 - s_1} \\
w  &= f(x_3, x_2, x_1) \\ 
\end{aligned}
$$

In [None]:
import sys
import os
from os.path import join
import time
from datetime import datetime
import importlib
import numpy as np
import pandas as pd
import h5py
import imageio
from scipy import ndimage
from scipy import interpolate
import skimage
from tqdm import tqdm
from tqdm import trange
from matplotlib import pyplot as plt
from matplotlib import colors
import plotly.graph_objs as go
from ipywidgets import interact
import proplot as pplt

sys.path.append('../..')
from tools import energyVS06 as energy
from tools import image_processing as ip
from tools import plotting as mplt
from tools import utils
from tools.utils import project

In [None]:
pplt.rc['grid'] = False
pplt.rc['cmap.discrete'] = False
pplt.rc['cmap.sequential'] = 'viridis'

## Load 5D array 

In [None]:
folder = '.'
filenames = os.listdir(folder)
for filename in filenames:
    if filename.startswith('rawgrid'):
        print(filename)
    if filename.startswith('slit_coordinates'):
        print(filename)

In [None]:
filename = 'rawgrid_220429190854-scan-xxpy-image-ypdE.mmp'
coordfilename = 'slit_coordinates_220429190854-scan-xxpy-image-ypdE.npy'

In [None]:
shape = tuple(np.loadtxt(join(folder, 'rawgrid_shape.txt')).astype(int))
print(shape)

In [None]:
file = open(join(folder, 'im_dtype.txt'), 'r')
dtype = file.readline()
file.close()
print(dtype)

In [None]:
a5d = np.memmap(join(folder, filename), shape=shape, dtype=dtype, mode='r')
print(np.info(a5d))

Flip the y3 axis (the image comes in upside-down). (*Why flip the x3 axis? We want the beam moving out of the screen?*) Also, the y1 slit coordinates are reverse of the the y vector, so flip the y1 axis.

In [None]:
a5d = a5d[:, :, ::-1,::-1,::-1]

## View 5D array in slit-screen coordinates

NOTE: These plots do not show correlations beween the slits. For example, x1-x2 should be significantly tilted. 

### Projections 

In [None]:
dims = ['x1', 'x2', 'y1', 'y3', 'x3']
frac_thresh = 1e-5
# frac_thresh = None

In [None]:
for norm in [None, 'log']:
    axes = mplt.corner(
        a5d,
        labels=dims,
        norm=norm,
        diag_kind='None',
        prof='edges',
        prof_kws=dict(lw=1.0),
        fig_kws=dict(),
        frac_thresh=frac_thresh,
    )
    plt.savefig(f"_output/corner_log{norm == 'log'}.png")
    plt.show()

### Slices

In [None]:
dim_to_int = {dim: i for i, dim in enumerate(dims)}
int_to_dim = {i: dim for i, dim in enumerate(dims)}

In [None]:
ind = np.unravel_index(np.argmax(a5d), a5d.shape)
ind = tuple([i for i in ind])
print(ind)

In [None]:
axes_slice = [(k, j, i) for i in range(a5d.ndim) for j in range(i) for k in range(j)]
for axes in axes_slice:
    axes_not_slice = [axis for axis in range(a5d.ndim) if axis not in axes]
    a5d_slice = utils.slice_array(a5d, axes, [ind[axis] for axis in axes])
    a5d_slice = a5d_slice / np.max(a5d_slice)
    
    fig, plot_axes = pplt.subplots(ncols=2)
    for ax, norm in zip(plot_axes, [None, 'log']):
        mplt.plot_image(a5d_slice, ax=ax, frac_thresh=frac_thresh, norm=norm, colorbar=True)
    dim1, dim2 = [dims[axis] for axis in axes_not_slice]
    plot_axes.format(xlabel=dim1, ylabel=dim2)
    # Save the figure.
    string = '_output/slice_'
    for axis in axes:
        string += f'_{int_to_dim[axis]}-{ind[axis]}'
    plt.savefig(string + '.png')
    plt.show()

## Interactive 

In [None]:
cmaps = ['viridis', 'dusk_r', 'mono_r', 'grays', 'plasma', 'blues_r', 'rocket', 'mako', 'stellar_r',]

In [None]:
def update_projection(a5d, dim1='y3', dim2='x3', cfix=False, log=False, reverse=False, **plot_kws):
    if dim1 == dim2:
        raise ValueError('dim1 == dim2')
    dims = [dim1, dim2]
    norm = 'log' if log else None
    plot_kws['cmap'] = pplt.Colormap(plot_kws.pop('cmap'), reverse=reverse)
    image = project(a5d, axis=[dim_to_int[dim] for dim in dims])
    image = image / image.max()
    fig, ax = pplt.subplots()
    mplt.plot_image(image, ax=ax, profx=True, profy=True, prof_kws=dict(scale=0.15), 
                    colorbar=True, norm=norm, **plot_kws)
    ax.format(xlabel=dim1, ylabel=dim2)
    plt.show()

In [None]:
fig = go.FigureWidget()
heatmap = fig.add_heatmap()

@interact(dim1=dims, dim2=dims, cmap=cmaps, reverse=False, log=False, discrete=False)
def update(dim1='y3', dim2='x3', cmap='viridis', reverse=False, log=False, discrete=False):
    with fig.batch_update():
        heatmap.data = update_projection(
            a5d, dim1=dim1, dim2=dim2, log=log, discrete=discrete,
            cmap=cmap, frac_thresh=frac_thresh, reverse=reverse,
        )

## Coordinate transform

Load the slit coordinates 3D grid (x1-x2-y1)

In [None]:
coords_3d = np.load(coordfilename)  # [X1, X2, Y1]
coords_3d.shape

In [None]:
fig, axes = pplt.subplots(nrows=3, ncols=3, figwidth=6, spanx=False, spany=False)
for i in range(3):
    for j in range(3):
        U = coords_3d[j]
        V = coords_3d[i]
        ax = axes[i, j]
        ax.scatter(U.ravel(), V.ravel(), s=1, color='black')
        ax.axvline(np.mean(U), color='red', alpha=0.15)
        ax.axhline(np.mean(V), color='red', alpha=0.15)
    axes[i, 0].format(ylabel=dims[i])
    axes[-1, i].format(xlabel=dims[i])
plt.show()

In [None]:
y1_gv = coords_3d[2, 0, 0, :]
y3_gv = np.arange(shape[3])
x3_gv = np.arange(shape[4])
gvs = [None, None, y1_gv, y3_gv, x3_gv]

In [None]:
# kws = dict(color='black', lw=0.25, alpha=1)
# fig, axes = pplt.subplots(nrows=5, ncols=5, figwidth=10.0, spanx=False, spany=False)
# for i in range(5):
#     for j in range(5):
#         u, v, ax = gvs[j], gvs[i], axes[i, j]
#         for gv in gvs[j]:
#             ax.axvline(gv, **kws)
#         for gv in gvs[i]:
#             ax.axhline(gv, **kws)
#     axes[i, 0].format(ylabel=dims[i])
#     axes[-1, i].format(xlabel=dims[i])
# plt.show()

Copy the grids into new dimensions.

In [None]:
X1, X2, Y1 = coords_3d
X1 = utils.copy_into_new_dim(X1, shape[3:], axis=-1)
X2 = utils.copy_into_new_dim(X2, shape[3:], axis=-1)
Y1 = utils.copy_into_new_dim(Y1, shape[3:], axis=-1)

In [None]:
print('X1.shape =', X1.shape)
print('X2.shape =', X2.shape)
print('Y1.shape =', Y1.shape)

In [None]:
Y3, X3 = np.meshgrid(y3_gv, x3_gv, indexing='ij')
Y3 = utils.copy_into_new_dim(Y3, shape[:3], axis=0)
X3 = utils.copy_into_new_dim(X3, shape[:3], axis=0)

In [None]:
print('Y3.shape =', Y3.shape)
print('X3.shape =', X3.shape)

Make list of center coordinates. `coords_` has coordinates in raw slit values.

In [None]:
X1 = X1 - np.mean(X1)
X2 = X2 - np.mean(X2)
Y1 = Y1 - np.mean(Y1)
Y3 = Y3 - np.mean(Y3)
X3 = X3 - np.mean(X3)
coords_ = [X1, X2, Y1, Y3, X3]

In [None]:
for i, dim in enumerate(dims):
    print('dim =', dim)
    U = coords_[i]
    axes = [k for k in range(U.ndim) if k != i]
    idx = utils.make_slice(U.ndim, axes, ind=[0, 0, 0, 0])
    print(U[idx])
    print()

Build the transfer matrices between slits and the screen.

In [None]:
a2mm = 1.009  # assume same as first dipole
amp2meter = a2mm * 1e3
rho = 0.3556
GL05 = 0
GL06 = 0.0
l1 = 0.0
l2 = 0.0
l3 = 0.774
L2 = 0.311  # slit2 to dipole face
l = 0.129  # dipole face to VS06 screen (assume same for first/last dipole-screen)
LL = l1 + l2 + l3 + L2  # distance from emittance plane to dipole entrance

ecalc = energy.EnergyCalculate(l1=l1, l2=l2, l3=l3, L2=L2, l=l, amp2meter=amp2meter)

Mslit = ecalc.getM1()  # slit-slit
Mscreen = ecalc.getM()  # slit-screen

Compute x', y', and energy w.

In [None]:
Y = Y1.copy()
YP = ecalc.calculate_yp(Y.ravel() * 1e-3, Y3.ravel() * 1e-3, Mscreen)
YP = YP.reshape(shape)

X = X1.copy()
XP = ecalc.calculate_xp(X.ravel() * 1e-3, X2.ravel() * 1e-3, Mslit)
XP = XP.reshape(shape)

W = ecalc.calculate_dE_screen(X3.ravel() * 1e-3, 0.0, X.ravel() * 1e-3, XP.ravel(), Mscreen)
W = W.reshape(shape)

# Convert from m-rad to mm-mrad
YP *= 1e3
XP *= 1e3
W *= 1e3

In [None]:
X = X.copy()
Y = Y.copy()
XP = XP.copy()
YP = YP.copy()
W = W.copy()
# del(X1, X2, Y1, X3, Y3)

Make list of centered phase space coordinate grids.

In [None]:
coords = [X, XP, Y, YP, W]
for coord in tqdm(coords):
    coord = coord - np.mean(coord)

Put energy on evenly-spaced grid.

In [None]:
pdims = ["x [mm]", "x' [mrad]", "y [mm]", "y' [mrad]", "w [keV?]"]

Test interpolation.

In [None]:
axis_view = (0, 1)
method = '1D'
kws = dict(fill_value=0.0, method='linear')

H = utils.project(a5d, axis_view)
H = H / np.max(H)

axis_slice = [i for i in range(a5d.ndim) if i not in axis_view]
idx = utils.make_slice(a5d.ndim, axis_slice, ind=[0, 0, 0])
U = coords[axis_view[0]][idx]
V = coords[axis_view[1]][idx]
u_new = np.linspace(np.min(U), np.max(U), shape[axis_view[0]])
v_new = np.linspace(np.min(V), np.max(V), shape[axis_view[1]])

def interp2d(meth, **kws):
    if meth == '2D':
        points = (U.ravel(), V.ravel())
        values = H.ravel()  
        U_new, V_new = np.meshgrid(u_new, v_new, indexing='ij')
        new_points = (U_new.ravel(), V_new.ravel())
        H_new = interpolate.griddata(points, values, new_points, **kws)
        H_new = H_new.reshape(len(u_new), len(v_new))
    elif meth == '1D':
        H_new = np.zeros(H.shape)
        grids = (U, V)
        new_points = (u_new, v_new)
        for j in range(H_new.shape[1]):
            H_new[:, j] = interpolate.griddata(grids[0][:, j], H[:, j], new_points[0], **kws)
        for i in range(H_new.shape[1]):
            H_new[i, :] = interpolate.griddata(grids[1][i, :], H[i, :], new_points[1], **kws)
    H_new[H_new < 0] = 0
    return H_new


fig, axes = pplt.subplots(ncols=3)
pkws = dict(colorbar=True, norm='log', frac_thresh=1e-5)
mplt.plot_image(H / np.max(H), x=U.T, y=V.T, ax=axes[0], **pkws)
for ax, meth in zip(axes[1:], ['1D', '2D']):
    H_new = interp2d(meth, **kws)
    mplt.plot_image(H_new / np.max(H_new), x=U_new.T, y=V_new.T, ax=ax, **pkws)
    ax.contour(U.T, V.T, H.T, color='white', lw=1, alpha=0.4, label='original')
axes.format(
    xlabel=pdims[axis_view[0]],
    ylabel=pdims[axis_view[1]],
    toplabels=['Measured', 'interp 1D', 'interp 2D']
)
plt.show()

Gridding the x-x'-w distribution for each y-y' is possible but would take 120 hours. We will try gridding w distribution first.

In [None]:
new_shape = list(shape)
new_shape[0] = 40
new_shape = tuple(new_shape)
print(new_shape)

In [None]:
x_gv_new = np.linspace(np.min(X), np.max(X), new_shape[0])
xp_gv_new = np.linspace(np.min(XP), np.max(XP), new_shape[1])
w_gv_new = np.linspace(np.min(W), np.max(W), new_shape[4])

In [None]:
new_points = tuple([G.ravel() for G in np.meshgrid(x_gv_new, xp_gv_new, w_gv_new, indexing='ij')])
points = (
    coords[0][:, :, 0, 0, :].ravel(),
    coords[1][:, :, 0, 0, :].ravel(),
    coords[4][:, :, 0, 0, :].ravel(),
)

In [None]:
a5d_new = np.zeros(new_shape)
for k in trange(a5d.shape[2]):
    for l in trange(a5d.shape[3]):   
        values = a5d[:, :, k, l, :].ravel()
        dens3d = interpolate.griddata(
            points,
            values,
            new_points,
            fill_value=0.0,
        )
        a5d_new[:, :, k, l, :] = dens3d.reshape((new_shape[0], new_shape[1], new_shape[4]))