# Step 1

* Load scalar, waveform and image h5 files.
* For each sweep, interpolate images on regular y grid. 
* For each y and image pixel, interpolate x-x'. (f(x, x', y, y3, x3))
* For each (x, xp, y, x3), interpolate yp. (f(x, x', y, y', x3))
* For each (x, xp, y, yp), inteprolate w. (f(x, x', y, y', w)).

In [None]:
import sys
import os
from os.path import join
import time
from datetime import datetime
import importlib
from pprint import pprint
import numpy as np
import pandas as pd
import h5py
import imageio
from scipy import ndimage
from scipy import interpolate
import skimage
from tqdm.notebook import tqdm
from tqdm.notebook import trange
from matplotlib import pyplot as plt
from matplotlib.patches import Ellipse
from plotly import graph_objects as go
import proplot as pplt
from ipywidgets import interactive
from ipywidgets import widgets

sys.path.append('../..')
from tools.energyVS06 import EnergyCalculate
from tools import image_processing as ip
from tools import plotting as mplt
from tools import utils

In [None]:
pplt.rc['grid'] = False
pplt.rc['cmap.discrete'] = False
pplt.rc['cmap.sequential'] = 'viridis'

## Setup

In [None]:
folder = '_output'

In [None]:
info = utils.load_pickle(join(folder, 'info.pkl'))
print('info')
pprint(info)

In [None]:
datadir = info['datadir']
filename = info['filename']
datadir = info['datadir']
file = h5py.File(join(datadir, 'preproc-' + filename + '.h5'), 'r')
data_sc = file['/scalardata']
data_wf = file['/wfdata']
data_im = file['/imagedata']

print('Attributes:')
print()
for data in [data_sc, data_wf, data_im]:
    print(data.name)
    for item in data.dtype.fields.items():
        print(item)
    print()

In [None]:
variables = info['variables']
keys = list(variables)
nsteps = np.array([variables[key]['steps'] for key in keys])

acts = info['acts']
print(acts)
points = np.vstack([data_sc[act] for act in acts]).T

cam = info['cam']
print(f"cam = '{cam}'")
if cam.lower() not in ['cam06', 'cam34']:
    raise ValueError(f"Unknown camera name '{cam}'.")

## Interpolation

### Convert to beam frame coordinates

In [None]:
## y slit is inserted from above, always opposite y beam.
points[:, 0] = -points[:, 0]

## Screen coordinates (x3, y3) are always opposite beam (x, y).
image_shape = info['image_shape']
x3grid = -np.arange(image_shape[1]) * info['image_pix2mm_x']
y3grid = -np.arange(image_shape[0]) * info['image_pix2mm_y']

# VT04/VT06 are same sign as x_beam; VT34a and VT34b are opposite.
if cam.lower() == 'cam34':
    points[:, 1:] *= -1.0

### Setup interpolation grids

Build the transfer matrices between the slits and the screen. 

In [None]:
dipole_current = 0.0  # deviation of dipole current from nominal
l = 0.129  # dipole face to screen (assume same for first/last dipole-screen)
if cam.lower() == 'cam06':
    GL05 = 0.0  # QH05 integrated field strength
    GL06 = 0.0  # QH06 integrated field strength  (1 [A] = 0.0778 [Tm])
    l1 = 0.280  # slit1 to QH05 center
    l2 = 0.210  # QH05 center to QV06 center
    l3 = 0.457  # QV06 center to slit2
    L2 = 0.599  # slit2 to dipole face    
    rho_sign = +1.0  # dipole bend radius sign
    # if GL05 == 0.0 and metadata['BTF_MEBT_Mag:PS_QH05:I_Set'][0] != 0.0:
    #     print('Warning: QH05 is turned on according to metadata.')
    # if GL05 != 0.0 and metadata['BTF_MEBT_Mag:PS_QH05:I_Set'][0] == 0.0:
    #     print('Warning: QH05 is turned off according to metadata.')
    # if GL06 == 0.0 and metadata['BTF_MEBT_Mag:PS_QV06:I_Set'][0] != 0.0:
    #     print('Warning: QH06 is turned on according to metadata.')
    # if GL06 != 0.0 and metadata['BTF_MEBT_Mag:PS_QV06:I_Set'][0] == 0.0:
    #     print('Warning: QH06 is turned off according to metadata.')
elif cam.lower() == 'cam34':
    GL05 = 0.0  # QH05 integrated field strength
    GL06 = 0.0  # QH06 integrated field strength
    l1 = 0.000  # slit1 to QH05 center
    l2 = 0.000  # QH05 center to QV06 center
    l3 = 0.774  # QV06 center to slit2
    L2 = 0.311  # slit2 to dipole face
    # Weird... I can only get the right answer for energy if I *do not* flip rho,
    # x1, x2, and x3. I then flip x and xp at the very end.
    rho_sign = +1.0  # dipole bend radius sign
    x3grid = -x3grid
    points[:, 1:] = -points[:, 1:]
LL = l1 + l2 + l3 + L2  # distance from emittance plane to dipole entrance
ecalc = EnergyCalculate(l1=l1, l2=l2, l3=l3, L2=L2, l=l, rho_sign=rho_sign)
Mslit = ecalc.getM1(GL05=GL05, GL06=GL06)  # slit-slit
Mscreen = ecalc.getM(GL05=GL05, GL06=GL06)  # slit-screen

#### y grid

In [None]:
y_scale = 1.1
ygrid = np.linspace(
    np.min(points[:, 0]), 
    np.max(points[:, 0]), 
    int(y_scale * (nsteps[0] + 1)),
)

#### x-x' grid

Assume that $x$ and $x'$ do not change on each iteration (or that the only variation is noise in the readback value). Select an $x$ and $x'$ for each $\left\{y, y_3, x_3\right\}$.

In [None]:
iterations = data_sc['iteration']
iteration_nums = np.unique(iterations)
n_iterations = len(iteration_nums)

XXP = np.zeros((n_iterations, 2))
for iteration in iteration_nums:
    idx, = np.where(iterations == iteration)
    x2, x1 = np.mean(points[idx, 1:], axis=0)
    x = x1
    xp = 1e3 * ecalc.calculate_xp(x1 * 1e-3, x2 * 1e-3, Mslit)
    XXP[iteration - 1] = (x, xp)

Define the $x$-$x'$ interpolation grid. Tune `x_scale` and `xp_scale` to roughly align the grid points with the measured points.

In [None]:
x_scale = 1.1
xp_scale = 1.6

x_min, xp_min = np.min(XXP, axis=0)
x_max, xp_max = np.max(XXP, axis=0)
xgrid = np.linspace(x_min, x_max, int(x_scale * (nsteps[2] + 1)))
xpgrid = np.linspace(xp_min, xp_max, int(xp_scale * (nsteps[1] + 1)))

fig, ax = pplt.subplots(figwidth=4)
line_kws = dict(color='lightgray', lw=0.7)
for x in xgrid:
    g1 = ax.axvline(x, **line_kws)
for xp in xpgrid:
    ax.axhline(xp, **line_kws)
ax.plot(XXP[:, 0], XXP[:, 1], color='pink7', lw=0, marker='.', ms=2)
xlim = ax.get_xlim()
if xlim[0] > xlim[1]:
    ax.set_xlim(reversed(xlim))
ax.format(xlabel='x [mm]', ylabel='xp [mrad]')
# plt.savefig('_output/xxp_interp_grid.png')

#### y' grid

Define the $y'$ grid.

In [None]:
yp_scale = 1.1  # scales resolution of y' interpolation grid
_Y, _Y3 = np.meshgrid(ygrid, y3grid, indexing='ij')
_YP = 1e3 * ecalc.calculate_yp(_Y * 1e-3, _Y3 * 1e-3, Mscreen)  # [mrad]
ypgrid = np.linspace(
    np.min(_YP),
    np.max(_YP), 
    int(yp_scale * len(y3grid)),
)

fig, ax = pplt.subplots(figwidth=4)
for yp in ypgrid:
    ax.axhline(yp, **line_kws)
for y in ygrid:
    ax.axvline(y, **line_kws)
ax.plot(_Y.ravel(), _YP.ravel(), color='pink7', lw=0, marker='.', ms=2)
ax.format(xlabel='y [mm]', ylabel='yp [mrad]')
# plt.savefig('_output/yyp_interp_grid.png')
plt.show()

#### w grid

In [None]:
w_scale = 1.1
_W = np.zeros((len(xgrid), len(xpgrid), image_shape[1]))
for i in range(len(xgrid)):
    for j in range(len(xpgrid)):
        _W[i, j, :] = ecalc.calculate_dE_screen(
            1e-3 * x3grid,
            dipole_current, 
            1e-3 * xgrid[i],
            1e-3 * xpgrid[j],
            Mscreen,
        )
wgrid = np.linspace(np.min(_W), np.max(_W), int(w_scale * image_shape[1]))

In [None]:
_arrs = []
count = 0
j = 0
for i in range(len(xgrid)):
    _arrs.append(_W[i, j, :])

fig, ax = pplt.subplots()
for i, _arr in enumerate(_arrs):
    ax.plot([i, i], [np.min(_arr), np.max(_arr)], c='black')
for w in wgrid:
    ax.axhline(w, **line_kws)

In [None]:
xxpw_points = []
for x in np.linspace(xgrid.min(), xgrid.max(), 10):
    for xp in np.linspace(xpgrid.min(), xpgrid.max(), 10):
        for x3 in np.linspace(x3grid.min(), x3grid.max(), 10):
            w = ecalc.calculate_dE_screen(x3 * 1e-3, 0.0, x * 1e-3, xp * 1e-3, Mscreen)  # [MeV]
            xxpw_points.append([x, xp, w])
xxpw_points = np.array(xxpw_points)
xxpw_points = xxpw_points - np.mean(xxpw_points, axis=0)

fig, axes = pplt.subplots(ncols=3, nrows=3, figwidth=4.5, span=False)
for i in range(3):
    for j in range(3):
        ax = axes[i, j]
        ax.scatter(xxpw_points[:, j], xxpw_points[:, i], color='black', s=1)
    axes[-1, i].set_xlabel(['x', 'xp', 'w'][i])
    axes[i, 0].set_ylabel(['x', 'xp', 'w'][i])
plt.show()

## Interpolation 

### Interpolate y 

Interpolate the image stack along the $y$ axis on each iteration.

In [None]:
images = data_im[cam + '_Image'].reshape((len(points), len(y3grid), len(x3grid)))

In [None]:
images_3D, images_3D_raw = [], []
for iteration in tqdm(iteration_nums):
    idx, = np.where(iterations == iteration)
    _points = points[idx, 0].copy()
    _values = images[idx, :, :].copy()
    _, uind = np.unique(_points, return_index=True)
    # Interpolate y-y3-x3 image along y axis.
    fint = interpolate.interp1d(
        _points[uind],
        _values[uind],
        kind='linear', 
        axis=0, 
        copy=True, 
        bounds_error=False, 
        fill_value=0.0, 
        assume_sorted=True,
    )
    images_3D.append(fint(ygrid))
    images_3D_raw.append(_values[uind])
images_3D = np.array(images_3D)

Adjust sliders so that we're looking at the same y positions in raw and interpolated images.

In [None]:
def update(iteration, yint, yraw):
    fig, axes = pplt.subplots(ncols=2)
    for _arrs, _y, title, ax in zip([images_3D_raw, images_3D], [yraw, yint], ['Raw', 'Interpolated'], axes):
        ax.format(title=title)
        mplt.plot_image(_arrs[iteration - 1][_y].T, ax=ax)
    axes.format(xlabel='x3', ylabel='y3')
    _, uind = np.unique(points[iterations == iteration, 0], return_index=True)
    ys = points[iterations == iteration, 0].copy()
    ys = np.unique(ys)
    print(f'yint, yraw = ({ygrid[yint]}, {ys[yraw]})')
    print(f'x, xp = {XXP[iteration - 1]}')
    plt.show()
    
interactive(
    update, 
    iteration=widgets.BoundedIntText(
        value=len(images_3D)//2,
        min=1,
        max=len(XXP),
        step=1,
        description='Iteration',
        disabled=False,
    ),
    yint=(0, len(ygrid) - 1), 
    yraw=(0, len(ygrid) - 1),
)

### Testing

y-y'

In [None]:
fyy3 = np.sum(images_3D, axis=(0, -1))  # y vs y3
fyyp = np.zeros((len(ygrid), len(ypgrid)))
for i in range(len(ygrid)):
    yp = ecalc.calculate_yp(ygrid[i], y3grid, Mscreen)
    fint = interpolate.interp1d(
        yp,
        fyy3[i, :].copy(),
        kind='linear', fill_value=0.0, bounds_error=False, assume_sorted=False,
    )
    fyyp[i, :] = fint(ypgrid)

In [None]:
# Save emittance for simulation

y_scale = 1.1
ygrid = np.linspace(
    np.min(points[:, 0]), 
    np.max(points[:, 0]), 
    int(200),
)
ypgrid = np.linspace(
    np.min(_YP),
    np.max(_YP), 
    int(200),
)

images_3D = []
for iteration in tqdm(iteration_nums):
    idx, = np.where(data_sc['iteration'] == iteration)
    _points = points[idx, 0].copy()
    _values = images[idx, :, :].copy()
    _, uind = np.unique(_points, return_index=True)
    # Interpolate y-y3-x3 image along y axis.
    fint = interpolate.interp1d(
        _points[uind],
        _values[uind],
        kind='linear', 
        axis=0, 
        copy=True, 
        bounds_error=False, 
        fill_value=0.0, 
        assume_sorted=True,
    )
    images_3D.append(fint(ygrid))
images_3D = np.array(images_3D)

fyy3 = np.sum(images_3D, axis=(0, -1))  # y vs y3
fyyp = np.zeros((len(ygrid), len(ypgrid)))
for i in range(len(ygrid)):
    yp = ecalc.calculate_yp(ygrid[i], y3grid, Mscreen)
    fint = interpolate.interp1d(
        yp,
        fyy3[i, :].copy(),
        kind='linear', fill_value=0.0, bounds_error=False, assume_sorted=False,
    )
    fyyp[i, :] = fint(ypgrid)
    
ygrid = ygrid - np.mean(ygrid)
ypgrid = ypgrid - np.mean(ypgrid)
df = pd.DataFrame(
    data=np.vstack([
        np.hstack([None, ygrid]),
        np.hstack([ypgrid[:, None], fyyp.T / np.max(fyyp)]),
    ]),
)
directory = '_output'
df.to_csv(
    join(directory, f"emittance-data-y-{filename.split('-')[0]}.csv"),
    header=False,
    index=False,
)

In [None]:
fig, ax = pplt.subplots()
mplt.plot_image(
    fyyp / fyyp.max(), 
    x=ygrid-np.mean(ygrid), 
    y=ypgrid-np.mean(ypgrid), 
    ax=ax, 
    profx=True, profy=True,
    norm='log', 
    handle_log='floor', frac_thresh=1e-3, colorbar=True,
)
ax.format(xlabel='y [mm]', ylabel='yp [mrad]')
plt.show()

x-x'

In [None]:
_values = np.sum(images_3D, axis=(1, 2, 3))
_points = XXP
_new_points = utils.get_grid_coords(xgrid, xpgrid)
_new_values = interpolate.griddata(_points, _values, _new_points, method='linear', fill_value=0.0)
fxxp = _new_values.reshape((len(xgrid), len(xpgrid)))
fxxp = np.clip(fxxp, 0.0, None)
fxxp = fxxp[::-1, ::-1]

fig, ax = pplt.subplots()
mplt.plot_image(
    fxxp / fxxp.max(), x=xgrid-np.mean(xgrid), y=xpgrid-np.mean(xpgrid), ax=ax, 
    handle_log='floor', 
    norm='log', 
    frac_thresh=1e-4, 
    colorbar=True,
    profx=True,
    profy=True,
)
ax.format(xlabel='x [mm]', ylabel='xp [mrad]')

In [None]:
__values = _values / np.max(_values)

fig, axes = pplt.subplots(ncols=2, figwidth=5.0, share=False)
axes[0].scatter(
    _points[:, 0], _points[:, 1], 
    c=np.log10(__values  + np.min(__values[_values > 0])),
    cmap='viridis',
    s=1,
    marker='o',
)
mplt.plot_image(
    fxxp / fxxp.max(), x=xgrid, y=xpgrid, ax=axes[1], 
    handle_log='floor', 
    norm='log', 
    frac_thresh=1e-4, 
    colorbar=True,
)
axes.format(xlabel="x [mm]", ylabel="xp [mrad]")
axes[0].format(title='Measured')
axes[1].format(title='Interpolated')
plt.savefig('interp2d.png')

x-x'-w

In [None]:
pws = np.zeros((len(iteration_nums), len(x3grid)))
for iteration in tqdm(iteration_nums):
    idx, = np.where(iterations == iteration)
    pws[iteration - 1] = np.sum(images[idx, :, :], axis=(0, 1))
    
fxxpx3 = np.zeros((len(xgrid), len(xpgrid), len(x3grid)))
_new_points = utils.get_grid_coords(xgrid, xpgrid)
for k in trange(len(x3grid)):
    _points = XXP
    _values = pws[:, k]
    _new_values = interpolate.griddata(
        _points,
        _values,
        _new_points,
        method='linear',
        fill_value=0.0,
    )
    fxxpx3[:, :, k] = _new_values.reshape((len(xgrid), len(xpgrid)))
    
fxxpw = np.zeros((len(xgrid), len(xpgrid), len(wgrid)))
for i, x in enumerate(tqdm(xgrid)):
    for j, xp in enumerate(xpgrid):
        w = ecalc.calculate_dE_screen(x3grid * 1e-3, 0.0, x * 1e-3, xp * 1e-3, Mscreen)
        fint = interpolate.interp1d(
            w, fxxpx3[i, j, :], kind='linear', fill_value=0.0, bounds_error=False,
            assume_sorted=False, copy=True,
        )
        fxxpw[i, j, :] = fint(wgrid)

# Hack for now
fxxpw = fxxpw[::-1, ::-1, ...]

In [None]:
mplt.interactive_proj2d(
    fxxpw / np.max(fxxpw), 
    dims=['x', 'xp', 'w'], 
    units=['mm', 'mrad', 'MeV'],
    coords=[
        xgrid-np.mean(xgrid), 
        xpgrid-np.mean(xpgrid), 
        wgrid-np.mean(wgrid)
    ],
)

x-x'-y-y'

In [None]:
images_yy3 = np.sum(images_3D, axis=-1)  # iteration, y, y3

_f4d = np.zeros((len(xgrid), len(xpgrid), len(ygrid), len(y3grid)))
_new_points = utils.get_grid_coords(xgrid, xpgrid, indexing='ij')
for k in trange(len(ygrid)):
    for l in range(len(y3grid)):
        _points = XXP
        _values = images_yy3[:, k, l]
        _new_values = interpolate.griddata(
            _points, _values, _new_points,
            method='linear', fill_value=0.0,
        )
        _f4d[:, :, k, l] = _new_values.reshape((len(xgrid), len(xpgrid)))

In [None]:
_coords = [c.copy() - c.mean() for c in [xgrid, xpgrid, ygrid, y3grid]]
_labels = ["x [mm]", "xp [mrad]", "y [mm]", "y3 [mm]"]
fig, axes = pplt.subplots(ncols=2, nrows=2, share=False, space=7, figwidth=5)
for ax, (i, j) in zip(axes, [(0, 1), (2, 3), (0, 2), (1, 3)]):
    _im = utils.project(_f4d, (i, j))
    mplt.plot_image(_im, x=_coords[i], y=_coords[j], ax=ax)
    ax.format(xlabel=_labels[i], ylabel=_labels[j])

In [None]:
f4d = np.zeros((len(xgrid), len(xpgrid), len(ygrid), len(ypgrid)))
for i in trange(len(xgrid)):
    for j in range(len(xpgrid)):
        for k in range(len(ygrid)):
            yp = ecalc.calculate_yp(ygrid[k], -y3grid, Mscreen)
            fint = interpolate.interp1d(yp, _f4d[i, j, k, :].copy(), kind='linear', 
                                        bounds_error=False, fill_value=0.0,
                                        copy=True,
                                        assume_sorted=False)
            f4d[i, j, k, :] = fint(ypgrid)             
# f4d = np.clip(f4d, 0.0, None)

In [None]:
pwd

In [None]:
_coords = [c.copy() - c.mean() for c in [xgrid, xpgrid, ygrid, ypgrid]]
_labels = ["x [mm]", "xp [mrad]", "y [mm]", "yp [mrad]"]
fig, axes = pplt.subplots(ncols=2, nrows=2, share=False, hspace=5, wspace=6, figwidth=5)
for ax, (i, j) in zip(axes, [(0, 1), (2, 3), (0, 2), (1, 3)]):
    _im = utils.project(f4d, (i, j))
    mplt.plot_image(_im, x=_coords[i], y=_coords[j], ax=ax)
    ax.format(xlabel=_labels[i], ylabel=_labels[j])
plt.savefig('_output/compare2.png', dpi=200)

In [None]:
mplt.interactive_proj2d(f4d, dims=['x', 'xp', 'y', 'yp'], 
                        units=['mm', 'mrad', 'mm', 'mrad'],
                        coords=[xgrid, xpgrid, ygrid, ypgrid])

In [None]:
# axes = mplt.corner(
#     f4d,
#     coords=[xgrid-np.mean(xgrid), xpgrid-np.mean(xpgrid), 
#             ygrid-np.mean(ygrid), y3grid-np.mean(y3grid)],
#     diag_kind='None',
#     labels=["x [mm]", "xp [mrad]", "y [mm]", "y3 [mm]"],
#     # prof='edges',
#     prof_kws=dict(lw=0.75, alpha=0.5),
# )
# # plt.savefig("_output/corner_4d_y3.png")

In [None]:
# f4d_new = np.zeros((len(xgrid), len(xpgrid), len(ygrid), len(ypgrid)))
# for i in trange(len(xgrid)):
#     for j in range(len(xpgrid)):
#         for k in range(len(ygrid)):
#             _points = 1e3 * ecalc.calculate_yp(ygrid[k] * 1e-3, y3grid * 1e-3, Mscreen)
#             _values = f4d[i, j, k, :].copy()
#             fint = interpolate.interp1d(
#                 _points,
#                 _values,
#                 kind='linear',
#                 assume_sorted=False,
#                 bounds_error=False,
#                 fill_value=0.0,
#             )
#             f4d_new[i, j, k, :] = fint(ypgrid)
# f4d_new = np.clip(f4d_new, 0.0, None)

In [None]:
# _coords = [grid - np.mean(grid) for grid in [xgrid, xpgrid, ygrid, ypgrid]]
# mplt.interactive_proj2d(f4d_new, dims=['x', 'xp', 'y', 'yp'], coords=_coords)

In [None]:
# axes = mplt.corner(
#     f4d_new,
#     coords=_coords,
#     diag_kind='None',
#     labels=["x [mm]", "xp [mrad]", "y [mm]", "yp [mrad]"],
# )
# # plt.savefig("_output/corner_4d.png")

### Interpolate x-x'

Interpolate $x$-$x'$ for each $\left\{y, y_3, x_3\right\}$.

In [None]:
shape = (len(xgrid), len(xpgrid), len(ygrid), image_shape[0], image_shape[1])
f = np.zeros(shape)
_new_points = utils.get_grid_coords(xgrid, xpgrid)
for k in trange(shape[2]):
    for l in trange(shape[3]):
        for m in range(shape[4]):
            new_values = interpolate.griddata(
                XXP,
                images_3D[:, k, l, m],
                _new_points,
                method='linear',
                fill_value=False,
            )
            f[:, :, k, l, m] = new_values.reshape((shape[0], shape[1]))

### Interpolate y'

The $y$ coordinate is already on a grid. For each $\left\{x, x', y, x_3\right\}$, transform $y_3 \rightarrow y'$ and interpolate onto `ypgrid`. 

In [None]:
shape = (len(xgrid), len(xpgrid), len(ygrid), len(ypgrid), len(wgrid))
f_new = np.zeros(shape)
for i in trange(shape[0]):
    for j in trange(shape[1]):
        for k, y in enumerate(ygrid):
            for m in range(shape[4]): 
                yp = 1e3 * ecalc.calculate_yp(y * 1e-3, y3grid * 1e-3, Mscreen)
                fint = interpolate.interp1d(
                    yp,
                    f[i, j, k, :, m], 
                    kind='linear', 
                    bounds_error=False,
                    fill_value=0.0, 
                    assume_sorted=False,
                )
                f_new[i, j, k, :, m] = fint(ypgrid)
f = f_new.copy()

### Interpolate energy spread $w$

Interpolate $w$ for each $\left\{x, x', y, y'\right\}$.

In [None]:
savefilename = f'_output/f_{filename}.mmp'
f_new = np.memmap(savefilename, shape=shape, dtype='float', mode='w+') 
for i in trange(shape[0]):
    for j in trange(shape[1]):
        for k in range(shape[2]):
            for l in range(shape[3]):
                x = xgrid[i]
                xp = xpgrid[j]
                w = ecalc.calculate_dE_screen(x3grid * 1e-3, 0.0, x * 1e-3, xp * 1e-3, Mscreen)
                fint = interpolate.interp1d(
                    w,
                    f[i, j, k, l, :],
                    kind='linear',
                    fill_value=0.0, 
                    bounds_error=False,
                    assume_sorted=False,
                )
                f_new[i, j, k, l, :] = fint(wgrid)

Center grid and save coordinates.

In [None]:
coords = [xgrid, xpgrid, ygrid, ypgrid, wgrid]
for i in range(5):
    coords[i] = coords[i] - np.mean(coords[i])
utils.save_stacked_array(f'_output/coords_{filename}.npz', coords)
info['int_shape'] = tuple([len(c) for c in coords])
print(info['int_shape'])

Save info.

In [None]:
print('info:')
pprint(info)

# Save as pickled dictionary for easy loading.
utils.save_pickle('_output/info.pkl', info)

# Also save as file for viewing.
file = open('_output/info.txt', 'w')
for key, value in info.items():
    file.write(f'{key}: {value}\n')
file.close()