In [1]:
import numpy as np
import zmq
import threading
import time
import uuid
from libertem_live.detectors.common import StoppableThreadMixin, send_serialized, recv_serialized
from libertem_live.detectors.k2is.state import EventType, SetUDFsEvent, StopEvent, EventReplicaClientThread, SetNavShapeEvent, StartProcessingEvent, StopProcessingEvent
from libertem_live.detectors.k2is.proto import ResultSink
from libertem.analysis.com import com_masks_factory
from libertem.udf.masks import ApplyMasksUDF
from libertem.udf.base import UDFMeta, UDFResults
from libertem.common import Shape
from libertem.common.buffers import BufferWrapper
from libertem.viz.bqp import BQLive2DPlot

In [2]:
class FakeDataSet:
    def __init__(self, nav_shape):
        self.nav_shape = nav_shape
        self.shape = Shape(nav_shape + (1860, 2048), sig_dims=2)
        self.dtype = np.uint16


def _get_dtype(udfs, dtype, corrections):
    if corrections is not None and corrections.have_corrections():
        tmp_dtype = np.result_type(np.float32, dtype)
    else:
        tmp_dtype = dtype
    for udf in udfs:
        tmp_dtype = np.result_type(
            udf.get_preferred_input_dtype(),
            tmp_dtype
        )
    return tmp_dtype


def _prepare_run_for_dataset(
    udfs, dataset, executor, roi, corrections, backends, dry
):
    meta = UDFMeta(
        partition_shape=None,
        dataset_shape=dataset.shape,
        roi=roi,
        dataset_dtype=dataset.dtype,
        input_dtype=_get_dtype(udfs, dataset.dtype, corrections),
        corrections=corrections,
    )
    for udf in udfs:
        udf.set_meta(meta)
        udf.init_result_buffers()
        udf.allocate_for_full(dataset, roi)

        if hasattr(udf, 'preprocess'):
            udf.set_views_for_dataset(dataset)
            udf.preprocess()

In [3]:
from libertem.common.buffers import PreallocBufferWrapper

In [15]:
def run_for_dataset_sync(st, udfs, dataset, executor,
                         roi=None, progress=False, corrections=None, backends=None, dry=False):

    cancel_id = str(uuid.uuid4())
    
    sink = ResultSink()

    def _start_processing():
        st.dispatch(SetUDFsEvent(udfs=udfs))
        st.dispatch(SetNavShapeEvent(nav_shape=tuple(dataset.shape.nav)))
        st.dispatch(StartProcessingEvent())
        
    def handle_start(state, new_state, event, effects):
        _start_processing()
    st.sub.store.listen(EventType.CAM_CONNECTED, handle_start)
    st.sub.store.listen(EventType.STARTUP_COMPLETE, handle_start)
    
    _start_processing()

    damage = BufferWrapper(kind='nav', dtype=int)
    damage.set_shape_ds(dataset.shape, roi)
    damage.allocate()

    all_results = []

    udf_copies = [
        udf.copy()
        for udf in udfs
    ]

    _prepare_run_for_dataset(
        udf_copies, dataset, executor, roi, corrections, backends, dry
    )

    current_epoch = 0
    
    try:
        with sink:
            while True:
                # check for errors in our state replication:
                st.maybe_raise()

                # NOTE: keep any recurring checks above the poll, as they need to be executed
                # even if we don't get a result here
                result = sink.poll(timeout=500)
                if result is None:
                    current_epoch = -1  # FIXME: is this too aggressive?
                    continue
                    
                partition_slice, part_results, epoch, packet_counter = result

                if epoch != current_epoch:
                    # reset UDF copies after each run:
                    udf_copies = [
                        udf.copy()
                        for udf in udfs
                    ]
                    _prepare_run_for_dataset(
                        udf_copies, dataset, executor, roi, corrections, backends, dry
                    )
                    current_epoch = epoch
                    damage.data[:] = 0

                all_results.append((partition_slice, part_results, epoch, packet_counter))

                for results, udf in zip(part_results.buffers, udf_copies):
                    udf.set_views_for_partition(partition_slice=partition_slice)
                    udf.merge(
                        dest=udf.results.get_proxy(),
                        src=results.get_proxy()
                    )
                    udf.clear_views()
                v = damage.get_view_for_partition(partition_slice=partition_slice)
                v[:] += 1
                bool_damage = PreallocBufferWrapper(damage.data >= 8, kind='nav', dtype=bool)
                bool_damage.set_shape_ds(dataset.shape, roi)
                yield UDFResults(
                    buffers=tuple(
                        udf._do_get_results()
                        for udf in udf_copies
                    ),
                    damage=bool_damage,
                ), all_results
    finally:
        st.dispatch(StopProcessingEvent())
        st.sub.store.remove_callback(handle_start)

In [13]:
st = EventReplicaClientThread()
st.start()

In [6]:
def print_event(state, new_state, event, effects):
    print(f"reveived event: {event} {time.time()}")
st.sub.store.listen_all(print_event)

In [7]:
ds = FakeDataSet(nav_shape=(40, 20))
# ds = FakeDataSet(nav_shape=(128, 128))

In [8]:
from libertem.udf import UDF

In [9]:
class NonZeroUDF(UDF):
    def get_result_buffers(self):
        return {
            'intensity': self.buffer(kind='nav'),
            'slices': self.buffer(kind='single', dtype=object),
            # 'slices': self.buffer(kind='single', dtype="python", initial=[]),
        }
    
    def preprocess(self):
        self.results.slices[0] = []
    
    def process_tile(self, tile):
        s = np.sum(tile, axis=(1, 2))
        self.results.intensity[:] += s
        self.results.slices[0].append(self.meta.slice)
        assert np.all(s > 0)
        
    def merge(self, dest, src):
        dest.intensity[:] += src.intensity
        dest.slices[0].extend(src.slices[0])

In [10]:
from libertem.udf.sumsigudf import SumSigUDF
from libertem.udf.base import NoOpUDF

In [None]:
com_udf = ApplyMasksUDF(mask_factories=com_masks_factory(
    detector_y=1860,
    detector_x=2048,
    cx=2048/2,
    cy=1860/2,
    r=np.inf,
))
# udfs = [com_udf, NonZeroUDF()]
udfs = [com_udf]
# udfs = [NoOpUDF()]

plot = BQLive2DPlot(udf=udfs[0], dataset=ds, channel=('intensity', lambda x: x[..., 0]))
plot.display()

old_damage = np.zeros(ds.shape.nav, dtype=bool)

for res, all_res in run_for_dataset_sync(st, udfs=udfs, dataset=ds, executor=None,
                                roi=None, progress=False, corrections=None):
    if np.any(old_damage != res.damage.data):
        plot.new_data(udf_results=res.buffers[0], damage=res.damage)
        old_damage[:] = res.damage.data

Figure(axes=[Axis(label='x', scale=LinearScale(max=1.0, min=0.0)), Axis(label='y', orientation='vertical', sca…

In [None]:
all_slice_origins = set([s.origin[0] for s in res.buffers[1]['slices'].data[0]])
for i in range(20*40):
    assert i in all_slice_origins

In [None]:
all_res

In [None]:
2*32*100

In [None]:
print(all_res[0][0].get(nav_only=True))
print(all_res[8][0].get(nav_only=True))

In [None]:
all_res[8][1].buffers[0].intensity[..., 2]

In [None]:
[(r[0].origin[0], r[3]) for r in all_res]

In [None]:
%matplotlib widget

In [None]:
from matplotlib import pyplot as plt

In [None]:
hmm = np.zeros(ds.shape.nav)
hmm_flat = hmm.reshape((-1,))
for s_, _, _, _ in all_res:
    hmm_flat[s_.get(nav_only=True)] += 1

In [None]:
assert np.allclose(hmm, 8)

In [None]:
plt.figure()
plt.imshow(res.buffers[0]['intensity'].data[..., 0])

In [None]:
plt.figure()
plt.imshow(res.damage.data.astype(int))

In [None]:
from libertem.common import Slice, Shape

In [None]:
res0 = res.buffers[0]['intensity'].data[..., 0].copy()
res0_flat = res0.reshape((-1,))
# res0_flat[Slice.from_shape((100,), sig_dims=0).get()] = 42

In [None]:
plt.figure()
plt.imshow(res0, vmin=res0[res0 > 0].min())

# TEARDOWN!

In [12]:
st.stop()
st.join()