In [None]:
# | default_exp _components.aiokafka_consumer_loop

In [None]:
# | export

import asyncio
from asyncio import iscoroutinefunction  # do not use the version from inspect
from datetime import datetime, timedelta
from os import environ
from typing import *

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream
import asyncer
from aiokafka import AIOKafkaConsumer
from aiokafka.structs import ConsumerRecord, TopicPartition
from pydantic import BaseModel, Field, HttpUrl, NonNegativeInt

from fast_kafka_api._components.logger import get_logger

In [None]:
from unittest.mock import AsyncMock, MagicMock, Mock, call

from fast_kafka_api._components.logger import supress_timestamps
from fast_kafka_api.testing import (
    create_and_fill_testing_topic,
    nb_safe_seed,
    true_after,
)

In [None]:
seed = nb_safe_seed("_components.aiokafka_consumer_loop")

In [None]:
# | notest
# allows async calls in notebooks

import nest_asyncio

In [None]:
# | notest

nest_asyncio.apply()

In [None]:
# | export

logger = get_logger(__name__)

In [None]:
supress_timestamps()
logger = get_logger(__name__, level=20)
logger.info("ok")

[INFO] __main__: ok


In [None]:
kafka_server_url = environ["KAFKA_HOSTNAME"]
kafka_server_port = environ["KAFKA_PORT"]

kafka_config = {"bootstrap.servers": f"{kafka_server_url}:{kafka_server_port}"}

In [None]:
class MyMessage(BaseModel):
    url: HttpUrl = Field(..., example="http://www.acme.com", description="Url example")
    port: NonNegativeInt = Field(1000)

In [None]:
# | export


async def process_msgs(  # type: ignore
    *,
    msgs: Dict[TopicPartition, List[ConsumerRecord]],
    callbacks: Dict[str, Callable[[BaseModel], Union[None, Awaitable[None]]]],
    msg_types: Dict[str, Type[BaseModel]],
    process_f: Callable[
        [Tuple[Callable[[BaseModel], Awaitable[None]], BaseModel]], Awaitable[None]
    ],
) -> None:
    """For each messages **msg** in **msgs**, calls process_f with callbacks[topic] and **msgs**.

    Params:
        msgs: a dictionary mapping topic partition to a list of messages, returned by `AIOKafkaConsumer.getmany`.
        callbacks: a dictionary mapping topics into a callback functions.
        msg_types: a dictionary mapping topics into a message type of a message.
        process_f: a stream processing function registrated by `anyio.create_memory_object_stream`

    Todo:
        remove it :)
    """
    for topic_partition, topic_msgs in msgs.items():
        topic = topic_partition.topic
        msg_type = msg_types[topic]
        try:
            decoded_msgs = [
                msg_type.parse_raw(msg.value.decode("utf-8")) for msg in topic_msgs
            ]
            for msg in decoded_msgs:
                callback_raw = callbacks[topic]
                if not iscoroutinefunction(callback_raw):
                    c: Callable[[BaseModel], None] = callback_raw  # type: ignore
                    callback: Callable[[BaseModel], Awaitable[None]] = asyncer.asyncify(
                        c
                    )
                else:
                    callback = callback_raw

                async def safe_callback(
                    msg: BaseModel,
                    callback: Callable[[BaseModel], Awaitable[None]] = callback,
                ) -> None:
                    try:
                        #                         logger.debug(f"process_msgs(): awaiting '{callback}({msg})'")
                        await callback(msg)
                    except Exception as e:
                        logger.warning(
                            f"process_msgs(): exception caugth {e.__repr__()} while awaiting '{callback}({msg})'"
                        )

                await process_f((safe_callback, msg))
        except Exception as e:
            logger.warning(
                f"process_msgs(): Unexpected exception '{e.__repr__()}' caught and ignored for topic='{topic_partition.topic}', partition='{topic_partition.partition}' and messages: {topic_msgs}"
            )

In [None]:
def create_consumer_record(topic: str, partition: int, msg: BaseModel):
    record = ConsumerRecord(
        topic=topic,
        partition=partition,
        offset=0,
        timestamp=0,
        timestamp_type=0,
        key=None,
        value=msg.json().encode("utf-8"),
        checksum=0,
        serialized_key_size=0,
        serialized_value_size=0,
        headers=[],
    )
    return record

In [None]:
# Sanity check
# One msg, one topic, process_f called once with callback and decoded_msg

topic = "topic_0"
partition = 0
topic_part_0_0 = TopicPartition(topic, partition)
msg = MyMessage(url="http://www.acme.com", port=22)
record = create_consumer_record(topic=topic, partition=partition, msg=msg)


async def process_f(arg):
    callback, msg = arg
    await callback(msg)


for is_async in [False, True]:
    print(f"is_async={is_async}")
    callback_0 = Mock()
    await process_msgs(
        msgs={topic_part_0_0: [record]},
        callbacks={topic: (asyncer.asyncify(callback_0) if is_async else callback_0)},
        msg_types={topic: MyMessage},
        process_f=process_f,
    )

    #     process_f.assert_called_with((callback_0, msg))
    callback_0.assert_called_with(msg)
    assert callback_0.call_count == 1

is_async=False
is_async=True


In [None]:
# Sanity check: exception in callback
# One msg, one topic, process_f called once with callback and decoded_msg

topic = "topic_0"
partition = 0
topic_part_0_0 = TopicPartition(topic, partition)
msg = MyMessage(url="http://www.acme.com", port=22)
record = create_consumer_record(topic=topic, partition=partition, msg=msg)


async def process_f(arg):
    callback, msg = arg
    await callback(msg)


for is_async in [False, True]:
    print(f"is_async={is_async}")
    callback_0 = Mock()
    callback_0.side_effect = Mock(side_effect=Exception("Test"))
    await process_msgs(
        msgs={topic_part_0_0: [record]},
        callbacks={topic: (asyncer.asyncify(callback_0) if is_async else callback_0)},
        msg_types={topic: MyMessage},
        process_f=process_f,
    )

    #     process_f.assert_called_with((callback_0, msg))
    callback_0.assert_called_with(msg)
    assert callback_0.call_count == 1

is_async=False
is_async=True


In [None]:
# Check different topics

# Two msg, two topics, process_f called twice with each callback called once

topic_part_0_0 = TopicPartition("topic_0", 0)
topic_part_1_0 = TopicPartition("topic_1", 0)

topic = "topic_0"
partition = 0
topic_part_0_0 = TopicPartition("topic_0", 0)
msg = MyMessage(url="http://www.acme.com", port=22)
record = create_consumer_record(topic=topic, partition=partition, msg=msg)

callback_0 = Mock()
callback_1 = AsyncMock()

await process_msgs(
    msgs={topic_part_0_0: [record], topic_part_1_0: [record]},
    callbacks={"topic_0": callback_0, "topic_1": callback_1},
    msg_types={"topic_0": MyMessage, "topic_1": MyMessage},
    process_f=process_f,
)

callback_0.assert_called_once_with(msg)
callback_1.assert_awaited_once_with(msg)
callback_0.assert_called_once_with(msg)

In [None]:
# Check multiple msgs in same topic
# Check callback not called if there are no msgs for it in the queue

# Two msg, one topic, one callback called twice, other called nonce, produce and process_f called twice

# Check different topics

# Two msg, two topics, process_f called twice with each callback called once and produce twice

topic_part_0_0 = TopicPartition("topic_0", 0)

topic = "topic_0"
partition = 0
topic_part_0_0 = TopicPartition("topic_0", 0)
msg = MyMessage(url="http://www.acme.com", port=22)
record = create_consumer_record(topic=topic, partition=partition, msg=msg)

callback_0 = Mock()
callback_1 = AsyncMock()

await process_msgs(
    msgs={topic_part_0_0: [record, record]},
    callbacks={"topic_0": callback_0, "topic_1": callback_1},
    msg_types={"topic_0": MyMessage, "topic_1": MyMessage},
    process_f=process_f,
)

callback_0.assert_has_calls([call(msg)] * 2)
callback_1.assert_not_awaited()

In [None]:
# Check multiple partitions

# Two msg, one topic, two partitions, one callback called twice, produce and process_f called twice

topic_part_0_0 = TopicPartition("topic_0", 0)
topic_part_0_1 = TopicPartition("topic_0", 1)

msg = MyMessage(url="http://www.acme.com", port=22)
record = create_consumer_record(topic=topic, partition=partition, msg=msg)

callback_0 = AsyncMock()
callback_1 = Mock()

await process_msgs(
    msgs={
        topic_part_0_0: [create_consumer_record(topic="topic_0", partition=0, msg=msg)],
        topic_part_0_1: [create_consumer_record(topic="topic_0", partition=1, msg=msg)],
    },
    callbacks={"topic_0": callback_0, "topic_1": callback_1},
    msg_types={"topic_0": MyMessage, "topic_1": MyMessage},
    process_f=process_f,
)

callback_0.assert_has_awaits([call(msg)] * 2)
callback_1.assert_not_called()

In [None]:
# | export


async def process_message_callback(
    receive_stream: MemoryObjectReceiveStream[Any],
) -> None:
    async with receive_stream:
        async for callback, msg in receive_stream:
            await callback(msg)


async def _aiokafka_consumer_loop(  # type: ignore
    consumer: AIOKafkaConsumer,
    *,
    callbacks: Dict[str, Callable[[BaseModel], Union[None, Awaitable[None]]]],
    timeout_ms: int = 100,
    max_buffer_size: int = 10_000,
    msg_types: Dict[str, Type[BaseModel]],
    is_shutting_down_f: Callable[[], bool],
) -> None:
    """Write docs

    Todo: add batch size if needed
    """
    send_stream, receive_stream = anyio.create_memory_object_stream(
        max_buffer_size=max_buffer_size
    )
    async with anyio.create_task_group() as tg:
        tg.start_soon(process_message_callback, receive_stream)
        async with send_stream:
            while not is_shutting_down_f():
                msgs = await consumer.getmany(timeout_ms=timeout_ms)
                try:
                    await process_msgs(
                        msgs=msgs,
                        callbacks=callbacks,
                        msg_types=msg_types,
                        process_f=send_stream.send,
                    )
                except Exception as e:
                    logger.warning(
                        f"_aiokafka_consumer_loop(): Unexpected exception '{e}' caught and ignored for messages: {msgs}"
                    )

In [None]:
topic = "topic_0"
msg = MyMessage(url="http://www.acme.com", port=22)
record = create_consumer_record(topic=topic, partition=partition, msg=msg)

mock_consumer = MagicMock()
msgs = {TopicPartition(topic, 0): [record]}

f = asyncio.Future()
f.set_result(msgs)
mock_consumer.configure_mock(**{"getmany.return_value": f})
mock_callback = Mock()


def is_shutting_down_f(mock_func):
    def _is_shutting_down_f():
        return mock_func.called

    return _is_shutting_down_f


for is_async in [True, False]:
    await _aiokafka_consumer_loop(
        consumer=mock_consumer,
        max_buffer_size=100,
        callbacks={
            topic: asyncer.asyncify(mock_callback) if is_async else mock_callback
        },
        msg_types={topic: MyMessage},
        is_shutting_down_f=is_shutting_down_f(mock_consumer.getmany),
    )

    assert mock_consumer.getmany.call_count == 1
    mock_callback.assert_called_once_with(msg)

In [None]:
# | export


async def aiokafka_consumer_loop(  # type: ignore
    topics: List[str],
    *,
    bootstrap_servers: str,
    auto_offset_reset: str,
    max_poll_records: int = 1_000,
    timeout_ms: int = 100,
    max_buffer_size: int = 10_000,
    callbacks: Dict[str, Callable[[BaseModel], Union[None, Awaitable[None]]]],
    msg_types: Dict[str, Type[BaseModel]],
    is_shutting_down_f: Callable[[], bool],
    **kwargs,
) -> None:
    """todo: write docs"""
    logger.info(f"aiokafka_consumer_loop() starting..")
    consumer = AIOKafkaConsumer(
        bootstrap_servers=bootstrap_servers,
        auto_offset_reset=auto_offset_reset,
        max_poll_records=max_poll_records,
    )
    logger.info("aiokafka_consumer_loop(): Consumer created.")

    await consumer.start()
    logger.info("aiokafka_consumer_loop(): Consumer started.")
    consumer.subscribe(topics)
    logger.info("aiokafka_consumer_loop(): Consumer subscribed.")

    try:
        await _aiokafka_consumer_loop(
            consumer=consumer,
            max_buffer_size=max_buffer_size,
            timeout_ms=timeout_ms,
            callbacks=callbacks,
            msg_types=msg_types,
            is_shutting_down_f=is_shutting_down_f,
        )
    finally:
        await consumer.stop()
        logger.info(f"aiokafka_consumer_loop(): Consumer stopped.")
        logger.info(f"aiokafka_consumer_loop() finished.")

In [None]:
msgs_sent = 9178
msgs = [
    MyMessage(url="http://www.ai.com", port=port).json().encode("utf-8")
    for port in range(msgs_sent)
]
msgs_received = 0


async def count_msg(msg: MyMessage):
    global msgs_received
    msgs_received = msgs_received + 1
    if msgs_received % 1000 == 0:
        logger.info(f"{msgs_received=}")


async with create_and_fill_testing_topic(
    kafka_config=kafka_config, msgs=msgs, seed=seed(1)
) as topic:
    await aiokafka_consumer_loop(
        topics=[topic],
        bootstrap_servers=kafka_config["bootstrap.servers"],
        auto_offset_reset="earliest",
        callbacks={topic: count_msg},
        msg_types={topic: MyMessage},
        is_shutting_down_f=true_after(2),
    )

assert msgs_sent == msgs_received, f"{msgs_sent} != {msgs_received}"

[INFO] fast_kafka_api.testing: create_missing_topics(['my_topic_5696213874']): new_topics = [NewTopic(topic=my_topic_5696213874,num_partitions=3)]
[INFO] fast_kafka_api.testing: Producer <aiokafka.producer.producer.AIOKafkaProducer object> created.
[INFO] fast_kafka_api.testing: Producer <aiokafka.producer.producer.AIOKafkaProducer object> started.
[INFO] fast_kafka_api.testing: Sent messages: len(sent_msgs)=9178
[INFO] __main__: aiokafka_consumer_loop() starting..
[INFO] __main__: aiokafka_consumer_loop(): Consumer created.
[INFO] __main__: aiokafka_consumer_loop(): Consumer started.
[INFO] aiokafka.consumer.subscription_state: Updating subscribed topics to: frozenset({'my_topic_5696213874'})
[INFO] aiokafka.consumer.consumer: Subscribed to topic(s): {'my_topic_5696213874'}
[INFO] __main__: aiokafka_consumer_loop(): Consumer subscribed.
[INFO] aiokafka.consumer.group_coordinator: Metadata for topic has changed from {} to {'my_topic_5696213874': 3}. 
[INFO] __main__: msgs_received=1000

In [None]:
# | notest

msgs_sent = 100_000
msgs = [
    MyMessage(url="http://www.ai.com", port=port).json().encode("utf-8")
    for port in range(msgs_sent)
]
msgs_received = 0


async def count_msg(msg: MyMessage):
    global msgs_received
    msgs_received = msgs_received + 1
    if msgs_received % 1000 == 0:
        logger.info(f"{msgs_received=}")


def _is_shutting_down_f():
    return msgs_received == msgs_sent


async with create_and_fill_testing_topic(
    kafka_config=kafka_config, msgs=msgs, seed=seed(3)
) as topic:
    start = datetime.now()
    await aiokafka_consumer_loop(
        topics=[topic],
        bootstrap_servers=kafka_config["bootstrap.servers"],
        auto_offset_reset="earliest",
        callbacks={topic: count_msg},
        msg_types={topic: MyMessage},
        is_shutting_down_f=_is_shutting_down_f,
    )
    t = (datetime.now() - start) / timedelta(seconds=1)
    thrp = msgs_received / t

    print(f"Messages processed: {msgs_received:,d}")
    print(f"Time              : {t:.2f} s")
    print(f"Throughput.       : {thrp:,.0f} msg/s")

[INFO] fast_kafka_api.testing: create_missing_topics(['my_topic_5168585847']): new_topics = [NewTopic(topic=my_topic_5168585847,num_partitions=3)]
[INFO] fast_kafka_api.testing: Producer <aiokafka.producer.producer.AIOKafkaProducer object> created.
[INFO] fast_kafka_api.testing: Producer <aiokafka.producer.producer.AIOKafkaProducer object> started.
[INFO] fast_kafka_api.testing: Sent messages: len(sent_msgs)=100000
[INFO] __main__: aiokafka_consumer_loop() starting..
[INFO] __main__: aiokafka_consumer_loop(): Consumer created.
[INFO] __main__: aiokafka_consumer_loop(): Consumer started.
[INFO] aiokafka.consumer.subscription_state: Updating subscribed topics to: frozenset({'my_topic_5168585847'})
[INFO] aiokafka.consumer.consumer: Subscribed to topic(s): {'my_topic_5168585847'}
[INFO] __main__: aiokafka_consumer_loop(): Consumer subscribed.
[INFO] aiokafka.consumer.group_coordinator: Metadata for topic has changed from {} to {'my_topic_5168585847': 3}. 
[INFO] __main__: msgs_received=10