In [None]:
%matplotlib widget

In [None]:
import multiprocessing as mp
mp.set_start_method('spawn')

In [None]:
from matplotlib import pyplot as plt

In [None]:
import numpy as np
import uuid

In [None]:
from libertem.common import Shape, Slice
from libertem.udf.base import UDFMeta, UDFResults
from libertem.common.buffers import BufferWrapper
from libertem.io.dataset.base import DataSetMeta

In [None]:
from libertem.udf.sum import SumUDF
from libertem.udf.sumsigudf import SumSigUDF
from libertem_live.detectors.k2is.proto import MySubProcess, SyncState, PlaceholderPartition

In [None]:
num_frames = 1800

In [None]:
class FakeDataSet:
    def __init__(self):
        self.shape = Shape((num_frames, 1860, 2048), sig_dims=2)
        self.dtype = np.uint16

In [None]:
class FakeExecutor:
    def run_tasks(self, tasks, cancel_id):
        ss = SyncState(num_processes=len(tasks))
        processes = []
        oqs = []
        try:
            for i in tasks:
                oq = mp.Queue()
                p = MySubProcess(idx=i, sync_state=ss, udfs=udfs, out_queue=oq)
                p.start()
                processes.append(p)
                oqs.append(oq)
            for idx, q in enumerate(oqs):
                print(f"getting result from q {q} ({idx})")
                yield q.get(), idx
        finally:
            for p in processes:
                p.join()

    def ensure_sync(self):
        return self

In [None]:
def make_udf_tasks(dataset, roi, corrections, backends):
    # in case of a k2is live dataset, we need to create "tasks" for each partition, so for each sector:
    assert roi is None
    assert corrections is None or not corrections.have_corrections()
    
    return list(range(8))
    # raise NotImplementedError()

In [None]:
def _get_dtype(udfs, dtype, corrections):
    if corrections is not None and corrections.have_corrections():
        tmp_dtype = np.result_type(np.float32, dtype)
    else:
        tmp_dtype = dtype
    for udf in udfs:
        tmp_dtype = np.result_type(
            udf.get_preferred_input_dtype(),
            tmp_dtype
        )
    return tmp_dtype

In [None]:
def _prepare_run_for_dataset(
    udfs, dataset, executor, roi, corrections, backends, dry
):
    meta = UDFMeta(
        partition_shape=None,
        dataset_shape=dataset.shape,
        roi=roi,
        dataset_dtype=dataset.dtype,
        input_dtype=_get_dtype(udfs, dataset.dtype, corrections),
        corrections=corrections,
    )
    for udf in udfs:
        udf.set_meta(meta)
        udf.init_result_buffers()
        udf.allocate_for_full(dataset, roi)

        if hasattr(udf, 'preprocess'):
            udf.set_views_for_dataset(dataset)
            udf.preprocess()
    if dry:
        tasks = []
    else:
        tasks = list(make_udf_tasks(dataset, roi, corrections, backends))
    return tasks

In [None]:
def _partition_by_idx(idx):
    # num_frames = 1800  # less than 10 seconds

    meta = DataSetMeta(
        shape=Shape((num_frames, 1860, 2048), sig_dims=2),
        image_count=num_frames,
        raw_dtype=np.uint16,
    )

    x_offset = 256 * idx
    
    partition_slice = Slice(
        origin=(0, 0, x_offset),
        shape=Shape((num_frames, 1860, 256), sig_dims=2),
    )

    # let's first create single partition per sector, with size >= what
    # we expect during 10 seconds of runtime
    return PlaceholderPartition(
        meta=meta,
        partition_slice=partition_slice,
        tiles=[],
        start_frame=0,
        num_frames=num_frames,
    )

In [None]:
def run_for_dataset_sync(udfs, dataset, executor,
                    roi=None, progress=False, corrections=None, backends=None, dry=False):
    tasks = _prepare_run_for_dataset(
        udfs, dataset, executor, roi, corrections, backends, dry
    )
    cancel_id = str(uuid.uuid4())

    if progress:
        from tqdm import tqdm
        t = tqdm(total=len(tasks))

    executor = executor.ensure_sync()

    damage = BufferWrapper(kind='nav', dtype=bool)
    damage.set_shape_ds(dataset.shape, roi)
    damage.allocate()
    if tasks:
        for part_results, task in executor.run_tasks(tasks, cancel_id):
            if progress:
                t.update(1)
            for results, udf in zip(part_results, udfs):
                udf.set_views_for_partition(_partition_by_idx(task))
                udf.merge(
                    dest=udf.results.get_proxy(),
                    src=results.get_proxy()
                )
                udf.clear_views()
            v = damage.get_view_for_partition(_partition_by_idx(task))
            v[:] = True
            yield UDFResults(
                buffers=tuple(
                    udf._do_get_results()
                    for udf in udfs
                ),
                damage=damage
            )
    else:
        # yield at least one result (which should be empty):
        for udf in udfs:
            udf.clear_views()
        yield UDFResults(
            buffers=tuple(
                udf._do_get_results()
                for udf in udfs
            ),
            damage=damage
        )

    if progress:
        t.close()

# kind="sig"

In [None]:
import time

In [None]:
num_processes = 8
ss = SyncState(num_processes=num_processes)
processes = []
oqs = []
udfs = [SumUDF()]
try:
    for i in range(num_processes):
        oq = mp.Queue()
        p = MySubProcess(idx=i, sync_state=ss, udfs=udfs, out_queue=oq, acqtime=10)
        p.start()
        processes.append(p)
        oqs.append(oq)

        # because 
        time.sleep(15)
        results = []
        for q in oqs:
            while not q.empty():
                results.append(q.get())
finally:
    
    for p in processes:
        print(f"joining process {p}")
        p.join()
        print(f"joined process {p}")

In [None]:
plt.figure()
plt.imshow(results[0][0].intensity)

In [None]:
res0 = oqs[0].get()

# kind="nav"

In [None]:
num_processes = 8
ss = SyncState(num_processes=num_processes)
processes = []
oqs = []
udfs = [SumSigUDF()]
try:
    for i in range(num_processes):
        oq = mp.Queue()
        p = MySubProcess(idx=i, sync_state=ss, udfs=udfs, out_queue=oq, acqtime=10)
        p.start()
        processes.append(p)
        oqs.append(oq)
finally:
    for p in processes:
        p.join()

In [None]:
res0 = oqs[0].get()

In [None]:
res0[0].intensity.shape

In [None]:
res0sq = res0[0].intensity[:42*42].reshape((42, 42))
print(res0sq[res0sq>0].min())
plt.figure()
plt.imshow(res0sq, vmin=3.317e8, vmax=res0sq.max())
plt.colorbar()

In [None]:
res0sq = res0[0].intensity[:3969].reshape((63, 63))

In [None]:
plt.figure()
plt.imshow(res0sq)

# run for dataset

In [None]:
udfs = [SumSigUDF()]
ds = FakeDataSet()
executor = FakeExecutor()
for res in run_for_dataset_sync(udfs, dataset=ds, executor=executor):
    print(res)

In [None]:
res.buffers[0]['intensity']

In [None]:
res0sq = res.buffers[0]['intensity'].data[:3969].reshape((63, 63))
vmin = res0sq[res0sq != 0].min()
vmax = res0sq[res0sq != 0].max()
plt.figure()
plt.imshow(res0sq[10:50, ...])
plt.colorbar()

# run for dataset with kind='sig'

In [None]:
udfs = [SumUDF()]
ds = FakeDataSet()
executor = FakeExecutor()
for res in run_for_dataset_sync(udfs, dataset=ds, executor=executor):
    print(res)

In [None]:
plt.figure()
plt.imshow(res.buffers[0]['intensity'])