In [None]:
# | default_exp _components.task_streaming

In [None]:
# | export 

import asyncio
import sys
from abc import ABC, abstractmethod

from asyncio import Task
from contextlib import asynccontextmanager
from typing import *

import anyio
from aiokafka import ConsumerRecord

from logging import Logger
from fastkafka._components.logger import get_logger

In [None]:
from datetime import datetime, timedelta

from anyio import create_task_group, create_memory_object_stream, ExceptionGroup
from unittest.mock import Mock, MagicMock, AsyncMock

import asyncer
import pytest
from aiokafka import ConsumerRecord, TopicPartition
from pydantic import BaseModel, Field, HttpUrl, NonNegativeInt
from tqdm.notebook import tqdm
from types import CoroutineType

from fastkafka._components.logger import supress_timestamps

In [None]:
# | export

logger = get_logger(__name__)

In [None]:
supress_timestamps()
logger = get_logger(__name__, level=20)
logger.info("ok")

[INFO] __main__: ok


## anyio stream is not running tasks in parallel
> Memory object stream is buffering the messages but the messages are consumed one by one and a new one is consumed only after the last one is finished

In [None]:
num_msgs = 5
latency = 0.2

receive_pbar = tqdm(total=num_msgs*2)

async def latency_task():
    receive_pbar.update(1)
    await asyncio.sleep(latency)
    receive_pbar.update(1)

async def process_message_callback(
        receive_stream,
) -> None:
    async with receive_stream:
        async for task in receive_stream:
            await task

send_stream, receive_stream = anyio.create_memory_object_stream(
    max_buffer_size=num_msgs
)

t0 = datetime.now()
async with anyio.create_task_group() as tg:
    tg.start_soon(process_message_callback, receive_stream)
    async with send_stream:
        for i in tqdm(range(num_msgs)):
            await send_stream.send(latency_task())
            
assert datetime.now() - t0 >= timedelta(seconds=latency*num_msgs)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

To solve this, we can create tasks from coroutines and let them run in background while the receive_stream is spawning new tasks whithout being blocked by previous ones.

In [None]:
num_msgs = 10_000
latency = 4.0

receive_pbar = tqdm(total=num_msgs*2)

async def latency_task():
    receive_pbar.update(1)
    await asyncio.sleep(latency)
    receive_pbar.update(1)

tasks = set()

async def process_message_callback(
        receive_stream,
) -> None:
    async with receive_stream:
        async for f in receive_stream:
            task: asyncio.Task = asyncio.create_task(f())
            tasks.add(task)
            task.add_done_callback(lambda task=task, tasks=tasks: tasks.remove(task))

send_stream, receive_stream = anyio.create_memory_object_stream(
    max_buffer_size=num_msgs
)

t0 = datetime.now()
async with anyio.create_task_group() as tg:
    tg.start_soon(process_message_callback, receive_stream)
    async with send_stream:
        for i in tqdm(range(num_msgs)):
            await send_stream.send(latency_task)

await asyncio.sleep(latency/2)
receive_pbar.refresh()
assert receive_pbar.n == num_msgs, receive_pbar.n

while len(tasks) > 0:
    await asyncio.sleep(0)
await send_stream.aclose()
    
receive_pbar.close()
assert datetime.now() - t0 <= timedelta(seconds=latency+5.0)
assert receive_pbar.n == num_msgs*2, receive_pbar.n

print("ok")

  0%|          | 0/20000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

ok


## Keeping track of tasks

In [None]:
# | export


class TaskPool:
     
    def __init__(
        self,
        size: int = 100_000,
        on_error: Optional[Callable[[BaseException], None]] = None,
    ):
        self.size = size
        self.pool: Set[Task] = set()
        self.on_error = on_error
        self.finished = False

    async def add(self, item: Task) -> None:
        while len(self.pool) >= self.size:
            await asyncio.sleep(0)
        self.pool.add(item)
        item.add_done_callback(self.discard)

    def discard(self, task: Task) -> None:
        e = task.exception()
        if e is not None and self.on_error is not None:
            try:
                self.on_error(e)
            except Exception as ee:
                logger.warning(
                    f"Exception {ee} raised when calling on_error() callback: {e}"
                )

        self.pool.discard(task)

    def __len__(self) -> int:
        return len(self.pool)

    async def __aenter__(self) -> "TaskPool":
        self.finished = False
        return self

    async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
        while len(self) > 0:
            await asyncio.sleep(0)
        self.finished = True

    @staticmethod
    def log_error(logger: Logger) -> Callable[[Exception], None]:
        def _log_error(e: Exception, logger: Logger = logger) -> None:
            logger.warning(f"{e=}")
        return _log_error

In [None]:
async with TaskPool() as tp:
    pass

In [None]:
async def f():
    await asyncio.sleep(2)

pool = TaskPool()
assert len(pool) == 0

async with pool:
    task = asyncio.create_task(f())
    await pool.add(task)
    assert len(pool) == 1

assert len(pool) == 0, len(pool)

In [None]:
async def f():
    raise RuntimeError("funny error")

        
    return _log_error
    
pool = TaskPool(on_error=TaskPool.log_error(logger))

async with pool:
    task = asyncio.create_task(f())
    await pool.add(task)



In [None]:
# | export

class ExceptionMonitor():
    def __init__(self) -> None:
        self.exceptions: List[Exception]= []
        self.exception_found = False
        
    def on_error(self, e: Exception) -> None:
        self.exceptions.append(e)
        self.exception_found = True
        
    def _monitor_step(self) -> None:
        if len(self.exceptions) > 0:
            e = self.exceptions.pop(0)
            raise e
        
    async def __aenter__(self) -> "ExceptionMonitor":
        return self
    
    async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
        while len(self.exceptions) > 0:
            self._monitor_step()
            await asyncio.sleep(0)

In [None]:
no_tasks = 1

async def f():
    raise RuntimeError(f"very funny error.")


exception_monitor = ExceptionMonitor()
pool = TaskPool(on_error=exception_monitor.on_error)

async def create_tasks():
    for _ in range(no_tasks):
        task = asyncio.create_task(f())
        await pool.add(task)
        await asyncio.sleep(0.1) # otherwise the tasks get created before any of them throws an exception
        if exception_monitor.exception_found:
            break
        
with pytest.raises(RuntimeError) as e:
    async with exception_monitor, pool:
        async with asyncer.create_task_group() as tg:
            tg.soonify(create_tasks)()
            
print(f"{e=}")
assert exception_monitor.exceptions == [], len(exception_monitor.exceptions)

e=<ExceptionInfo RuntimeError('very funny error.') tblen=4>


# Streaming

In [None]:
# | export


class StreamExecutor(ABC):
    
    @abstractmethod
    async def run( # type: ignore
        self,
        is_shutting_down_f: Callable[[], bool],
        produce_func: Callable[[], Awaitable[ConsumerRecord]],
        consume_func: Callable[[ConsumerRecord], Awaitable[None]],
    ) -> None:
        pass

## Streaming tasks

In [None]:
mock = Mock()
async_mock = asyncer.asyncify(mock)

async def process_items(receive_stream):
    async with receive_stream:
        async for item in receive_stream:
            task = asyncio.create_task(async_mock(item))
            await pool.add(task)

send_stream, receive_stream = create_memory_object_stream()
pool = TaskPool()

async with pool:
    async with create_task_group() as tg:
        tg.start_soon(process_items, receive_stream)
        async with send_stream:
            await send_stream.send(f"hi")

mock.assert_called()

In [None]:
# | export


def _process_items_task( # type: ignore
    consume_func: Callable[[ConsumerRecord], Awaitable[None]],
    task_pool: TaskPool
) -> Callable[
    [
        anyio.streams.memory.MemoryObjectReceiveStream,
        Callable[[ConsumerRecord], Awaitable[None]],
        bool,
    ],
    Coroutine[Any, Any, Awaitable[None]],
]:
    async def _process_items_wrapper(  # type: ignore
        receive_stream: anyio.streams.memory.MemoryObjectReceiveStream,
        consume_func: Callable[[ConsumerRecord], Awaitable[None]] = consume_func,
        task_pool= task_pool,
    ):
        async with receive_stream:
            async for msg in receive_stream:
                task: asyncio.Task = asyncio.create_task(consume_func(msg))  # type: ignore
                await task_pool.add(task)

    return _process_items_wrapper

In [None]:
# | export


class DynamicTaskExecutor(StreamExecutor):
    def __init__(  # type: ignore
        self,
        throw_exceptions: bool = False,
        max_buffer_size=100_000,
        size=100_000,
    ):
        self.throw_exceptions = throw_exceptions
        self.max_buffer_size = max_buffer_size
        self.exception_monitor = ExceptionMonitor()
        self.task_pool = TaskPool(
            on_error=self.exception_monitor.on_error  # type: ignore
            if throw_exceptions
            else TaskPool.log_error(logger),
            size=size,
        )

    async def run( # type: ignore
        self,
        is_shutting_down_f: Callable[[], bool],
        produce_func: Callable[[], Awaitable[ConsumerRecord]],
        consume_func: Callable[[ConsumerRecord], Awaitable[None]],
    ) -> None:
        send_stream, receive_stream = anyio.create_memory_object_stream(
            max_buffer_size=self.max_buffer_size
        )

        async with self.exception_monitor, self.task_pool:
            async with anyio.create_task_group() as tg:
                tg.start_soon(_process_items_task(consume_func, self.task_pool), receive_stream)
                async with send_stream:
                    while not is_shutting_down_f():
                        if (
                            self.exception_monitor.exception_found
                            and self.throw_exceptions
                        ):
                            break
                        msgs = await produce_func()
                        for msg in msgs:
                            await send_stream.send(msg)

In [None]:
def is_shutting_down_f(call_count:int = 1) -> Callable[[], bool]:
    count = {"count": 0}
    
    def _is_shutting_down_f(count=count, call_count:int = call_count):
        if count["count"]>=call_count:
            return True
        else:
            count["count"] = count["count"] + 1
            return False
        
    return _is_shutting_down_f

In [None]:
f = is_shutting_down_f()
assert f() == False
assert f() == True

In [None]:
async def produce():
    return ["msg"]


async def consume(msg):
    print(msg)


stream = DynamicTaskExecutor()

await stream.run(
    is_shutting_down_f(),
    produce_func=produce,
    consume_func=consume,
)

msg


In [None]:
mock_produce = AsyncMock(spec=CoroutineType, return_value=["msg"])
mock_consume = AsyncMock(spec=CoroutineType)

stream = DynamicTaskExecutor()

await stream.run(
    is_shutting_down_f(),
    produce_func=mock_produce,
    consume_func=mock_consume,
)

mock_produce.assert_awaited()
mock_consume.assert_awaited_with("msg")

In [None]:
mock_produce = AsyncMock(spec=CoroutineType, return_value=["msg"])
mock_consume = AsyncMock(spec=CoroutineType)

stream = DynamicTaskExecutor()

await stream.run(
    is_shutting_down_f(),
    produce_func=mock_produce,
    consume_func=mock_consume,
)

mock_produce.assert_called()
mock_consume.assert_called_with("msg")

In [None]:
num_msgs = 13

mock_produce = AsyncMock(spec=CoroutineType, return_value=["msg"])
mock_consume = AsyncMock(spec=CoroutineType)
mock_consume.side_effect = RuntimeError()

stream = DynamicTaskExecutor(throw_exceptions=True)

with pytest.raises(RuntimeError) as e:
    await stream.run(
        is_shutting_down_f(num_msgs),
        produce_func=mock_produce,
        consume_func=mock_consume,
    )

mock_produce.assert_called()
mock_consume.assert_awaited_with("msg")

In [None]:
num_msgs = 13

mock_produce = AsyncMock(spec=CoroutineType, return_value=["msg"])
mock_consume = AsyncMock(spec=CoroutineType)
mock_consume.side_effect = RuntimeError()

stream = DynamicTaskExecutor()

await stream.run(
    is_shutting_down_f(num_msgs),
    produce_func=mock_produce,
    consume_func=mock_consume,
)

mock_produce.assert_called()
mock_consume.assert_awaited_with("msg")



## Awaiting coroutines

In [None]:
# | export


def _process_items_coro(  # type: ignore
    consume_func: Callable[[ConsumerRecord], Awaitable[None]],
    throw_exceptions: bool,
) -> Callable[
    [
        anyio.streams.memory.MemoryObjectReceiveStream,
        Callable[[ConsumerRecord], Awaitable[None]],
        bool,
    ],
    Coroutine[Any, Any, Awaitable[None]],
]:
    async def _process_items_wrapper(  # type: ignore
        receive_stream: anyio.streams.memory.MemoryObjectReceiveStream,
        consume_func: Callable[[ConsumerRecord], Awaitable[None]] = consume_func,
        throw_exceptions: bool = throw_exceptions,
    ) -> Awaitable[None]:
        async with receive_stream:
            async for msg in receive_stream:
                try:
                    await consume_func(msg)
                except Exception as e:
                    if throw_exceptions:
                        raise e
                    else:
                        logger.warning(f"{e=}")

    return _process_items_wrapper

In [None]:
# | export


class SequentialExecutor(StreamExecutor):
    def __init__(  # type: ignore
        self,
        throw_exceptions: bool = False,
        max_buffer_size=100_000,
    ):
        self.throw_exceptions = throw_exceptions
        self.max_buffer_size = max_buffer_size

    async def run( # type: ignore
        self,
        is_shutting_down_f: Callable[[], bool],
        produce_func: Callable[[], Awaitable[ConsumerRecord]],
        consume_func: Callable[[ConsumerRecord], Awaitable[None]],
    ) -> None:

        send_stream, receive_stream = anyio.create_memory_object_stream(
            max_buffer_size=self.max_buffer_size
        )
        
        async with anyio.create_task_group() as tg:
            tg.start_soon(_process_items_coro(consume_func, self.throw_exceptions), receive_stream)
            async with send_stream:
                while not is_shutting_down_f():
                    msgs = await produce_func()
                    for msg in msgs:
                        await send_stream.send(msg)

In [None]:
num_msgs = 13

mock_produce = AsyncMock(spec=CoroutineType, return_value=["msg"])
mock_consume = AsyncMock(spec=CoroutineType)
mock_consume.side_effect = RuntimeError("Funny error")

stream = SequentialExecutor(throw_exceptions=True)

with pytest.raises(ExceptionGroup) as e:
    await stream.run(is_shutting_down_f(num_msgs), produce_func=mock_produce, consume_func=mock_consume)

mock_produce.assert_called()
mock_consume.assert_awaited_with("msg")

In [None]:
num_msgs = 13

mock_produce = AsyncMock(spec=CoroutineType, return_value=["msg"])
mock_consume = AsyncMock(spec=CoroutineType)
mock_consume.side_effect = RuntimeError("Funny error")

stream = SequentialExecutor()

await stream.run(
    is_shutting_down_f(num_msgs),
    mock_produce,
    mock_consume,
)

mock_produce.assert_called()
mock_consume.assert_awaited_with("msg")



In [None]:
# | export


def get_executor(executor: Union[str, StreamExecutor, None] = None) -> StreamExecutor:
    if isinstance(executor, StreamExecutor):
        return executor
    elif executor is None:
        executor = "SequentialExecutor"
    return getattr(sys.modules["fastkafka._components.task_streaming"], executor)() # type: ignore

In [None]:
for executor in [None, "SequentialExecutor", SequentialExecutor()]:
    actual = get_executor(executor)
    assert actual.__class__.__qualname__ == "SequentialExecutor"

In [None]:
for executor in ["DynamicTaskExecutor", DynamicTaskExecutor()]:
    actual = get_executor(executor)
    assert actual.__class__.__qualname__ == "DynamicTaskExecutor"