In [1]:
%matplotlib widget

In [2]:
from matplotlib import pyplot as plt

In [3]:
from libertem.udf import UDF

In [4]:
from libertem.udf.sumsigudf import SumSigUDF
from libertem.udf.base import NoOpUDF
from libertem.common import Slice, Shape

In [5]:
import numpy as np
import zmq
import threading
import time
import uuid
from libertem_live.detectors.common import StoppableThreadMixin, send_serialized, recv_serialized
from libertem_live.detectors.k2is.state import EventType, SetUDFsEvent, StopEvent, EventReplicaClientThread, SetNavShapeEvent, StartProcessingEvent, StopProcessingEvent
from libertem_live.detectors.k2is.proto import ResultSink
from libertem.analysis.com import com_masks_factory, center_shifts, magnitude, curl_2d, divergence, identity, flip_y, rotate_deg
from libertem.udf.masks import ApplyMasksUDF
from libertem.udf.base import UDFMeta, UDFResults
from libertem.common import Shape
from libertem.common.buffers import BufferWrapper
from libertem.viz.bqp import BQLive2DPlot

In [6]:
class FakeDataSet:
    def __init__(self, nav_shape):
        self.nav_shape = nav_shape
        self.shape = Shape(nav_shape + (1860, 2048), sig_dims=2)
        self.dtype = np.uint16


def _get_dtype(udfs, dtype, corrections):
    if corrections is not None and corrections.have_corrections():
        tmp_dtype = np.result_type(np.float32, dtype)
    else:
        tmp_dtype = dtype
    for udf in udfs:
        tmp_dtype = np.result_type(
            udf.get_preferred_input_dtype(),
            tmp_dtype
        )
    return tmp_dtype


def _prepare_run_for_dataset(
    udfs, dataset, executor, roi, corrections, backends, dry
):
    meta = UDFMeta(
        partition_shape=None,
        dataset_shape=dataset.shape,
        roi=roi,
        dataset_dtype=dataset.dtype,
        input_dtype=_get_dtype(udfs, dataset.dtype, corrections),
        corrections=corrections,
    )
    for udf in udfs:
        udf.set_meta(meta)
        udf.init_result_buffers()
        udf.allocate_for_full(dataset, roi)

        if hasattr(udf, 'preprocess'):
            udf.set_views_for_dataset(dataset)
            udf.preprocess()

In [7]:
from libertem.common.buffers import PreallocBufferWrapper

In [8]:
def run_for_dataset_sync(st, udfs, dataset, executor,
                         roi=None, progress=False, corrections=None, backends=None, dry=False):

    cancel_id = str(uuid.uuid4())
    
    sink = ResultSink()

    def _start_processing():
        st.dispatch(SetUDFsEvent(udfs=udfs))
        st.dispatch(SetNavShapeEvent(nav_shape=tuple(dataset.shape.nav)))
        st.dispatch(StartProcessingEvent())
        
    def handle_start(state, new_state, event, effects):
        _start_processing()
    st.sub.store.listen(EventType.CAM_CONNECTED, handle_start)
    st.sub.store.listen(EventType.STARTUP_COMPLETE, handle_start)
    
    _start_processing()

    damage = BufferWrapper(kind='nav', dtype=int)
    damage.set_shape_ds(dataset.shape, roi)
    damage.allocate()

    all_results = []

    udf_copies = [
        udf.copy()
        for udf in udfs
    ]

    _prepare_run_for_dataset(
        udf_copies, dataset, executor, roi, corrections, backends, dry
    )

    current_epoch = 0
    
    try:
        with sink:
            while True:
                # check for errors in our state replication:
                st.maybe_raise()

                # NOTE: keep any recurring checks above the poll, as they need to be executed
                # even if we don't get a result here
                result = sink.poll(timeout=500)
                if result is None:
                    current_epoch = -1  # FIXME: is this too aggressive?
                    continue
                    
                partition_slice, part_results, epoch, packet_counter = result

                if epoch != current_epoch:
                    # reset UDF copies after each run:
                    udf_copies = [
                        udf.copy()
                        for udf in udfs
                    ]
                    _prepare_run_for_dataset(
                        udf_copies, dataset, executor, roi, corrections, backends, dry
                    )
                    current_epoch = epoch
                    damage.data[:] = 0

                all_results.append((partition_slice, part_results, epoch, packet_counter))

                for results, udf in zip(part_results.buffers, udf_copies):
                    udf.set_views_for_partition(partition_slice=partition_slice)
                    udf.merge(
                        dest=udf.results.get_proxy(),
                        src=results.get_proxy()
                    )
                    udf.clear_views()
                v = damage.get_view_for_partition(partition_slice=partition_slice)
                v[:] += 1
                bool_damage = PreallocBufferWrapper(damage.data >= 8, kind='nav', dtype=bool)
                bool_damage.set_shape_ds(dataset.shape, roi)
                yield UDFResults(
                    buffers=tuple(
                        udf._do_get_results()
                        for udf in udf_copies
                    ),
                    damage=bool_damage,
                ), all_results
    finally:
        st.dispatch(StopProcessingEvent())
        st.sub.store.remove_callback(handle_start)

In [9]:
st = EventReplicaClientThread()
st.start()

In [10]:
def print_event(state, new_state, event, effects):
    print(f"reveived event: {event} {time.time()}")
st.sub.store.listen_all(print_event)

In [11]:
def iDPC(x_centers, y_centers):
    realy = x_centers.shape[0]
    realx = x_centers.shape[1]
    
    ky = np.linspace(-0.5, 0.5, realy, endpoint=False).reshape((-1, 1))
    kx = np.linspace(-0.5, 0.5, realx, endpoint=False).reshape((1, -1))
    
    # We shift the arrays instead of the FFT results since they are smaller
    # and it is convenient to have the zero frequency at (0, 0)
    s_kx = np.fft.ifftshift(kx)
    s_ky = np.fft.ifftshift(ky)
    
    half_x = int(np.ceil((realx + 1) / 2))

    # Instead of convertig the complex result to a Hermitian,
    # we just take the FFT for real values that doesn't even calculate
    # those values
    
    fft_DPC_Y = np.fft.rfft2(y_centers)
    fft_DPC_X = np.fft.rfft2(x_centers)
    
    divider = (s_kx[:, :half_x]**2 + s_ky**2)
    # Avoid div0
    divider[0, 0] = 1
    
    fft_iDPC = s_kx[:, :half_x] * fft_DPC_X + s_ky * fft_DPC_Y
    fft_iDPC = fft_iDPC / 2 / np.pi / 1j / divider
    # We can't calculate the absolute phase anyway
    fft_iDPC[0, 0] = 0
    
    result = np.zeros(x_centers.shape, dtype=np.float32)
    
    calculated = np.fft.irfft2(fft_iDPC)
    
    result[:calculated.shape[0], :calculated.shape[1]] = calculated
    
    return result

In [12]:
import numba

In [13]:
a = np.ones((128, 16), dtype=bool)
a[65, 3] = 0
np.all(a, axis=1)

@numba.njit(boundscheck=True)
def get_inner_slice(arr):
    first_idx = 0
    last_idx = 0
    while last_idx < arr.shape[0] and arr[last_idx]:
        last_idx += 1
    return (first_idx, last_idx)

slice(*get_inner_slice(np.all(a, axis=1)))

slice(0, 65, None)

In [14]:
x = 500
y = 440
w = 921

In [15]:
CENTER_Y, CENTER_X = 900, 960
ROTATE = 115


# FLYBACK_COMPENSATION = 1  # 64x64
FLYBACK_COMPENSATION = 4  # 128x128

In [16]:
def transf(y_centers, x_centers, deg):
    t = rotate_deg(deg) @ flip_y()
    t_y, t_x = t @ (y_centers.reshape((-1,)), x_centers.reshape((-1,)),)
    return t_y.reshape(y_centers.shape), t_x.reshape(x_centers.shape)

In [17]:
transf(np.array([(1, 0)]), np.array([(0, 1)]), 45)

(array([[-0.70710678,  0.70710678]]), array([[0.70710678, 0.70710678]]))

In [18]:
from libertem.common.buffers import reshaped_view

def force_shape(data, shape):
    result = np.zeros(shape, dtype=data.dtype)
    flat_result = reshaped_view(result, (-1,))
    flat_data = reshaped_view(data, (-1,))
    end = min(len(flat_result), len(flat_data))
    flat_result[:end] = flat_data[:end]
    return result

In [19]:
def visualize_idpc(udf_result, damage):
    ref_y = CENTER_Y
    ref_x = CENTER_X
    
    data = udf_result['intensity'].data
    
    target_shape = np.array(data.shape[:2])
    target_shape[0] += 0
    target_shape[1] += FLYBACK_COMPENSATION
    
    data = force_shape(data, tuple(target_shape) + tuple(data.shape[2:]))
    damage = force_shape(damage.data, target_shape)
    
    img_sum, img_y, img_x = (
        data[..., 0],
        data[..., 1],
        data[..., 2],
    )
    
    y_centers, x_centers = center_shifts(img_sum, img_y, img_x, ref_y, ref_x)
    y_centers, x_centers = transf(y_centers, x_centers, ROTATE)
    
    # y slice with all-valid values
    inner_slice = slice(*get_inner_slice(np.all(damage, axis=1)))
    
    result = np.zeros_like(y_centers, dtype=np.float32)
    
    if inner_slice.start != inner_slice.stop:
        result[inner_slice, ...] = iDPC(x_centers[inner_slice, ...], y_centers[inner_slice, ...])
    
    new_damage = np.zeros_like(damage)
    new_damage[inner_slice, ...] = True
    
    return result, new_damage

In [20]:
class mockbuf:
    data = np.zeros((1, 65, 3))
    
class mockdamage:
    data = np.ones((1, 65), dtype=bool)

mock = {
    'intensity': mockbuf,
}
visualize_idpc(mock, mockdamage)

(array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0.]], dtype=float32),
 array([[False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False]]))

In [21]:
def visualize_mag(udf_result, damage):
    ref_y = CENTER_Y
    ref_x = CENTER_X
    
    data = udf_result['intensity'].data
    
    target_shape = np.array(data.shape[:2])
    target_shape[0] += 0
    target_shape[1] += FLYBACK_COMPENSATION
    
    data = force_shape(data, tuple(target_shape) + tuple(data.shape[2:]))
    damage = force_shape(damage.data, target_shape)
    
    img_sum, img_y, img_x = (
        data[..., 0],
        data[..., 1],
        data[..., 2],
    )
    y_centers, x_centers = center_shifts(img_sum, img_y, img_x, ref_y, ref_x)
    y_centers, x_centers = transf(y_centers, x_centers, ROTATE)
    
    result = magnitude(y_centers, x_centers)
    return result, damage

In [22]:
def visualize_x(udf_result, damage):
    ref_y = CENTER_Y
    ref_x = CENTER_X
    
    data = udf_result['intensity'].data
    
    target_shape = np.array(data.shape[:2])
    target_shape[0] += 0
    target_shape[1] += FLYBACK_COMPENSATION
    
    data = force_shape(data, tuple(target_shape) + tuple(data.shape[2:]))
    damage = force_shape(damage.data, target_shape)
    
    img_sum, img_y, img_x = (
        data[..., 0],
        data[..., 1],
        data[..., 2],
    )
    y_centers, x_centers = center_shifts(img_sum, img_y, img_x, ref_y, ref_x)
    y_centers, x_centers = transf(y_centers, x_centers, ROTATE)
    
    return x_centers, damage

In [23]:
def visualize_y(udf_result, damage):
    ref_y = CENTER_Y
    ref_x = CENTER_X
    
    data = udf_result['intensity'].data
    
    target_shape = np.array(data.shape[:2])
    target_shape[0] += 0
    target_shape[1] += FLYBACK_COMPENSATION
    
    data = force_shape(data, tuple(target_shape) + tuple(data.shape[2:]))
    damage = force_shape(damage.data, target_shape)
    
    img_sum, img_y, img_x = (
        data[..., 0],
        data[..., 1],
        data[..., 2],
    )
    y_centers, x_centers = center_shifts(img_sum, img_y, img_x, ref_y, ref_x)
    y_centers, x_centers = transf(y_centers, x_centers, ROTATE)
    
    return y_centers, damage

In [24]:
ds = FakeDataSet(nav_shape=(128, 128))
# ds = FakeDataSet(nav_shape=(64, 64))
# ds = FakeDataSet(nav_shape=(128, 128))

# CoM + iDPC

In [25]:
com_udf = ApplyMasksUDF(mask_factories=com_masks_factory(
    detector_y=1860,
    detector_x=2048,
    cx=CENTER_X,
    cy=CENTER_Y,
    r=300,
))

udfs = [com_udf]
# udfs = [NoOpUDF()]

In [26]:
plot_idpc = BQLive2DPlot(udf=udfs[0], dataset=ds, channel=visualize_idpc)
plot_idpc.display()
plot_idpc.color_scale.scheme = 'OrRd'

plot_mag = BQLive2DPlot(udf=udfs[0], dataset=ds, channel=visualize_mag)
plot_mag.display()
plot_mag.color_scale.scheme = 'OrRd'

plot_y = BQLive2DPlot(udf=udfs[0], dataset=ds, channel=visualize_y)
plot_y.display()
plot_y.color_scale.scheme = 'OrRd'

plot_x = BQLive2DPlot(udf=udfs[0], dataset=ds, channel=visualize_x)
plot_x.display()
plot_x.color_scale.scheme = 'OrRd'


plots = [
    plot_idpc,
    plot_mag,
    plot_y,
    plot_x,
]

old_damage = np.zeros(ds.shape.nav, dtype=bool)

for res, all_res in run_for_dataset_sync(st, udfs=udfs, dataset=ds, executor=None,
                                roi=None, progress=False, corrections=None):
    if np.any(old_damage != res.damage.data):
        for plot in plots:
            plot.new_data(udf_results=res.buffers[0], damage=res.damage)

        if old_damage.sum() > res.damage.data.sum():
            print("new epoch?")
        old_damage[:] = res.damage.data

Figure(axes=[Axis(label='x', scale=LinearScale(max=1.0, min=0.0)), Axis(label='y', orientation='vertical', sca…

Figure(axes=[Axis(label='x', scale=LinearScale(max=1.0, min=0.0)), Axis(label='y', orientation='vertical', sca…

Figure(axes=[Axis(label='x', scale=LinearScale(max=1.0, min=0.0)), Axis(label='y', orientation='vertical', sca…

Figure(axes=[Axis(label='x', scale=LinearScale(max=1.0, min=0.0)), Axis(label='y', orientation='vertical', sca…

reveived event: typ=<EventType.SET_UDFS: 'SET_UDFS'> udfs=[<libertem.udf.masks.ApplyMasksUDF object at 0x7f805e620160>] 1620391559.0243816
reveived event: typ=<EventType.SET_NAV_SHAPE: 'SET_NAV_SHAPE'> nav_shape=(128, 128) 1620391559.1963334
reveived event: typ=<EventType.START_PROCESSING: 'START_PROCESSING'> 1620391559.2439923
reveived event: typ=<EventType.CAM_CONNECTED: 'CAM_CONNECTED'> 1620391559.427524
reveived event: typ=<EventType.SET_UDFS: 'SET_UDFS'> udfs=[<libertem.udf.masks.ApplyMasksUDF object at 0x7f805e620df0>] 1620391559.5719082
reveived event: typ=<EventType.SET_NAV_SHAPE: 'SET_NAV_SHAPE'> nav_shape=(128, 128) 1620391559.7074947
reveived event: typ=<EventType.START_PROCESSING: 'START_PROCESSING'> 1620391559.7509334
new epoch?
new epoch?
reveived event: typ=<EventType.STOP_PROCESSING: 'STOP_PROCESSING'> 1620391682.2970736


KeyboardInterrupt: 

In [None]:
plt.figure()
plt.imshow(res.buffers[0]['intensity'].data[..., 0])

In [None]:
plt.figure()
plt.imshow(force_shape(res.buffers[0]['intensity'].data[..., 0], (128, 132)))

# SumSigUDF

In [None]:
from libertem.udf.sumsigudf import SumSigUDF

udfs = [SumSigUDF()]

In [None]:
plot = BQLive2DPlot(udf=udfs[0], dataset=ds, channel='intensity')
plot.display()
plot.color_scale.scheme = 'OrRd'

old_damage = np.zeros(ds.shape.nav, dtype=bool)

for res, all_res in run_for_dataset_sync(st, udfs=udfs, dataset=ds, executor=None,
                                roi=None, progress=False, corrections=None):
    if np.any(old_damage != res.damage.data):
        plot.new_data(udf_results=res.buffers[0], damage=res.damage)

        if old_damage.sum() > res.damage.data.sum():
            print("new epoch?")
        old_damage[:] = res.damage.data

# Sum sample

In [None]:
class SumSampleUDF(UDF):
    """
    Sum up frames, preserving the signal dimension

    Parameters
    ----------
    dtype : numpy.dtype, optional
        Preferred dtype for computation, default 'float32'. The actual dtype will be determined
        from this value and the dataset's dtype using :meth:`numpy.result_type`.
        See also :ref:`udf dtype`.

    Examples
    --------
    >>> udf = SumUDF()
    >>> result = ctx.run_udf(dataset=dataset, udf=udf)
    >>> np.array(result["intensity"]).shape
    (32, 32)
    """
    def __init__(self, dtype='float32'):
        super().__init__(dtype=dtype)

    def get_preferred_input_dtype(self):
        return self.params.dtype

    def get_result_buffers(self):
        ''
        return {
            'intensity': self.buffer(kind='sig', dtype=self.meta.input_dtype)
        }

    def process_tile(self, tile):
        ''
        self.results.intensity[:] += np.sum(tile, axis=0)

    def merge(self, dest, src):
        ''
        dest.intensity[:] += src.intensity  # XXX

In [None]:
udfs = [SumSampleUDF()]

In [None]:
plot = BQLive2DPlot(udf=udfs[0], dataset=ds, channel='intensity', min_delta=5)
plot.display()
plot.color_scale.scheme = 'OrRd'

old_damage = np.zeros(ds.shape.nav, dtype=bool)

for res, all_res in run_for_dataset_sync(st, udfs=udfs, dataset=ds, executor=None,
                                roi=None, progress=False, corrections=None):
    # plot.new_data(udf_results=res.buffers[0], damage=res.damage)
    pass

# Old Stuff

In [None]:
plt.figure()
plt.imshow(res.buffers[0]['intensity'].data, vmax=9.65e4)

In [None]:
plt.figure()
plt.imshow(res.damage.data.astype(int))

In [None]:
res0 = res.buffers[0]['intensity'].data[..., 0].copy()
res0_flat = res0.reshape((-1,))
# res0_flat[Slice.from_shape((100,), sig_dims=0).get()] = 42

In [None]:
plt.figure()
plt.imshow(res0, vmin=res0[res0 > 0].min())

# TEARDOWN!

In [None]:
st.stop()
st.join()