In [None]:
# | default_exp _components.aiokafka_producer_manager

In [None]:
# | export
from typing import *

from contextlib import contextmanager, asynccontextmanager

from aiokafka import AIOKafkaProducer
import anyio
import asyncio

from fast_kafka_api._components.logger import get_logger

In [None]:
from os import environ

import unittest.mock

from fast_kafka_api._components.logger import supress_timestamps
from fast_kafka_api.testing import true_after, create_and_fill_testing_topic, nb_safe_seed

In [None]:
seed = nb_safe_seed("_components.aiokafka_producer_loop")

In [None]:
# | notest
# allows async calls in notebooks

import nest_asyncio
nest_asyncio.apply()

In [None]:
# | export

logger = get_logger(__name__)

In [None]:
supress_timestamps()
logger = get_logger(__name__, level=20)
logger.info("ok")

[INFO] __main__: ok


In [None]:
kafka_server_url = environ["KAFKA_HOSTNAME"]
kafka_server_port = environ["KAFKA_PORT"]

kafka_config = {
    "bootstrap.servers": f"{kafka_server_url}:{kafka_server_port}"
}

In [None]:
# | export

@asynccontextmanager
async def _aiokafka_producer_manager(
    producer: AIOKafkaProducer,
    *,
    max_buffer_size: int = 10_000
):
    """Write docs

    Todo: add batch size if needed
    """
    
    async def send_message(receive_stream):
        async with receive_stream:
            async for topic, msg in receive_stream:
                fut = await producer.send(topic, msg)
                msg = await fut
    
    send_stream, receive_stream = anyio.create_memory_object_stream(
        max_buffer_size=max_buffer_size
    )
    
    task_group_generator = anyio.create_task_group()
    task_group = await task_group_generator.__aenter__()
    task_group.start_soon(send_message, receive_stream)
    await send_stream.__aenter__()
    yield send_stream
    await send_stream.__aexit__(None, None, None)
    await task_group_generator.__aexit__(None, None, None)

In [None]:
@contextmanager
def mock_AIOKafkaProducer_send():
    with unittest.mock.patch("__main__.AIOKafkaProducer.send") as mock:

        async def _f():
            pass

        mock.return_value = asyncio.create_task(_f())

        yield mock

In [None]:
num_msgs = 1576
topic = "topic"
msg = b"msg"
msgs = [(topic, msg) for _ in range(num_msgs)]
calls = [unittest.mock.call(topic, msg) for _ in range(num_msgs)]

with mock_AIOKafkaProducer_send() as send_mock:
    producer = AIOKafkaProducer()
    producer_loop_generator = _aiokafka_producer_manager(producer)
    send_stream = await producer_loop_generator.__aenter__()
    
    for msg in msgs:
        send_stream.send_nowait(msg)
    
    await producer.stop()
    await producer_loop_generator.__aexit__(None, None, None)
    
    send_mock.assert_has_calls(calls)

In [None]:
# | export


class AIOKafkaProducerManager:
    def __init__(
        self,
        *,
        bootstrap_servers: str,
        max_buffer_size: int = 10_000,
        **kwargs,
    ):
        self.producer = AIOKafkaProducer(bootstrap_servers=bootstrap_servers)
        self.max_buffer_size = max_buffer_size

    async def start(self):
        self.producer_manager_generator = _aiokafka_producer_manager(self.producer)
        self.send_stream = await self.producer_manager_generator.__aenter__()

    async def stop(self):
        await self.producer.stop()
        await self.producer_manager_generator.__aexit__(None, None, None)

    def send(self, topic: str, msg: bytes):
        self.send_stream.send_nowait((topic, msg))