# 4D scan

In [None]:
import sys
import os
from os.path import join
from pprint import pprint
import importlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches
import pandas as pd
import h5py
from scipy import ndimage
import proplot as pplt

sys.path.append('..')
from tools import plotting as mplt
from tools import utils

sys.path.append('/Users/46h/Research/btf/btf-scripts/')
import scan_patterns as sp

In [None]:
pplt.rc['grid'] = False
pplt.rc['cmap.sequential'] = 'viridis'
pplt.rc['cmap.discrete'] = False

## Load data 

In [None]:
datadir = '../Diagnostics/Data/Measurements/transverse4d/2022-07-09/'
filenames = os.listdir(datadir)
filenames

In [None]:
filename = '220709153435-transverse4d'
file = h5py.File(join(datadir, filename + '.h5'), 'r')
print(list(file))

In [None]:
if 'config' in file:
    config = file['config']
    print(f"'config', {type(config)}")
    for key in config:
        print(f"  '{key}', {type(config[key])}")
        for name in config[key].dtype.names:
            print(f'    {name}: {config[key][name]}')
    # Make dictionary of metadata
    metadata = dict()
    for name in config['metadata'].dtype.names:
        metadata[name] = config['metadata'][name]
else:
    # Older measurement; metadata is in json file.
    metadata = json.load(open(join(datadir, filename + '-metadata.json'), 'r'))
    _metadata = dict()
    for _dict in metadata.values():
        for key, value in _dict.items():
            _metadata[key] = value
    metadata = _metadata
    pprint(metadata)

In [None]:
if 'log' in file:
    log = file['log']
    print(f"'log', {type(log)}")
    for item in log.dtype.fields.items():
        print('  ', item)

    print('\nErrors and warnings:')
    for i in range(log.size):
        if not(log[i, 'level'] == 'INFO'.encode('utf')):
            timestr = datetime.fromtimestamp(log[i, 'timestamp']).strftime("%m/%d/%Y, %H:%M:%S")
            print(f"{timestr} {log[i, 'message']}")

In [None]:
data = file['scandata']

print(f"'scandata', {type(data)}")
for item in data.dtype.fields.items():
    print('  ', item)

In [None]:
acts = ['xp_PositionSync', 'x_PositionSync', 'yp_PositionSync', 'y_PositionSync']

## Scan overview 

In [None]:
for pv in ['cam06_Saturation']:
    fig, ax = pplt.subplots(figsize=(8.0, 1.5))
    ax.plot(data[pv], color='black')
    ax.format(ylabel=pv)
    plt.savefig

### Data collection frequency

In [None]:
duration = data[-1, 'timestamp'] - data[0, 'timestamp']
iteration_duration = duration / data[-1, 'iteration']
points_per_iteration = len(data) / data[-1, 'iteration']
print(f'{len(data)} points recorded over {duration:.1f} seconds ({(duration / 3600.0):.1f} hours)')
print(f"Number of iterations: {data[-1, 'iteration']}")
print(f'Effective rep rate: {(len(data) / duration):.2f} Hz')
print(f'Time per iteration: {iteration_duration:.2f} seconds')
print(f'Points per iteration: {points_per_iteration:.2f}')

Look for long pauses during data collection.

In [None]:
dt = np.diff(data[:, 'timestamp'])
rep_rate = 1.0 / np.median(dt)
print(f'reprate = {rep_rate:.2f} Hz')

print('Pauses longer than 30 seconds:')
long_pause = 30.0
print(dt[dt > long_pause])
dt[dt > long_pause] = 0.2

hist, bins = np.histogram(dt, bins=21)
idx_bins = np.digitize(dt, bins)
idx_pause, = np.where(idx_bins > 1)
median_pause = np.median(dt[idx_pause])
print(f'Most pauses are {median_pause:.2f} seconds')

fig, ax = pplt.subplots()
ax.bar(0.5 * (bins[1:] + bins[:-1]), hist, color='black', alpha=0.3)
ax.axvline(median_pause, color='black')
ax.format(xlabel='Pause length [seconds]', ylabel='Number of points', yscale='log')
# plt.savefig('_output/pauses.png')

### BCM current

In [None]:
bcm = 'bcm04_Current'
bcm_limit = 24.0  # [mA]
bcm_data = np.copy(data[bcm])

idx = np.arange(len(data))
idx_mask, = np.where(data[bcm] > -bcm_limit)
idx_valid, = np.where(~np.isin(idx, idx_mask))

print(f'Average BCM current (before masking) = {np.mean(bcm_data):.3f} +- {np.std(bcm_data):.3f} [mA]')
for i in idx_mask:
    print(f'Point {i} masked due to {bcm} current < {bcm_limit:.3f} [mA]')
print(f'Average BCM current (after masking) = {np.mean(bcm_data[idx_valid]):.3f} +- {np.std(bcm_data[idx_valid]):.3f} [mA]')

In [None]:
fig, ax = pplt.subplots(figsize=(7.0, 2.0))
ax.plot(bcm_data[idx], color='black')
ax.plot(idx_mask, bcm_data[idx_mask], color='red', lw=0, marker='.', label='Masked')
ax.format(xlabel='Point', ylabel='BCM current [mA]', ygrid=True)
ax.legend(loc='upper left')
# plt.savefig('_output/bcm_mask.png')
plt.show()

### Slit positions 

In [None]:
for act in acts:
    fig, ax = pplt.subplots(figsize=(7.0, 2.0))
    ax.plot(idx, data[idx, act], color='black')
    ax.plot(idx_mask, data[idx_mask, act], color='red', lw=0, marker='.', label='Masked')
    ax.format(xlabel='Point', ylabel=act)
    ax.legend(loc='upper left')
#     plt.savefig('_output/acts_mask.png')
    plt.show()

In [None]:
fig, axes = pplt.subplots(figsize=(7.0, 5.5), nrows=4, spany=False)
axes.format(cycle='default')
labels = ['x2', 'x1', 'y2', 'y1']
for i in (range(4)):
    axes[i].plot(data[acts[i]], label=labels[i], color='black', lw=1.1)
    axes[i].format(xlabel='Point', ylabel=acts[i])
plt.savefig('_output/acts.png')
plt.show()

In [None]:
fig, axes = pplt.subplots(nrows=len(acts), ncols=len(acts), figwidth=4.0*len(acts)/3, 
                          spanx=False, spany=False, aligny=True)
for i in range(len(acts)):
    for j in range(len(acts)):
        ax = axes[i, j]
        x = data[acts[j]]
        y = data[acts[i]]
        ax.plot(x, y, color='black', alpha=0.3, lw=0.75)
for i, act in enumerate(acts):
    axes[-1, i].set_xlabel(act)
    axes[i, 0].set_ylabel(act)
# plt.savefig('_output/slit_correlations.png')
plt.show()

In [None]:
variables = {
    'xp': {
        'pvname': 'ITSF_Diag:Slit_VT06',
        'center': 13.25,
        'distance': 16.0,
        'steps': 8,
        'min': +4.0,
        'max': +22.5,
    },
    'x': {
        'pvname': 'ITSF_Diag:Slit_VT04',
        'center': 12.5,
        'distance': 15.0,
        'steps': 8,
        'min': -50.0, 
        'max': +50.0,
    },
    'yp': {
        'pvname': 'ITSF_Diag:Slit_HZ06',
        'center': 15.0,
        'distance': 7.0,
        'steps': 8,
        'min': -50.0,
        'max': +50.0,
    },
    'y': {
        'pvname': 'ITSF_Diag:Slit_HZ04',
        'center': 13.0,
        'distance': 25.0,
        'steps': 8,
        'min': -50.0, 
        'max': +50.0,
    },
}
keys = list(variables)

M = np.identity(4)
M[keys.index('xp'), keys.index('x')] = 0.65
M[keys.index('yp'), keys.index('y')] = 0.85  
Minv = np.linalg.inv(M)

center = np.array([variables[key]['center'] for key in keys])
distance = np.array([variables[key]['distance'] for key in keys])
steps = np.array([variables[key]['steps'] for key in keys])

In [None]:
points = np.array([data[act].copy() for act in acts]).T
points_n = utils.apply(Minv, points - center)
points_nn = points_n / (0.5 * distance)

In [None]:
dims = ['x2', 'x1', 'y2', 'y1']
for _points, title in zip((points, points_n, points_nn), ('true', 'upright', 'upright + scaled')):
    fig, axes = pplt.subplots(ncols=4, nrows=4, figwidth=6.0, spanx=False, spany=False)
    axes.format(suptitle=title)
    for i in range(4):
        for j in range(4):
            axes[i, j].scatter(_points[:, j], _points[:, i], c='black', ec='None', s=2)
        axes[i, 0].format(ylabel=dims[i])
        axes[-1, i].format(xlabel=dims[i])
    plt.show()

In [None]:
signal = data['cam06_Integral'].copy()
signal = signal - np.min(signal)
signal = signal / np.max(signal)
thresh = 0.005
w, = np.where(signal >= thresh)

for ax, norm in zip(axes, [None, 'log']):
    fig, ax = pplt.subplots(figsize=(10.0, 2.0))
    ax.plot(signal, color='lightgray', lw=1, marker='.', ms=3, ec='None')
    ax.plot(w, signal[w], color='black', lw=0, marker='.', ms=3, ec='None')
    ax.format(ylabel='signal / max')
    ax.format(yscale=norm)
    plt.savefig(f'_output/thresh_{norm}.png')
    plt.show()

In [None]:
radii = np.sqrt(np.sum(np.square(points_nn), axis=1))

bins = 'auto'
for yscale in [None, 'log']:
    with pplt.rc.context(legendfontsize='medium'):
        fig, ax = pplt.subplots(figsize=(3, 1.85))
        ax.hist(radii, bins=bins, label='all', color='lightgrey')
        ax.hist(radii[w], bins=bins, label='above thresh', color='black')
        ax.format(ylabel='num. points', xlabel='radius', yscale=yscale)
        ax.legend(ncols=2, loc='top', framealpha=0)
        plt.savefig(f'_output/radii_yscale{yscale}.png')
        plt.show()

In [None]:
rmax = np.max(radii[w])

In [None]:
_sort = np.argsort(radii)
_radii = radii[_sort]
_signal = signal[_sort]
_w, = np.where(_signal >= thresh)
kws = dict(lw=0, marker='.', ms=2, ec='None')
for norm in [None, 'log']:
    fig, ax = pplt.subplots(figsize=(3.5, 2.5))
    ax.plot(_radii, _signal, color='lightgray', **kws)
    ax.plot(_radii[_w], _signal[_w], color='black', **kws)
    ax.format(xlabel='radius', ylabel='signal / max')
    ax.format(yscale=norm)
    plt.savefig(f'_output/radii_scatter_{norm}.png')
    plt.show()

In [None]:
counts = []
rs = np.linspace(1.0, rmax, 100)
for r in rs:
    count = np.count_nonzero((np.logical_and(signal >= thresh, radii > r)))
    counts.append(count)
counts = np.array(counts)

fig, ax = pplt.subplots(figsize=(3, 2))
ax.plot(rs, counts / len(w), color='black')
ax.format(xlabel='radius', ylabel='Frac. signal > radius')
ax.format(ylim=(0, ax.get_ylim()[1]))
ax1 = ax.alty(color='red')
ax1.plot(rs, utils.volume_sphere(n=4, r=rs) / utils.volume_box(n=4, r=1.0), color='red')
ax1.format(ylabel='Vball / Vbox')
# ax.format(yscale='log')
plt.savefig('_output/savings.png')

In [None]:
_idx, = np.where(np.logical_and(signal >= thresh, radii < rmax))

fig, axes = pplt.subplots(figsize=(10.0, 4.0), nrows=2)
for ax, norm in zip(axes, [None, 'log']):
    ax.plot(signal, color='lightgray', lw=0, marker='.', ms=2, ec='None')
    ax.plot(_idx, signal[_idx], color='black', lw=0, marker='+', ms=0.5)
    ax.format(ylabel='signal / max')
    ax.format(yscale=norm)
plt.show()

In [None]:
frac = np.count_nonzero(radii <= rmax) / len(radii)
frac

In [None]:
frac_signal = float(len(signal[w])) / len(signal)
print(f'{frac_signal:.3f}')

In [None]:
volume_ratio = utils.volume_sphere(n=4, r=rmax) / utils.volume_box(n=4, r=1.0)
print(volume_ratio)

In [None]:
_points = points_nn
_rmax = rmax

fig, axes = pplt.subplots(ncols=4, nrows=4, figwidth=5.0, spanx=False, spany=False)
# axes.format(suptitle=f'fraction (r < {rmax:.2f}) = {frac:.2f}',
#             suptitle_kw=dict(fontweight='normal'))
for i in range(4):
    for j in range(4):
        ax = axes[i, j]
        ax.scatter(_points[:, j], _points[:, i], c='lightgray', ec='None', s=0.5)
        ax.scatter(_points[w, j], _points[w, i], c='black', ec='None', s=0.5)
        if i != j:
            ax.add_patch(patches.Ellipse((0.0, 0.0), 2.0 * _rmax, 2.0 * _rmax, color='red', fill=False))
    axes[i, 0].format(ylabel=dims[i])
    axes[-1, i].format(xlabel=dims[i])
plt.savefig('_output/bounding_ellipse.png')
plt.show()

## Snug scan

In [None]:
import importlib
importlib.reload(sp)

In [None]:
iterations = data['iteration'].copy()
surface = utils.get_boundary_points(iterations, points, signal, thresh, pad=2.0)

In [None]:
# np.save('/Users/46h/Research/btf/btf-scripts/temp_data/transverse4d-surface.npy', surface)

In [None]:
fig, axes = pplt.subplots(ncols=4, nrows=4, figwidth=6.0, spanx=False, spany=False)
for i in range(4):
    for j in range(4):
        ax = axes[i, j]
        g1 = ax.scatter(points[:, j], points[:, i], c='lightgray', ec='None', s=0.5)
        for spts in surface:
            g3 = ax.scatter(
                spts[:, j], 
                spts[:, i], 
                marker='s', 
                s=7,
                color='pink1',
            )
        g2 = ax.scatter(points[w, j], points[w, i], c='black', ec='None', s=0.5)
    axes[i, 0].format(ylabel=dims[i])
    axes[-1, i].format(xlabel=dims[i])
axes[0, -1].legend([g1, g2, g3], labels=['signal', 'noise', 'new'], loc='r', ncols=1,
                   ms=5, framealpha=0)
plt.savefig('_output/planned_transverse4d.png')
plt.show()

Generate new points for ScanEngine. 

In [None]:
import importlib
importlib.reload(sp)

# Run points generator
navg = 0
ndim = 4
kws = dict(
    variables=variables, 
    M=M, 
    reprate=5.0,
    navg=navg, 
    boundary='None',  # {None, 'ellipsoid'} 
    R=1.1,
    exclude_outside_box=True,
    surface=surface,
)
lgen = list(sp.gen(**kws))

# Reshape
if navg > 0:
    new_points = np.zeros((len(lgen), ndim))
    for i in range(len(lgen)):
        new_points[i, :] = lgen[i][0]
else:
    lgen = np.array(lgen)
    new_points = np.zeros((2 * lgen.shape[0], ndim))
    for i in range(ndim):
        new_points[:, i] = lgen[:, i, :2].ravel()
    
# Un-shear generated points.
new_points_n = utils.apply(Minv, new_points - center)