# Step 1

* Load scalar, waveform and image h5 files.
* Interpolate each image onto y'-w grid.
* For each sweep, interpolate y-y'-w image onto y grid.
* For each y-y'-w image pixel, interpolate onto x-x' grid.

In [None]:
from datetime import datetime
import importlib
import itertools
import json
import numpy as np
import os
from pprint import pprint
import sys

import h5py
import imageio
from ipywidgets import interactive
from ipywidgets import widgets
from matplotlib import patches
from matplotlib import pyplot as plt
import pandas as pd
from plotly import graph_objects as go
import proplot as pplt
from scipy import ndimage
from scipy import interpolate
import skimage
from tqdm.notebook import tqdm
from tqdm.notebook import trange

sys.path.append('/home/46h/btf-data-analysis/')
from tools import image_processing as ip
from tools import utils
from tools.energyVS06 import EnergyCalculate

sys.path.append('/home/46h/ps-dist/')
from psdist import plotting as mplt

In [None]:
pplt.rc['cmap.discrete'] = False
pplt.rc['cmap.sequential'] = 'viridis'
pplt.rc['grid'] = False
pplt.rc['savefig.dpi'] = 300

## Load data

In [None]:
folder = '_output/'

In [None]:
info = utils.load_pickle(os.path.join(folder, 'info.pkl'))
print('info')
pprint(info)

In [None]:
filename = info['filename']
file = h5py.File(os.path.join(folder, 'preproc-' + filename + '.h5'), 'r')
data_sc = file['/scalardata']
data_wf = file['/wfdata']
data_im = file['/imagedata']

print('Attributes:')
print()
for data in [data_sc, data_wf, data_im]:
    print(data.name)
    for item in data.dtype.fields.items():
        print(item)
    print()

In [None]:
def save(figname, path=folder, prefix='fig_step1', ext='png', **kws):
    figname = f'{prefix}_{figname}'
    if ext:
        figname = f'{figname}.{ext}'
    plt.savefig(os.path.join(path, figname), **kws)

## Scan overview

### Data collection frequency

In [None]:
duration = data_sc[-1, 'timestamp'] - data_sc[0, 'timestamp']
iteration_duration = duration / data_sc[-1, 'iteration']
points_per_iteration = len(data) / data_sc[-1, 'iteration']
print(f'{len(data)} points recorded over {duration:.1f} seconds ({(duration / 3600.0):.1f} hours)')
print(f"Number of iterations: {data_sc[-1, 'iteration']}")
print(f'Effective rep rate: {(len(data) / duration):.2f} Hz')
print(f'Time per iteration: {iteration_duration:.2f} seconds')
print(f'Points per iteration: {points_per_iteration:.2f}')

Look for long pauses during data collection.

In [None]:
timestamps = data_sc['timestamp'][:]
dt = np.diff(timestamps)
rep_rate = 1.0 / np.median(dt)
print(f'reprate = {rep_rate:.2f} Hz')

print('Pauses longer than 30 seconds:')
long_pause = 30.0
pprint(dt[dt > long_pause])
dt[dt > long_pause] = 0.2

hist, bins = np.histogram(dt, bins=21)
idx_bins = np.digitize(dt, bins)
idx_pause, = np.where(idx_bins > 1)
median_pause = np.median(dt[idx_pause])
print(f'Most pauses are {median_pause:.2f} seconds')

fig, ax = pplt.subplots()
ax.bar(0.5 * (bins[1:] + bins[:-1]), hist, color='black', alpha=0.3)
ax.axvline(median_pause, color='black')
ax.format(xlabel='Pause length [seconds]', ylabel='Number of points', yscale='log')
save('pauses')

### Camera integral and saturation

In [None]:
cam = info['cam']
saturation = data_sc[info['cam'] + '_Saturation'][:]

fig, ax = pplt.subplots(figsize=(8.0, 2.0))
ax.plot(saturation, color='lightgrey', lw=0.8)
ax.format(ylabel='saturation', xlabel='Point', ylim=(ax.get_ylim()[0], 1.0))
save(f'{cam}_saturation')

Camera integral: Define `signal` as the camera integral normalized to the range [0, 1] (minimium is subtracted first). By tuning `thresh` and observing the signal in logarithmic scale, we can estimate the dynamic range.

In [None]:
signal = data_sc[cam + '_Integral'][:]
signal = signal - np.min(signal)
signal = signal / np.max(signal)
thresh = 0.0004  # fraction of max signal
valid, = np.where(signal >= thresh)
invalid, = np.where(signal < thresh)

print(f'Fractional signal thresh = {thresh}')
print(f'Fraction of points above thresh: {len(valid) / len(signal)}')
for yscale in [None, 'log']:
    fig, ax = pplt.subplots(figsize=(8.0, 2.35))
    ax.plot(signal, color='lightgrey', lw=0.8)
    ax.plot(valid, signal[valid], lw=0, marker='.', ms=2, alpha=1, ec='None', color='black')
    ax.format(yscale=yscale, ylabel='Signal', xlabel='Point')
    save(f'signal_thresh_{yscale}')

## Setup interpolation grids

Input `nsteps`, the number of x1, x2, y1 steps during the scan. This does not need to be exact. (It will not be exact for y1 since this is the sweeping variable.)

In [None]:
acts = ['x_PositionSync', 'xp_PositionSync', 'y_PositionSync']
nsteps = [64, 64, 64]

Generate `points`, a list of {x1, x2, y1} slit coordinates.

In [None]:
points = np.vstack([data_sc[act] for act in acts]).T
print('points.shape =', points.shape)

Convert to the beam-frame coordinates.

In [None]:
# VT04/VT06 are same sign as x_beam; VT34a and VT34b are opposite. This
# is due to the 180-degree bend in the BTF lattice.
if cam.lower() == 'cam34':
    points[:, :2] = -points[:, :2]

# The horizontal slit coordinate y1 changes sign (slit inserted from above).
points[:, 2] = -points[:, 2]

# The screen coordinates (x3, y3) also change sign (right hand rule).
image_shape = info['image_shape']
x3grid = -np.arange(image_shape[1]) * info['image_pix2mm_x'] 
y3grid = -np.arange(image_shape[0]) * info['image_pix2mm_y']

Build the transfer matrices between the slits and the screen. 

In [None]:
dipole_current = 0.0  # deviation of dipole current from nominal
l = 0.129  # dipole face to screen (assume same for first/last dipole-screen)
metadata = info['metadata']
if cam.lower() == 'cam06':
    GL05 = 0.0  # QH05 integrated field strength (1 [A] = 0.0778 [Tm])
    GL06 = 0.0  # QH06 integrated field strength (1 [A] = 0.0778 [Tm])
    l1 = 0.280  # slit1 to QH05 center
    l2 = 0.210  # QH05 center to QV06 center
    l3 = 0.457  # QV06 center to slit2
    L2 = 0.599  # slit2 to dipole face    
    rho_sign = +1.0  # dipole bend radius sign
    if GL05 == 0.0 and metadata['BTF_MEBT_Mag:PS_QH05:I_Set'] != 0.0:
        print('Warning: QH05 is turned on according to metadata.')
    if GL05 != 0.0 and metadata['BTF_MEBT_Mag:PS_QH05:I_Set'] == 0.0:
        print('Warning: QH05 is turned off according to metadata.')
    if GL06 == 0.0 and metadata['BTF_MEBT_Mag:PS_QV06:I_Set'] != 0.0:
        print('Warning: QH06 is turned on according to metadata.')
    if GL06 != 0.0 and metadata['BTF_MEBT_Mag:PS_QV06:I_Set'] == 0.0:
        print('Warning: QH06 is turned off according to metadata.')
elif cam.lower() == 'cam34':
    GL05 = 0.0  # QH05 integrated field strength
    GL06 = 0.0  # QH06 integrated field strength
    l1 = 0.000  # slit1 to QH05 center
    l2 = 0.000  # QH05 center to QV06 center
    l3 = 0.774  # QV06 center to slit2
    L2 = 0.311  # slit2 to dipole face
    # Weird... I can only get the right answer for energy if I *do not* flip rho,
    # x1, x2, and x3. I then flip x and xp at the very end.
    rho_sign = +1.0  # dipole bend radius sign
    x3grid = -x3grid
    points[:, :2] = -points[:, :2]
LL = l1 + l2 + l3 + L2  # distance from emittance plane to dipole entrance
ecalc = EnergyCalculate(l1=l1, l2=l2, l3=l3, L2=L2, l=l, rho_sign=rho_sign)
Mslit = ecalc.getM1(GL05=GL05, GL06=GL06)  # slit-slit
Mscreen = ecalc.getM(GL05=GL05, GL06=GL06)  # slit-screen

Convert to x'.

In [None]:
points[:, 1] = 1.0e3 * ecalc.calculate_xp(points[:, 0] * 1.0e-3, points[:, 1] * 1.0e-3, Mslit) 

Center points at zero.

Make grids.

In [None]:
scales = [1.1, 1.6, 1.1]  # grid resolution relative to `nsteps`
points -= np.mean(points, axis=0)
mins = np.min(points, axis=0)
maxs = np.max(points, axis=0)
ns = np.multiply(scales, np.array(nsteps) + 1).astype(int)
xgrid, xpgrid, ygrid = [np.linspace(umin, umax, n) for (umin, umax, n) in zip(mins, maxs, ns)]

# Compute y' for each y-y3.
YP = np.zeros((len(ygrid), len(y3grid)))
for k, y in enumerate(ygrid):
    YP[k] = 1e3 * ecalc.calculate_yp(1e-3 * y, 1e-3 * y3grid, Mscreen)
ypgrid = np.linspace(np.min(YP), np.max(YP), int(1.1 * len(y3grid)))

# Compute w (energy) for each x-xp.
W = np.zeros((len(xgrid), len(xpgrid), len(x3grid)))
for i, x in enumerate(xgrid):
    for j, xp in enumerate(xpgrid):
        W[i, j] = ecalc.calculate_dE_screen(1e-3 * x3grid, dipole_current, 1e-3 * x, 1e-3 * xp, Mscreen)
wgrid = np.linspace(np.min(W), np.max(W), int(1.1 * len(x3grid)))

## Interpolate 

In [None]:
iterations = data_sc['iteration'].copy()  # iteration == sweep
iteration_nums = np.unique(iterations)
n_iterations = len(iteration_nums)
kws = dict(kind='linear', copy=True, bounds_error=False, fill_value=0.0, assume_sorted=False)

### Interpolate y

In [None]:
images_yy3x3 = []
for iteration in tqdm(iteration_nums):
    idx, = np.where(iterations == iteration)
    _points = points[idx, 2]
    _values = data_im[idx, cam + '_Image'].reshape((len(idx), len(y3grid), len(x3grid)))
    _, uind = np.unique(_points, return_index=True)
    fint = interpolate.interp1d(_points[uind], _values[uind], axis=0, **kws)
    images_yy3x3.append(fint(ygrid))

### Interpolate y'

In [None]:
images_yypx3 = []
for image_yy3x3 in tqdm(images_yy3x3):
    image_yypx3 = np.zeros((len(ygrid), len(ypgrid), len(x3grid)))
    for k in range(len(ygrid)):
        _points = YP[k]
        _values = image_yy3x3[k, :, :]
        fint = interpolate.interp1d(_points, _values, axis=0, **kws)
        image_yypx3[k, :, :] = fint(ypgrid)
    images_yypx3.append(image_yypx3)
del(images_yy3x3)

### Interpolate w

In [None]:
XXP = []
images_yypw = []
for iteration, image_yypx3 in enumerate(tqdm(images_yypx3), start=1):
    x, xp = np.mean(points[iterations == iteration, :2], axis=0)
    _points = ecalc.calculate_dE_screen(1.0e-3 * x3grid, dipole_current, 1.0e-3 * x, 1.0e-3 * xp, Mscreen)
    _values = image_yypx3
    fint = interpolate.interp1d(_points, _values, axis=-1, **kws)
    images_yypw.append(fint(wgrid))
    XXP.append([x, xp])
del(images_yypx3)
XXP = np.array(XXP)
images_yypw = np.array(images_yypw)

### Interpolate x-x'

Since we moved in vertical lines in the x-xp plane, we could separate the x and xp interpolations. If we moved in diagonal lines in the x-xp plane, we would need to use 2D interpolation The variable `xxp_interp` determines which method to use.

In [None]:
xxp_interp = '1D'  # {'1D', '2D'}

Group the iterations by x step. Loop through each iteration and check if x has changed significantly. (Within each x step, x should only change by a small amount due to noise in the readback.) Find a good cutoff `max_abs_delta`.

In [None]:
max_abs_delta = 0.03  # Max absolute change in x within a group.
X, steps = [], []
x_last = np.inf
for iteration in trange(1, n_iterations + 1):
    x, xp = XXP[iteration - 1]
    if np.abs(x - x_last) > max_abs_delta:
        X.append(x)
        steps.append([])
    steps[-1].append(iteration)
    x_last = x

fig, ax = pplt.subplots(figsize=(4, 2))
ax.hist(np.abs(np.diff(points[:, 0])), bins=75, color='black')
ax.axvline(max_abs_delta, color='red')
ax.format(yscale='log', xlabel=r'$\Delta x$ [mm]', ylabel='Number of steps')
save('scan_delta_x')
plt.show()

fig, ax = pplt.subplots(figsize=(4, 2))
for _iterations in steps:
    _idx = np.array(_iterations) - 1
    ax.scatter(XXP[_idx][:, 0], XXP[_idx][:, 1], s=1)
    ax.format(xlabel="x [mm]", ylabel="x' [mrad]")
save('scan_x_groups')
plt.show()

Interpolate:

In [None]:
shape = (len(xgrid), len(xpgrid), len(ygrid), len(ypgrid), len(wgrid))
f = np.memmap(
    os.path.join(folder, f'f_{filename}.mmp'), 
    dtype='float', 
    mode='w+', 
    shape=shape,
)

In [None]:
if xxp_interp == '1D':
    # Interpolate x'-y-y'-w image along x'
    print("Interpolating xp.")
    images_xpyypw = []
    for _iterations in tqdm(steps):
        idx = np.array(_iterations) - 1
        _points = XXP[idx, 1]
        _values = images_yypw[idx]
        fint = interpolate.interp1d(_points, _values, axis=0, **kws)
        images_xpyypw.append(fint(xpgrid))
    del(images_yypw)

    # Interpolate the x-x'-y-y'-w image stack onto a regular x grid. 
    print("Interpolating x.")
    
    # Could get memory errors for large arrays. In this case, break down 
    # into smaller interpolations. If `n_loop = 1`, interpolate the x-y-y'-w
    # image for each x'; if `n_loop == 2`, interpolate the x-y'-w image for 
    # each x'-y; etc. If `n_loop == 0`, interpolate the x-x'-y-y'-w image 
    # directly.
    n_loop = 2
    _points = X
    if n_loop == 0:
        _values = images_xpyypw
        fint = interpolate.interp1d(_points, _values, axis=0, **kws)
        f[:, j, k, l, m] = fint(xgrid)
    else:
        images_xpyypw = np.array(images_xpyypw)
        axis = list(range(1, n_loop + 1))
        ranges = [range(s) for s in shape[1: n_loop + 1]]
        for ind in tqdm(itertools.product(*ranges)):
            idx = utils.make_slice(5, axis=axis, ind=ind)
            _values = images_xpyypw[idx]
            fint = interpolate.interp1d(_points, _values, axis=0, **kws)
            f[idx] = fint(xgrid) 
        del(images_xpyypw)
else:
    # 2D interpolation of x-x' for each {y, y', w}.
    print("Interpolating x-xp.")
    _points = XXP
    _new_points = utils.get_grid_coords(xgrid, xpgrid)
    for k in trange(shape[2]):
        for l in trange(shape[3]):
            for m in range(shape[4]):
                _values = images_yypw[:, k, l, m]
                new_values = interpolate.griddata(
                    _points,
                    _values,
                    _new_points,
                    method='linear',
                    fill_value=False,
                )
                f[:, :, k, l, m] = new_values.reshape((shape[0], shape[1]))

## Shutdown

Hack: flip x-x' if we are at Cam34.

In [None]:
if cam.lower() == 'cam34':
    ## This may give a memory error...
    f[:, :, :, :, :] = f[::-1, ::-1, :, :, :] 
    
    ## ...but this should not.
    # for k in trange(shape[2]):
    #     f[:, :, k, :, :] = f[::-1, ::-1, k, :, :]

Save the grid coordinates.

In [None]:
coords = [xgrid, xpgrid, ygrid, ypgrid, wgrid]
coords = [c.copy() - np.mean(c) for c in coords]
utils.save_stacked_array(os.path.join(folder, f'coords_{filename}.npz'), coords)

Briefly examine the interpolated array.

In [None]:
dims = ["x", "x'", "y", "y'", "w"]
units = ['mm', 'mrad', 'mm', 'mrad', 'MeV']
dims_units = [f'{dim} [{unit}]' for dim, unit in zip(dims, units)]
prof_kws = dict(kind='step')

In [None]:
mplt.interactive_proj1d(f, coords=coords, dims=dims, units=units)

In [None]:
mplt.interactive_proj2d(f, coords=coords, dims=dims, units=units, prof_kws=prof_kws)

Write changes to the memory map.

In [None]:
f.flush()

Save info.

In [None]:
info['dims'] = dims
info['units'] = units
info['int_shape'] = shape
utils.save_pickle(os.path.join(folder, 'info.pkl'), info)

Save static html of this notebook.

In [None]:
os.system(f"jupyter nbconvert scan-xxpy-image-ypdE_step1.ipynb --to html");
os.system(f"mv scan-xxpy-image-ypdE_step1.html {folder}");

In [None]:
os.system(f"ls {folder}");

In [None]:
file.close()