# Auto Batching

## Bacthed Queue: `.get()` a batched output 
### Single Consumer - Multiple Publisher

`get multiple items not more than the specified batch size` **and** `get the items within the timeout`

In [3]:
from queue import Queue, Empty
import time
from threading import Thread
from threading import Event
from typing import Callable, List
import random
import uuid

from typing import Any
from dataclasses import dataclass, field
import uuid
from threading import Event

In [4]:
class BatchedQueue:
    def __init__(self, timeout=1.0, bs=1):
        self.timeout = timeout
        self.bs = bs
        self._queue: Queue = Queue()
        self._result = []
        self._event = Event()

    def get(self):
        entered_at = time.time()
        timeout = self.timeout
        bs = self.bs

        if self._queue.qsize() >= bs:
            return [self._queue.get_nowait() for _ in range(bs)]

        while (
            self._event.wait(timeout - (time.time() - entered_at))
            and self._queue.qsize() < bs
        ):
            True

        result = []
        try:
            for _ in range(bs):
                result.append(self._queue.get_nowait())
            return result
        except Empty:
            return result

    def put(self, item):
        self._queue.put(item)
        if self._event.is_set() and self.size >= self.bs:
            self._event.set()

    @property
    def size(self):
        return self._queue.qsize()

In [3]:
q = BatchedQueue(timeout=2, bs=4)

q.put(1)
q.size

1

In [4]:
t0 = time.time()
print("size", q.size)
q.get()
print("size", q.size)
time.time() - t0

size 1
size 0


2.005953073501587

### Test with a publisher

In [5]:
import random

In [6]:
q = BatchedQueue(timeout=2, bs=4)


def publisher():
    for i in range(16):
        time.sleep(random.randint(0, 1))
        q.put(random.randint(1000, 100000))


thread1 = Thread(target=publisher, daemon=True)
thread2 = Thread(target=publisher, daemon=True)
thread3 = Thread(target=publisher, daemon=True)

thread1.start()
thread2.start()
thread3.start()

In [7]:
q.size

4

In [8]:
for i in range(12):
    t0 = time.time()
    items = q.get()
    print(items)
    t1 = time.time()
    print(f"consumed in {t1-t0:.2f}")

[51250, 40816, 7903, 26584]
consumed in 0.00
[87359, 66977, 26592, 39962]
consumed in 2.01
[93101, 41389, 98259, 70794]
consumed in 0.00
[68374, 40223, 48053, 6779]
consumed in 0.00
[39984, 49248, 43617, 37299]
consumed in 2.00
[66219, 54735, 50403, 76381]
consumed in 0.00
[77249, 19677, 32807, 18666]
consumed in 0.00
[54201, 19437, 64951, 82883]
consumed in 2.01
[65781, 19257, 38511, 6116]
consumed in 0.00
[77809, 68513, 59265, 36806]
consumed in 0.00
[76478, 51127, 54841, 48616]
consumed in 2.00
[92335, 76757, 11130, 48192]
consumed in 0.00


## Batched Processor

In [58]:
@dataclass
class WaitedObject:
    item: Any = None
    completed: bool = False
    result: Any = None
    _event: Event = None

    def __post_init__(self):
        self._event = Event()

    def set_result(self, result):
        self.result = result
        self.completed = True
        self._event.set()

    def get(self, timeout: float = None):
        if self.completed:
            return self.result

        if self._event.is_set():
            raise Exception("Already waiting!")
        self._event.wait(timeout)
        return self.result

    def __repr__(self):
        return f"WaitedOjb({dict(item=self.item, completed=self.completed, result=self.result, signal=self._event.is_set())})"

In [59]:
a = WaitedObject(item=1)
a.set_result(2)
a.get(1)

2

In [42]:
class BatchProcessor:
    def __init__(
        self,
        func: Callable,
        timeout=4.0,
        bs=1,
    ):
        self._batched_queue = BatchedQueue(timeout=timeout, bs=bs)
        self.func = func
        self._event = Event()
        self._cancel_signal = Event()

        self._thread = Thread(target=self._process_queue)
        self._thread.start()

    def _process_queue(self):
        print("Started processing")
        while True:
            if self._cancel_signal.is_set():
                print("Stopped batch processor")
                return
            t0 = time.time()
            batch: List[WaitedObject] = self._batched_queue.get()
            t1 = time.time()
            # print(f"waited {t1-t0:.2f}s for batch")
            if not batch:
                # print("no batch")
                continue
            batch_items = [b.item for b in batch]
            # print(batch_items)
            results = self.func(batch_items)
            for b, result in zip(batch, results):
                b.set_result(result)

    def process(self, item: Any):
        waited_obj = WaitedObject(item=item)
        self._batched_queue.put(waited_obj)
        return waited_obj

    def cancel(self):
        self._cancel_signal.set()
        self._thread.join()

In [45]:
def fake_ml_api(X):
    n = len(X)
    print(f"{n} items")
    time.sleep(random.randint(0, 2))
    return [random.randint(0, 1) for _ in range(n)]

In [62]:
processor = BatchProcessor(fake_ml_api, timeout=4, bs=16)

Started processing


In [63]:
results = []
for i in range(32):
    x = processor.process(i + 100)
    results.append(x)

In [64]:
results

[WaitedOjb({'item': 100, 'completed': False, 'result': None, 'signal': False}),
 WaitedOjb({'item': 101, 'completed': False, 'result': None, 'signal': False}),
 WaitedOjb({'item': 102, 'completed': False, 'result': None, 'signal': False}),
 WaitedOjb({'item': 103, 'completed': False, 'result': None, 'signal': False}),
 WaitedOjb({'item': 104, 'completed': False, 'result': None, 'signal': False}),
 WaitedOjb({'item': 105, 'completed': False, 'result': None, 'signal': False}),
 WaitedOjb({'item': 106, 'completed': False, 'result': None, 'signal': False}),
 WaitedOjb({'item': 107, 'completed': False, 'result': None, 'signal': False}),
 WaitedOjb({'item': 108, 'completed': False, 'result': None, 'signal': False}),
 WaitedOjb({'item': 109, 'completed': False, 'result': None, 'signal': False}),
 WaitedOjb({'item': 110, 'completed': False, 'result': None, 'signal': False}),
 WaitedOjb({'item': 111, 'completed': False, 'result': None, 'signal': False}),
 WaitedOjb({'item': 112, 'completed': Fa

In [65]:
print(results[0].get())
print(results[31].get())

16 items
16 items
0
1


In [66]:
results

[WaitedOjb({'item': 100, 'completed': True, 'result': 0, 'signal': True}),
 WaitedOjb({'item': 101, 'completed': True, 'result': 1, 'signal': True}),
 WaitedOjb({'item': 102, 'completed': True, 'result': 0, 'signal': True}),
 WaitedOjb({'item': 103, 'completed': True, 'result': 1, 'signal': True}),
 WaitedOjb({'item': 104, 'completed': True, 'result': 1, 'signal': True}),
 WaitedOjb({'item': 105, 'completed': True, 'result': 0, 'signal': True}),
 WaitedOjb({'item': 106, 'completed': True, 'result': 1, 'signal': True}),
 WaitedOjb({'item': 107, 'completed': True, 'result': 0, 'signal': True}),
 WaitedOjb({'item': 108, 'completed': True, 'result': 0, 'signal': True}),
 WaitedOjb({'item': 109, 'completed': True, 'result': 1, 'signal': True}),
 WaitedOjb({'item': 110, 'completed': True, 'result': 1, 'signal': True}),
 WaitedOjb({'item': 111, 'completed': True, 'result': 0, 'signal': True}),
 WaitedOjb({'item': 112, 'completed': True, 'result': 0, 'signal': True}),
 WaitedOjb({'item': 113, 