# Step 0

* Read the h5 file containing the scan data.
* Process the camera images.
    * Crop
    * Threshold
    * Downscale
* Create a new h5 file with separate scalar (0D), waveform (1D), and image (2D) data sets. The new file name is "prepoc-" + the original file name.

In [None]:
from datetime import datetime
import importlib
import json
import numpy as np
import os
from pprint import pprint
import sys

import h5py
import imageio
from ipywidgets import interactive
from ipywidgets import widgets
from matplotlib import patches
from matplotlib import pyplot as plt
import proplot as pplt
import skimage.transform
from tqdm.notebook import tqdm
from tqdm.notebook import trange

sys.path.append('/home/46h/btf-data-analysis/')
from tools import image_processing as ip
from tools import utils

sys.path.append('/home/46h/ps-dist/')
from psdist import plotting as mplt

In [None]:
pplt.rc['cmap.discrete'] = False
pplt.rc['cmap.sequential'] = 'viridis'
pplt.rc['grid'] = False
pplt.rc['savefig.dpi'] = 300.0

## Load data

In [None]:
datadir = '/home/share/Measurements/scan-xxpy-image-ypdE/2022-06-26/'
filenames = os.listdir(datadir)
filenames

In [None]:
filename = '220626140058-scan-xxpy-image-ypdE'
file = h5py.File(os.path.join(datadir, filename + '.h5'), 'r')
print(list(file))

Create `info` dictionary to pass between notebooks.

In [None]:
info = dict()
info['datadir'] = datadir
info['filename'] = filename

In [None]:
metadata = dict()
for name in file['config']['metadata'].dtype.names:
    metadata[name] = file['config']['metadata'][name]
print('metadata:')
pprint(metadata)
info['metadata'] = metadata

In [None]:
if 'log' in file:
    log = file['log']
    print(f"'log', {type(log)}")
    for item in log.dtype.fields.items():
        print('  ', item)
    print('\nErrors and warnings:')
    for i in range(log.size):
        if not(log[i, 'level'] == 'INFO'.encode('utf')):
            timestr = datetime.fromtimestamp(log[i, 'timestamp']).strftime("%m/%d/%Y, %H:%M:%S")
            print(f"{timestr} {log[i, 'message']}")

In [None]:
data = file['scandata']

print(f"'scandata', {type(data)}")
for item in data.dtype.fields.items():
    print('  ', item)
print(f"nbytes = {data.nbytes:.3e}")

Set output directory for all data and figures.

In [None]:
outdir = './_output'

if not os.path.exists(outdir):
    os.makedirs(outdir)

In [None]:
def save(figname, prefix='fig_step0', ext='png', **kws):
    """Save figure."""
    figname = f'{prefix}_{figname}'
    if ext:
        figname = f'{figname}.{ext}'
    plt.savefig(os.path.join(outdir, figname), **kws)

## Image processing

Get camera name and settings.

In [None]:
# Find camera name.
cam = None
for name in data.dtype.names:
    if 'cam' in name.lower():
        cam = name.split('_')[0]
        
# Find camera zoom. 
zoom = None
for key in metadata:
    if 'Magn' in key:
        zoom = [0.25, 0.33, 0.5, 1.0][int(metadata[key])]
            
# Load camera settings. 
cam_settings = ip.CameraSettings(cam)
cam_settings.set_zoom(zoom)

# Save info.
info['cam'] = cam_settings.name
info['cam_zoom'] = cam_settings.zoom
info['cam_pix2mm_x'] = info['cam_pix2mm_y'] = cam_settings.pix2mm
info['cam_shape'] = cam_settings.shape

print(f"cam = '{cam}'")
print(f'zoom = {zoom}')
print(f'pix2mm = {cam_settings.pix2mm} (zoom = {zoom})')
print(f'image shape = {cam_settings.shape}')

Use bright/dim images for testing.

In [None]:
def get_image(i):
    return data[i, cam + '_Image'].reshape(cam_settings.shape)

def plot_compare_images(im1, im2, **plot_kws):
    """Plot images side by side, and a second row in log scale."""
    fig, axes = pplt.subplots(ncols=2, nrows=2, figwidth=None, sharex=False, sharey=False)
    for col, im in enumerate([im1, im2]):
        for row, norm in enumerate([None, 'log']):
            mplt.plot_image(im.T / np.max(im), ax=axes[row, col], norm=norm, **plot_kws)
    return axes

In [None]:
signal = data[cam + '_Integral'][:]
imax = np.argmax(signal)
imin = np.argmin(signal)
im_max = get_image(imax)
im_min = get_image(imin)
print(f'Max {cam}_Integral at i = {imax}')
print(f'Min {cam}_Integral at i = {imin}')

In [None]:
for im, title in zip((im_max, im_min), ('Max integral', 'Min integral')):
    fig, axes = pplt.subplots(ncols=2)
    kws = dict(colorbar=True)
    mplt.plot_image(im.T / np.max(im), ax=axes[0], **kws)
    mplt.plot_image(im.T / np.max(im), ax=axes[1], norm='log', **kws)
    axes.format(xlabel='x3', ylabel='y3', suptitle=title)
    save(f'image_{title}')
    plt.show()

### Crop

View cropping at various frames around `imax`.

In [None]:
image_shape = cam_settings.shape
image_crop_edges = {
    'x1': 180, 
    'x2': 420, 
    'y1': 75, 
    'y2': 370,
}
x1 = image_crop_edges['x1']
x2 = image_crop_edges['x2']
y1 = image_crop_edges['y1']
y2 = image_crop_edges['y2']

def update(i, log, handle_log):
    im = get_image(i)
    norm = 'log' if log else None
    fig, ax = pplt.subplots()
    mplt.plot_image(im.T / np.max(im), ax=ax, colorbar=True, norm=norm, handle_log=handle_log)
    ax.add_patch(patches.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, ec='red'))
    plt.show()
    
interactive(update, i=(imax-99, imax+101), log=True, handle_log=['floor', 'mask'])

Save cropping results for brightest image.

In [None]:
im = im_max.copy()
im1 = ip.crop(im, **image_crop_edges)

axes = plot_compare_images(im, im1, colorbar=True)
for ax in axes[:, 0]:
    ax.add_patch(patches.Rectangle((x1, y1), im1.shape[1], im1.shape[0], fill=False, ec='red'))
save('crop')

Make sure that nothing important will be cropped by looking at x/y projections of images across all frames with `signal > thresh`. (Will take a while for large data sets.)

In [None]:
# Extract x and y profiles 
xprofs = data[:, cam + '_ProfileX']
yprofs = data[:, cam + '_ProfileY']

# Plot mean and max heights of all x/y profiles.
fig, axes = pplt.subplots(nrows=2, ncols=2, figsize=(6, 3.5), sharey=1)
for j in range(2):
    for ax, func in zip(axes[:, j], [np.mean, np.max]):
        for profs, label in zip([xprofs, yprofs], ['x', 'y']):
            ax.plot(np.arange(profs.shape[1]), func(profs, axis=0), label=label)
axes[0, 1].legend(ncols=1, loc='r')

# Plot light blue{orange} in uncropped x{y} region.
colors = pplt.Cycle('colorblind').by_key()['color']
for _x1, _x2, _xmax, c in zip([x1, y1], [x2, y2], [xprofs.shape[1], yprofs.shape[1]], colors):
    if _x2 < 0:
        _x2 += _xmax
    for ax in axes:
        ax.axvspan(_x1, _x2, color=c, alpha=0.1)
axes[:, 1].format(yscale='log')

# Save figure
axes.format(leftlabels=['Mean', 'Max'], toplabels=['Normal scale', 'Log scale'],
            xlabel='Pixel', ylabel='Integrated profile')
save('crop2')

In [None]:
info['image_crop_edges'] = image_crop_edges
im = im1.copy()

### Threshold

Check for a nonzero offset.

In [None]:
window = int(0.01 * len(data))
frames = np.arange(imin - window, imin + window, 1)
min_pixels = [np.min(ip.crop(get_image(i), **image_crop_edges)) for i in tqdm(frames)]

fig, ax = pplt.subplots()
ax.plot(frames, min_pixels, color='black')
ax.format(xlabel='Frame', ylabel='min_pixel')
save('offset')

Subtract the offset.

In [None]:
image_offset = np.mean(min_pixels)
info['image_offset'] = image_offset
print('image_offset = ', image_offset)

im = im - image_offset

Select images that are obviously measuring noise. (Tune `width_view` to view more frames; tune `width_select` to select the frames.)

In [None]:
steps = 64  # (approximate) number of steps in one sweep
width_view = 30 * steps
width_calc = 2 * steps
icenter = imin
info['noise_icenter'] = icenter
idx_view = np.arange(icenter - width_view, icenter + width_view, 1)
idx_calc = np.arange(icenter - width_calc, icenter + width_calc, 1)
ims_noise = np.array([ip.crop(get_image(i), **image_crop_edges) for i in idx_calc])
print('For selected images:')
print(f'max pixel: {np.max(ims_noise)}')
print(f'min noise: {np.min(ims_noise)}')
print(f'mean noise: {np.mean(ims_noise)}')
print(f'std noise: {np.std(ims_noise)}')

for yscale in [None, 'log']:
    fig, ax = pplt.subplots(figsize=(8.0, 1.75))
    ax.plot(idx_view, signal[idx_view], color='black', alpha=0.2)
    ax.plot(idx_calc, signal[idx_calc], color='black', label='selected')
    ax.legend(loc='upper right')
    ax.format(yscale=yscale)
    save(f'noise_region_{yscale}')
    plt.show()

Set the threshold based on these noisy images.

In [None]:
image_thresh = 1.0 * (np.max(ims_noise) - image_offset)
image_thresh_frac_peak = image_thresh / np.max(im)

print(f'image_thresh = {image_thresh}')
print(f'10^{np.log10(image_thresh_frac_peak):.2f} of max pixel in peak integral image')

Observe how the thresholding affects the images.

In [None]:
def update(i, log, discrete, handle_log, thresh):
    _im = get_image(i)
    _im = ip.crop(_im, **image_crop_edges)
    _im = _im - image_offset
    _im[_im <= thresh] = 0
    norm = 'log' if log else None
    fig, ax = pplt.subplots()
    im_max = np.max(_im)
    if im_max > 0:
        _im = _im / im_max
    mplt.plot_image(_im.T, ax=ax, colorbar=True, norm=norm, 
                    discrete=discrete, handle_log=handle_log)
    plt.show()
    
interactive(
    update, i=(imax - 99, imax + 101), log=True, discrete=False,
    handle_log=['mask', 'floor'], 
    thresh=widgets.FloatSlider(min=0.0, max=4.0 * image_thresh, 
                               step=0.1 * image_thresh, value=image_thresh),
)

Save the results for our test image.

In [None]:
im1 = im.copy()
im1[im1 <= image_thresh] = 0
info['image_thresh'] = image_thresh

for discrete in [False, True]:
    axes = plot_compare_images(im / np.max(im), im1 / np.max(im1), colorbar=True, 
                               discrete=discrete, handle_log='mask')
    axes.format(suptitle=f'Threshold at 10^{np.log10(image_thresh_frac_peak):.2f} peak pixel')
    save(f'thresh_discrete{discrete}')
    plt.show()

In [None]:
i, j = np.unravel_index(np.argmax(im), im.shape)
fig, axes = pplt.subplots([[1, 2, 3], [1, 4, 5]], sharey=False)
axes[0].pcolormesh(im)
kws = dict(color='white', alpha=0.4)
axes[0].axhline(i, **kws)
axes[0].axvline(j, **kws)
axes[1].set_title(f'Row {i}')
axes[2].set_title(f'Column {j}')
kws = dict(color='black', lw=1.0)
for ax in axes[:, 1]:
    ax.plot(np.arange(im.shape[1]), im[i, :] / np.max(im[i, :]), **kws)
    ax.axhline(image_thresh / np.max(im[i, :]), color='black', alpha=0.1)
for ax in axes[:, 2]:
    ax.plot(np.arange(im.shape[0]), im[:, j] / np.max(im[:, j]), **kws)
    ax.axhline(image_thresh / np.max(im[:, j]), color='black', alpha=0.1)
axes[1:, 1:].format(yscale='log')
save('thresh2')
plt.show()

In [None]:
im = im1.copy()

### Downscale 

Observe the effect of downscaling.

In [None]:
def update(i, log, downscale, handle_log):
    _im = get_image(i)
    _im = ip.crop(_im, **image_crop_edges)
    _im = _im - image_offset
    _im[_im <= image_thresh] = 0
    _im = skimage.transform.downscale_local_mean(_im, (downscale, downscale))
    norm = 'log' if log else None
    fig, ax = pplt.subplots()
    mplt.plot_image(_im.T / np.max(_im), ax=ax, colorbar=True, norm=norm, handle_log=handle_log,
                    profx=True, profy=True)
    plt.show()
    
    
interactive(
    update, i=(imax - 99, imax + 101), log=True, 
    downscale=widgets.IntSlider(min=1, max=10, value=3),
    handle_log=['mask', 'floor'],
)

Set the downscale factor and save the results for our test image.

In [None]:
image_downscale = 3
info['image_downscale'] = image_downscale
    
im1 = skimage.transform.downscale_local_mean(im, (image_downscale, image_downscale))

for discrete in [False, True]:
    axes = plot_compare_images(im / np.max(im), im1 / np.max(im1), colorbar=True, discrete=discrete,
                               handle_log='floor')
    axes.format(toplabels=['Original', f'Downscaled by factor {image_downscale}'])
    save(f'downscale_discrete{discrete}')

Save the new image shape and pixel/mm calibration.

In [None]:
info['image_shape'] = im1.shape
info['image_pix2mm_y'] = cam_settings.pix2mm * (im.shape[0] / im1.shape[0])
info['image_pix2mm_x'] = cam_settings.pix2mm * (im.shape[1] / im1.shape[1])

print('Original (cropped) image shape:', im.shape)
print('Downscaled image shape:', im1.shape)
print(f'Original pix2mm_x = pix2mm_y = {cam_settings.pix2mm} (zoom = {cam_settings.zoom})')
print('new pix2mm_y =', info['image_pix2mm_y'])
print('new pix2mm_x =', info['image_pix2mm_x'])

Make a function that does all of the above.

In [None]:
def process_image(im, crop_edges, offset, thresh, downscale):
    im = ip.crop(im, **crop_edges)
    im = im - offset
    im[im <= thresh] = 0
    if downscale > 1:
        im = skimage.transform.downscale_local_mean(im, (downscale, downscale))
    return im

### Save gifs of a few sweeps

If you want to save before downscaling.

In [None]:
window = 150  # number of frames
iterations = data[imax - window: imax + window, 'iteration']
unique_iterations = np.unique(iterations)
print('Iterations:', len(unique_iterations))
print(unique_iterations)

_images = []
for i in range(imax - window, imax + window):
    _image = get_image(i)
    _image = ip.crop(_image, **image_crop_edges)
    _images.append(_image)
_images = np.array(_images)
_images = _images / np.max(_images)

cmap = 'mono_r'
imageio.mimwrite(
    os.path.join(outdir, f'fig_sweeps_{unique_iterations[0]}-{unique_iterations[-1]}_{cmap}.gif'),
    [ip.to_uint8(im, cmap) for im in _images],
    fps=12,
)

## Write new h5 file

Find points where the measured beam current is too low and mask them.

In [None]:
bcm = 'bcm04'
bcm_data = data[bcm][:]

In [None]:
bcm_limit = 23.5  # minmum absolute current [mA]
idx = np.arange(len(data))
valid_bcm, = np.where(np.abs(bcm_data) > bcm_limit)
mask_bcm, = np.where(np.abs(bcm_data) <= bcm_limit)
n_valid = len(valid_bcm)

print(f'Average BCM current (before masking) = {np.mean(bcm_data):.3f} +- {np.std(bcm_data):.3f} [mA]')
print(f'Number of valid points: {n_valid}')
for i in mask_bcm:
    print(f'Point {i} masked due to {bcm} current {np.abs(bcm_data[i]):.3f} < {bcm_limit:.3f} [mA]')
print('Average BCM current (after masking) = {:.3f} +- {:.3f} [mA]'
      .format(np.mean(bcm_data[valid_bcm]), np.std(bcm_data[valid_bcm])))

In [None]:
fig, ax = pplt.subplots(figsize=(7.0, 2.0))
ax.plot(bcm_data, color='black')
ax.plot(mask_bcm, bcm_data[mask_bcm], color='red', lw=0, marker='.', label='Masked')
ax.format(xlabel='Point', ylabel='BCM current [mA]', ygrid=True)
ax.legend(loc='upper left')
save('bcm_mask')

In [None]:
fig, ax = pplt.subplots(figsize=(7.0, 2.0))
ax.plot(valid_bcm, bcm_data[valid_bcm], color='black')
bcm_rolling_average = np.cumsum(bcm_data[valid_bcm]) / np.arange(1, len(valid_bcm) + 1)
ax.plot(valid_bcm, bcm_rolling_average, color='red', label='rolling avg.', alpha=0.8)
ax.format(xlabel='Point', ylabel='BCM current [mA]', ygrid=False)
ax.legend(loc='lower right')
save('bcm_unmasked_points_rolling_average')

Create a new H5 file with three data sets: scalar (0d), waveform (1d), and image (2d). First, collect the appropriate dtypes.

In [None]:
im = im1.copy()
attrs = data.dtype.names
sc_dtype, sc_attrs = [], list(attrs)
print(sc_attrs)
wf_dtype, wf_attrs = [], []
im_dtype, im_attrs = [], []
print('Moving the following columns:')
for i in reversed(range(len(attrs))):
    attr = attrs[i]
    if '_Image' in attr:
        sc_attrs.pop(i)
        im_attrs.append(attr)
        im_dtype.append((attr, data.dtype[attr]))
        print(attr)
    elif 'Profile' in attr:
        sc_attrs.pop(i)
        wf_attrs.append(attr)
        wf_dtype.append((attr, data.dtype[attr]))
        print(attr)
    else:
        sc_dtype.append((attr, data.dtype[attr]))
        
sc_dtype = np.dtype(sc_dtype)
wf_dtype = np.dtype(wf_dtype)
im_dtype = np.dtype(im_dtype)

# Override the image dtype. 
im_dtype = np.dtype([(cam + '_Image', str(im.dtype), (im.size,))])

print('\nscalars:')
print(sc_dtype)
print('\nwaveforms:')
print(wf_dtype)
print('\nimage:')
print(im_dtype)

Write the data to the new files.

In [None]:
writer = h5py.File(os.path.join(outdir, 'preproc-' + filename + '.h5'), 'w')
data_sc = writer.create_dataset('scalardata', (n_valid,), dtype=sc_dtype)
data_wf = writer.create_dataset('wfdata', (n_valid,), dtype=wf_dtype)
data_im = writer.create_dataset('imagedata', (n_valid,), dtype=im_dtype)
for i, j in enumerate(tqdm(valid_bcm)):
    for attr in sc_attrs:
        data_sc[i, attr] = data[j, attr]
    for attr in wf_attrs:
        data_wf[i, attr] = data[j, attr]
    for attr in im_attrs:
        image = get_image(j)
        image = process_image(image, image_crop_edges, image_offset, 
                              image_thresh, image_downscale)
        data_im[i, attr] = image.ravel()
writer.close()

Pass info dict to future notebooks. 

In [None]:
print('info:')
pprint(info)
utils.save_pickle(os.path.join(outdir, 'info.pkl'), info)

Save HTML of this notebook.

In [None]:
os.system(f"jupyter nbconvert scan-xxpy-image-ypdE_step0.ipynb --to html");
os.system(f"mv scan-xxpy-image-ypdE_step0.html {outdir}");

In [None]:
os.system(f"ls {outdir}");