In [1]:
import numpy as np
import zmq
import threading
import time
import uuid
from libertem_live.detectors.common import StoppableThreadMixin, send_serialized, recv_serialized
from libertem_live.detectors.k2is.state import EventType, SetUDFsEvent, StopEvent, EventReplicaClientThread, SetNavShapeEvent, StartProcessingEvent, StopProcessingEvent
from libertem_live.detectors.k2is.proto import ResultSink
from libertem.analysis.com import com_masks_factory, center_shifts
from libertem.udf.masks import ApplyMasksUDF
from libertem.udf.base import UDFMeta, UDFResults
from libertem.common import Shape
from libertem.common.buffers import BufferWrapper
from libertem.viz.bqp import BQLive2DPlot

In [2]:
class FakeDataSet:
    def __init__(self, nav_shape):
        self.nav_shape = nav_shape
        self.shape = Shape(nav_shape + (1860, 2048), sig_dims=2)
        self.dtype = np.uint16


def _get_dtype(udfs, dtype, corrections):
    if corrections is not None and corrections.have_corrections():
        tmp_dtype = np.result_type(np.float32, dtype)
    else:
        tmp_dtype = dtype
    for udf in udfs:
        tmp_dtype = np.result_type(
            udf.get_preferred_input_dtype(),
            tmp_dtype
        )
    return tmp_dtype


def _prepare_run_for_dataset(
    udfs, dataset, executor, roi, corrections, backends, dry
):
    meta = UDFMeta(
        partition_shape=None,
        dataset_shape=dataset.shape,
        roi=roi,
        dataset_dtype=dataset.dtype,
        input_dtype=_get_dtype(udfs, dataset.dtype, corrections),
        corrections=corrections,
    )
    for udf in udfs:
        udf.set_meta(meta)
        udf.init_result_buffers()
        udf.allocate_for_full(dataset, roi)

        if hasattr(udf, 'preprocess'):
            udf.set_views_for_dataset(dataset)
            udf.preprocess()

In [3]:
from libertem.common.buffers import PreallocBufferWrapper

In [4]:
def run_for_dataset_sync(st, udfs, dataset, executor,
                         roi=None, progress=False, corrections=None, backends=None, dry=False):

    cancel_id = str(uuid.uuid4())
    
    sink = ResultSink()

    def _start_processing():
        st.dispatch(SetUDFsEvent(udfs=udfs))
        st.dispatch(SetNavShapeEvent(nav_shape=tuple(dataset.shape.nav)))
        st.dispatch(StartProcessingEvent())
        
    def handle_start(state, new_state, event, effects):
        _start_processing()
    st.sub.store.listen(EventType.CAM_CONNECTED, handle_start)
    st.sub.store.listen(EventType.STARTUP_COMPLETE, handle_start)
    
    _start_processing()

    damage = BufferWrapper(kind='nav', dtype=int)
    damage.set_shape_ds(dataset.shape, roi)
    damage.allocate()

    all_results = []

    udf_copies = [
        udf.copy()
        for udf in udfs
    ]

    _prepare_run_for_dataset(
        udf_copies, dataset, executor, roi, corrections, backends, dry
    )

    current_epoch = 0
    
    try:
        with sink:
            while True:
                # check for errors in our state replication:
                st.maybe_raise()

                # NOTE: keep any recurring checks above the poll, as they need to be executed
                # even if we don't get a result here
                result = sink.poll(timeout=500)
                if result is None:
                    current_epoch = -1  # FIXME: is this too aggressive?
                    continue
                    
                partition_slice, part_results, epoch, packet_counter = result

                if epoch != current_epoch:
                    # reset UDF copies after each run:
                    udf_copies = [
                        udf.copy()
                        for udf in udfs
                    ]
                    _prepare_run_for_dataset(
                        udf_copies, dataset, executor, roi, corrections, backends, dry
                    )
                    current_epoch = epoch
                    damage.data[:] = 0

                all_results.append((partition_slice, part_results, epoch, packet_counter))

                for results, udf in zip(part_results.buffers, udf_copies):
                    udf.set_views_for_partition(partition_slice=partition_slice)
                    udf.merge(
                        dest=udf.results.get_proxy(),
                        src=results.get_proxy()
                    )
                    udf.clear_views()
                v = damage.get_view_for_partition(partition_slice=partition_slice)
                v[:] += 1
                bool_damage = PreallocBufferWrapper(damage.data >= 8, kind='nav', dtype=bool)
                bool_damage.set_shape_ds(dataset.shape, roi)
                yield UDFResults(
                    buffers=tuple(
                        udf._do_get_results()
                        for udf in udf_copies
                    ),
                    damage=bool_damage,
                ), all_results
    finally:
        st.dispatch(StopProcessingEvent())
        st.sub.store.remove_callback(handle_start)

In [5]:
st = EventReplicaClientThread()
st.start()

In [6]:
def print_event(state, new_state, event, effects):
    print(f"reveived event: {event} {time.time()}")
st.sub.store.listen_all(print_event)

In [7]:
ds = FakeDataSet(nav_shape=(40, 20))
# ds = FakeDataSet(nav_shape=(128, 128))

In [8]:
from libertem.udf import UDF

In [9]:
from libertem.udf.sumsigudf import SumSigUDF
from libertem.udf.base import NoOpUDF

In [10]:
def iDPC(x_centers, y_centers):
    realy = x_centers.shape[0]
    realx = x_centers.shape[1]
    
    ky = np.linspace(-0.5, 0.5, realy, endpoint=False).reshape((-1, 1))
    kx = np.linspace(-0.5, 0.5, realx, endpoint=False).reshape((1, -1))
    
    # We shift the arrays instead of the FFT results since they are smaller
    # and it is convenient to have the zero frequency at (0, 0)
    s_kx = np.fft.ifftshift(kx)
    s_ky = np.fft.ifftshift(ky)
    
    half_x = int(np.ceil((realx + 1) / 2))

    # Instead of convertig the complex result to a Hermitian,
    # we just take the FFT for real values that doesn't even calculate
    # those values
    
    fft_DPC_Y = np.fft.rfft2(y_centers)
    fft_DPC_X = np.fft.rfft2(x_centers)
    
    divider = (s_kx[:, :half_x]**2 + s_ky**2)
    # Avoid div0
    divider[0, 0] = 1
    
    fft_iDPC = s_kx[:, :half_x] * fft_DPC_X + s_ky * fft_DPC_Y
    fft_iDPC = fft_iDPC / 2 / np.pi / 1j / divider
    # We can't calculate the absolute phase anyway
    fft_iDPC[0, 0] = 0
    
    return np.fft.irfft2(fft_iDPC)

In [17]:
import numba

In [25]:
a = np.ones((128, 16), dtype=bool)
a[65, 3] = 0
np.all(a, axis=1)

@numba.njit(boundscheck=True)
def get_inner_slice(arr):
    first_idx = 0
    last_idx = 0
    while last_idx < arr.shape[0] and arr[last_idx]:
        last_idx += 1
    return (first_idx, last_idx)

slice(*get_inner_slice(np.all(a, axis=1)))

slice(0, 65, None)

In [37]:
def visualize_idpc(udf_result, damage):
    # FIXME: ref_x, ref_y
    ref_y = 1860 / 2
    ref_x = 2048 / 2
    
    data = udf_result['intensity'].data
    
    img_sum, img_y, img_x = (
        data[..., 0],
        data[..., 1],
        data[..., 2],
    )
    y_centers, x_centers = center_shifts(img_sum, img_y, img_x, ref_y, ref_x)
    
    # y slice with all-valid values
    inner_slice = slice(*get_inner_slice(np.all(damage, axis=1)))
    
    result = np.zeros_like(y_centers, dtype=np.float32)
    
    if inner_slice.start != inner_slice.stop:
        result[inner_slice, ...] = iDPC(x_centers[inner_slice, ...], y_centers[inner_slice, ...])
    
    new_damage = np.zeros_like(damage)
    new_damage[inner_slice, ...] = True
    
    return result, new_damage

In [35]:
%pdb off

Automatic pdb calling has been turned OFF


In [38]:
com_udf = ApplyMasksUDF(mask_factories=com_masks_factory(
    detector_y=1860,
    detector_x=2048,
    cx=2048/2,
    cy=1860/2,
    r=np.inf,
))

udfs = [com_udf]
# udfs = [NoOpUDF()]

plot = BQLive2DPlot(udf=udfs[0], dataset=ds, channel=visualize_idpc)
plot.display()

plot.color_scale.scheme = 'OrRd'

old_damage = np.zeros(ds.shape.nav, dtype=bool)

for res, all_res in run_for_dataset_sync(st, udfs=udfs, dataset=ds, executor=None,
                                roi=None, progress=False, corrections=None):
    if np.any(old_damage != res.damage.data):
        plot.new_data(udf_results=res.buffers[0], damage=res.damage)

        if old_damage.sum() > res.damage.data.sum():
            print("new epoch?")
        old_damage[:] = res.damage.data

Figure(axes=[Axis(label='x', scale=LinearScale(max=1.0, min=0.0)), Axis(label='y', orientation='vertical', sca…

reveived event: typ=<EventType.SET_UDFS: 'SET_UDFS'> udfs=[<libertem.udf.masks.ApplyMasksUDF object at 0x7f27f80747c0>] 1620317730.3108945
reveived event: typ=<EventType.SET_NAV_SHAPE: 'SET_NAV_SHAPE'> nav_shape=(40, 20) 1620317730.387529
reveived event: typ=<EventType.START_PROCESSING: 'START_PROCESSING'> 1620317730.465409
reveived event: typ=<EventType.STOPPED: 'STOPPED'> 1620317733.2544231
reveived event: typ=<EventType.STARTING: 'STARTING'> 1620317739.5169349
reveived event: typ=<EventType.STARTUP_COMPLETE: 'STARTUP_COMPLETE'> 1620317739.564932
reveived event: typ=<EventType.SET_UDFS: 'SET_UDFS'> udfs=[<libertem.udf.masks.ApplyMasksUDF object at 0x7f27f8aac820>] 1620317739.6154518
reveived event: typ=<EventType.SET_NAV_SHAPE: 'SET_NAV_SHAPE'> nav_shape=(40, 20) 1620317739.7207613
reveived event: typ=<EventType.START_PROCESSING: 'START_PROCESSING'> 1620317739.8311477
reveived event: typ=<EventType.CAM_CONNECTED: 'CAM_CONNECTED'> 1620317747.6021523
new epoch?
reveived event: typ=<Eve

KeyboardInterrupt: 

In [None]:
all_slice_origins = set([s.origin[0] for s in res.buffers[1]['slices'].data[0]])
for i in range(20*40):
    assert i in all_slice_origins

In [None]:
all_res

In [None]:
2*32*100

In [None]:
print(all_res[0][0].get(nav_only=True))
print(all_res[8][0].get(nav_only=True))

In [None]:
all_res[8][1].buffers[0].intensity[..., 2]

In [None]:
[(r[0].origin[0], r[3]) for r in all_res]

In [None]:
%matplotlib widget

In [None]:
from matplotlib import pyplot as plt

In [None]:
hmm = np.zeros(ds.shape.nav)
hmm_flat = hmm.reshape((-1,))
for s_, _, _, _ in all_res:
    hmm_flat[s_.get(nav_only=True)] += 1

In [None]:
assert np.allclose(hmm, 8)

In [None]:
plt.figure()
plt.imshow(res.buffers[0]['intensity'].data[..., 0])

In [None]:
plt.figure()
plt.imshow(res.damage.data.astype(int))

In [None]:
from libertem.common import Slice, Shape

In [None]:
res0 = res.buffers[0]['intensity'].data[..., 0].copy()
res0_flat = res0.reshape((-1,))
# res0_flat[Slice.from_shape((100,), sig_dims=0).get()] = 42

In [None]:
plt.figure()
plt.imshow(res0, vmin=res0[res0 > 0].min())

# TEARDOWN!

In [None]:
st.stop()
st.join()