In [1]:
# | default_exp _components.helper_structs

In [2]:
# | export 

import asyncio
from asyncio import Task
from contextlib import asynccontextmanager
import anyio
from typing import Set, AsyncGenerator, Callable

In [3]:
from anyio import create_task_group, create_memory_object_stream
from unittest.mock import Mock, MagicMock
import asyncer
from pydantic import BaseModel, Field, HttpUrl, NonNegativeInt
from aiokafka import ConsumerRecord, TopicPartition

In [4]:
# | export

class TaskPool():
    
    def __init__(self, pool_size: int = 100_000):
        self.pool_size = pool_size
        self.pool: Set[Task] = set()
        
    async def add(self, item: Task) -> None:
        while len(self.pool) >= self.pool_size:
            await asyncio.sleep(0.1)
        self.pool.add(item)
        item.add_done_callback(self.pool.discard)
        
    def __len__(self) -> int:
        return len(self.pool)   

In [5]:
# | export

@asynccontextmanager
async def pool_guard(pool: TaskPool) -> AsyncGenerator[TaskPool, None]:
    try:
        yield pool
    finally:
        while len(pool) > 0:
            await asyncio.sleep(1)

In [6]:
async def f():
    await asyncio.sleep(5)

pool = TaskPool()
assert len(pool) == 0

async with pool_guard(pool):
    task = asyncio.create_task(f())
    await pool.add(task)
    task.add_done_callback(print)
    assert len(pool) == 1
    

assert len(pool) == 0, len(pool)

<Task finished name='Task-5' coro=<f() done, defined at /tmp/ipykernel_691689/1873161791.py:1> result=None>


In [7]:
async def f():
    await asyncio.sleep(5)
    raise RuntimeError

pool = TaskPool()
assert len(pool) == 0

async with pool_guard(pool):
    task = asyncio.create_task(f())
    await pool.add(task)
    task.add_done_callback(lambda task: task.result())
    assert len(pool) == 1
    

assert len(pool) == 0, len(pool)
print("This should not print")

Exception in callback <lambda>(<Task finishe...untimeError()>) at /tmp/ipykernel_691689/1509985980.py:11
handle: <Handle <lambda>(<Task finishe...untimeError()>) at /tmp/ipykernel_691689/1509985980.py:11>
Traceback (most recent call last):
  File "/usr/lib/python3.11/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/tmp/ipykernel_691689/1509985980.py", line 11, in <lambda>
    task.add_done_callback(lambda task: task.result())
                                        ^^^^^^^^^^^^^
  File "/tmp/ipykernel_691689/1509985980.py", line 3, in f
    raise RuntimeError
RuntimeError


This should not print


In [8]:
mock = Mock()
async_mock = asyncer.asyncify(mock)


async def process_items(receive_stream):
    async with receive_stream:
        async for item in receive_stream:
            task = asyncio.create_task(async_mock(item))
            await pool.add(task)


send_stream, receive_stream = create_memory_object_stream()

async with pool_guard(pool):
    async with create_task_group() as tg:
        tg.start_soon(process_items, receive_stream)
        async with send_stream:
            await send_stream.send(f"hi")

mock.assert_called()

In [19]:
def is_shutting_down_f(mock_func: Mock, num_calls: int = 1) -> Callable[[], bool]:
    def _is_shutting_down_f():
        return mock_func.call_count == num_calls

    return _is_shutting_down_f

msgs = {TopicPartition("topic", 0): ["record"]}

f = asyncio.Future()
f.set_result(msgs)
mock_consumer = MagicMock()
mock_consumer.configure_mock(**{"getmany.return_value": f})
mock_callback = Mock()

shutting_down = is_shutting_down_f(mock_consumer.getmany)

# async def async_mock(item):
#     mock(item)

async_mock = asyncer.asyncify(mock_callback)

async def process_items(receive_stream):
    async with receive_stream:
        async for item in receive_stream:
            task = asyncio.create_task(async_mock(item))
            await pool.add(task)


send_stream, receive_stream = create_memory_object_stream()

async with create_task_group() as tg:
    tg.start_soon(process_items, receive_stream)
    async with pool_guard(pool), send_stream:
        while not shutting_down():
            msgs = await mock_consumer.getmany()
            await send_stream.send(msgs.values())

mock.assert_called()

print("ok")

ok


In [13]:
mock_consumer.getmany()

<Future finished result={TopicPartitio..., partition=0): ['record']}>