In [None]:
# | default_exp _components.aiokafka_consumer_loop

In [None]:
# | export

import asyncio
from asyncio import iscoroutinefunction  # do not use the version from inspect
from datetime import datetime, timedelta
from os import environ
from typing import *

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream
import asyncer
from aiokafka import AIOKafkaConsumer
from aiokafka.structs import ConsumerRecord, TopicPartition
from pydantic import BaseModel, Field, HttpUrl, NonNegativeInt

from fast_kafka_api._components.logger import get_logger

In [None]:
from unittest.mock import AsyncMock, MagicMock, Mock, call, patch

from fast_kafka_api._components.logger import supress_timestamps
from fast_kafka_api.testing import (
    create_and_fill_testing_topic,
    nb_safe_seed,
    true_after,
)

In [None]:
seed = nb_safe_seed("_components.aiokafka_consumer_loop")

In [None]:
# | notest
# allows async calls in notebooks

import nest_asyncio

In [None]:
# | notest

nest_asyncio.apply()

In [None]:
# | export

logger = get_logger(__name__)

In [None]:
supress_timestamps()
logger = get_logger(__name__, level=20)
logger.info("ok")

[INFO] __main__: ok


In [None]:
# | export

kafka_server_url = (
    environ["KAFKA_HOSTNAME"] if "KAFKA_HOSTNAME" in environ else "localhost"
)

kafka_server_port = environ["KAFKA_PORT"] if "KAFKA_PORT" in environ else "9092"

aiokafka_config = {
    "bootstrap_servers": f"{kafka_server_url}:{kafka_server_port}",
}

In [None]:
class MyMessage(BaseModel):
    url: HttpUrl = Field(..., example="http://www.acme.com", description="Url example")
    port: NonNegativeInt = Field(1000)

In [None]:
def create_consumer_record(topic: str, partition: int, msg: BaseModel):
    record = ConsumerRecord(
        topic=topic,
        partition=partition,
        offset=0,
        timestamp=0,
        timestamp_type=0,
        key=None,
        value=msg.json().encode("utf-8"),
        checksum=0,
        serialized_key_size=0,
        serialized_value_size=0,
        headers=[],
    )
    return record

In [None]:
# | export


def _create_safe_callback(
    callback: Callable[[BaseModel], Awaitable[None]]
) -> Callable[[BaseModel], Awaitable[None]]:
    """
    Wraps an async callback into a safe callback that catches any Exception and loggs them as warnings

    Params:
        callback: async callable that will be wrapped into a safe callback

    Returns:
        Wrapped callback into a safe callback that handles exceptions
    """
    async def safe_callback(
        msg: BaseModel,
        callback: Callable[[BaseModel], Awaitable[None]] = callback,
    ) -> None:
        try:
            #                         logger.debug(f"process_msgs(): awaiting '{callback}({msg})'")
            await callback(msg)
        except Exception as e:
            logger.warning(
                f"safe_callback(): exception caugth {e.__repr__()} while awaiting '{callback}({msg})'"
            )
    return safe_callback

In [None]:
# Check if callback is called when wrapped

example_msg = "Example msg"
callback = AsyncMock()
safe_callback = _create_safe_callback(callback)

await safe_callback(f"{example_msg}")

callback.assert_awaited_once_with(f"{example_msg}")

In [None]:
# Check if exception is caught and logged when callback is called and throws an exception

with patch.object(logger, "warning") as mock:
    example_msg = "Example msg"
    exception = Exception("")

    callback = AsyncMock()
    callback.side_effect = exception
    safe_callback = _create_safe_callback(callback)

    await safe_callback(f"{example_msg}")
    
    callback.assert_awaited_once_with(f"{example_msg}")
    mock.assert_called_once_with(
        f"safe_callback(): exception caugth {exception.__repr__()} while awaiting '{callback}({example_msg})'"
    )

In [None]:
# | export


def _prepare_callback(
    callback: Union[Callable[[BaseModel], None], Callable[[BaseModel], Awaitable[None]]]
) -> Callable[[BaseModel], Awaitable[None]]:
    """
    Prepares a callback to be used in the consumer loop. 
        1. If callback is sync, asyncify it
        2. Wrap the callback into a safe callback for exception handling

    Params:
        callback: async callable that will be prepared for use in consumer

    Returns:
        Prepared callback
    """
    callback: Callable[[BaseModel], Awaitable[None]] = (
        callback if iscoroutinefunction(callback) else asyncer.asyncify(callback)
    )
    return _create_safe_callback(callback)

In [None]:
# Check if callback is called when wrapped

for is_async in [False, True]:
    example_msg = "Example msg"
    callback = AsyncMock() if is_async else Mock()
    prepared_callback = _prepare_callback(callback)

    await prepared_callback(f"{example_msg}")

    callback.assert_called_once_with(f"{example_msg}")

In [None]:
# | export


async def _decode_msgs_and_stream(msgs: Dict[TopicPartition, bytes], msg_types: Dict[str, BaseModel], send_stream: anyio.streams.memory.MemoryObjectSendStream[Any]) -> None:
    """
    Prepares a callback to be used in the consumer loop. 
        1. If callback is sync, asyncify it
        2. Wrap the callback into a safe callback for exception handling

    Params:
        callback: async callable that will be prepared for use in consumer

    Returns:
        Prepared callback
    """
    for topic_partition, topic_msgs in msgs.items():
        topic = topic_partition.topic
        msg_type = msg_types[topic]
        try:
            decoded_msgs = [
                msg_type.parse_raw(msg.value.decode("utf-8")) for msg in topic_msgs
            ]
            for msg in decoded_msgs:
                await send_stream.send((topic, msg))
        except Exception as e:
            logger.warning(
                f"process_msgs(): Unexpected exception '{e.__repr__()}' caught and ignored for topic='{topic_partition.topic}', partition='{topic_partition.partition}' and messages: {topic_msgs}"
            )

In [None]:
# Sanity check: one msg, one topic

with patch("anyio.streams.memory.MemoryObjectSendStream.send") as mock:
    send_stream, receive_stream = anyio.create_memory_object_stream()

    topic = "topic_0"
    partition = 0
    topic_part_0_0 = TopicPartition(topic, partition)
    msg = MyMessage(url="http://www.acme.com", port=22)
    record = create_consumer_record(topic=topic, partition=partition, msg=msg)

    await _decode_msgs_and_stream(
        msgs={topic_part_0_0: [record]},
        msg_types={topic: MyMessage},
        send_stream=send_stream,
    )

    mock.assert_called_with((topic, msg))

In [None]:
# Check different topics

# Two msg, two topics, send called twice with each topic

with patch("anyio.streams.memory.MemoryObjectSendStream.send") as mock:
    send_stream, receive_stream = anyio.create_memory_object_stream()

    topic_partitions = [("topic_0", 0), ("topic_1", 0)]

    msg = MyMessage(url="http://www.acme.com", port=22)
    record = create_consumer_record(topic=topic, partition=partition, msg=msg)

    await _decode_msgs_and_stream(
        msgs={
            TopicPartition(topic, partition): [
                create_consumer_record(topic=topic, partition=partition, msg=msg)
            ]
            for topic, partition in topic_partitions
        },
        msg_types={topic: MyMessage for topic, _ in topic_partitions},
        send_stream=send_stream,
    )

    mock.assert_has_calls(
        [
            call(("topic_0", msg)),
            call(("topic_1", msg)),
        ]
    )

In [None]:
# Check multiple msgs in same topic

# Two msg, one topic, send called twice for same topic

with patch("anyio.streams.memory.MemoryObjectSendStream.send") as mock:
    send_stream, receive_stream = anyio.create_memory_object_stream()

    topic_partitions = [("topic_0", 0)]

    msg = MyMessage(url="http://www.acme.com", port=22)
    record = create_consumer_record(topic=topic, partition=partition, msg=msg)

    await _decode_msgs_and_stream(
        msgs={
            TopicPartition(topic, partition): [
                create_consumer_record(topic=topic, partition=partition, msg=msg),
                create_consumer_record(topic=topic, partition=partition, msg=msg),
            ]
            for topic, partition in topic_partitions
        },
        msg_types={topic: MyMessage for topic, _ in topic_partitions},
        send_stream=send_stream,
    )

    mock.assert_has_calls(
        [
            call(("topic_0", msg)),
            call(("topic_0", msg)),
        ]
    )

In [None]:
# Check multiple partitions

# Two msg, one topic, differenct partitions, send called twice for same topic

with patch("anyio.streams.memory.MemoryObjectSendStream.send") as mock:
    send_stream, receive_stream = anyio.create_memory_object_stream()

    topic_partitions = [("topic_0", 0), ("topic_0", 1)]

    msg = MyMessage(url="http://www.acme.com", port=22)
    record = create_consumer_record(topic=topic, partition=partition, msg=msg)

    await _decode_msgs_and_stream(
        msgs={
            TopicPartition(topic, partition): [
                create_consumer_record(topic=topic, partition=partition, msg=msg)
                ]
            for topic, partition in topic_partitions
        },
        msg_types={topic: MyMessage for topic, _ in topic_partitions},
        send_stream=send_stream,
    )

    mock.assert_has_calls(
        [
            call(("topic_0", msg)),
            call(("topic_0", msg)),
        ]
    )

In [None]:
# | export


async def _aiokafka_consumer_loop(  # type: ignore
    consumer: AIOKafkaConsumer,
    *,
    callbacks: Dict[str, Callable[[BaseModel], Union[None, Awaitable[None]]]],
    timeout_ms: int = 100,
    max_buffer_size: int = 10_000,
    msg_types: Dict[str, Type[BaseModel]],
    is_shutting_down_f: Callable[[], bool],
) -> None:
    """Write docs

    Todo: add batch size if needed
    """

    prepared_callbacks = {
        topic: _prepare_callback(callback) for topic, callback in callbacks.items()
    }

    async def process_message_callback(
        receive_stream: MemoryObjectReceiveStream[Any], callback=prepared_callbacks
    ) -> None:
        async with receive_stream:
            async for topic, msg in receive_stream:
                if topic in callbacks:
                    await callbacks[topic](msg)
                else:
                    logger.error(f"No callback defined for topic {topic}")

    send_stream, receive_stream = anyio.create_memory_object_stream(
        max_buffer_size=max_buffer_size
    )

    async with anyio.create_task_group() as tg:
        tg.start_soon(process_message_callback, receive_stream)
        async with send_stream:
            while not is_shutting_down_f():
                msgs = await consumer.getmany(timeout_ms=timeout_ms)
                try:
                    await _decode_msgs_and_stream(msgs, msg_types, send_stream)
                except Exception as e:
                    logger.warning(
                        f"_aiokafka_consumer_loop(): Unexpected exception '{e}' caught and ignored for messages: {msgs}"
                    )

In [None]:
topic = "topic_0"
partition = 0
msg = MyMessage(url="http://www.acme.com", port=22)
record = create_consumer_record(topic=topic, partition=partition, msg=msg)

mock_consumer = MagicMock()
msgs = {TopicPartition(topic, 0): [record]}

f = asyncio.Future()
f.set_result(msgs)
mock_consumer.configure_mock(**{"getmany.return_value": f})
mock_callback = Mock()


def is_shutting_down_f(mock_func):
    def _is_shutting_down_f():
        return mock_func.called

    return _is_shutting_down_f


for is_async in [True, False]:
    await _aiokafka_consumer_loop(
        consumer=mock_consumer,
        max_buffer_size=100,
        callbacks={
            topic: asyncer.asyncify(mock_callback) if is_async else mock_callback
        },
        msg_types={topic: MyMessage},
        is_shutting_down_f=is_shutting_down_f(mock_consumer.getmany),
    )

    assert mock_consumer.getmany.call_count == 1
    mock_callback.assert_called_once_with(msg)

In [None]:
# # Sanity check: exception in callback
# # One msg, one topic, process_f called once with callback and decoded_msg

# topic = "topic_0"
# partition = 0
# topic_part_0_0 = TopicPartition(topic, partition)
# msg = MyMessage(url="http://www.acme.com", port=22)
# record = create_consumer_record(topic=topic, partition=partition, msg=msg)


# async def process_f(arg):
#     callback, msg = arg
#     await callback(msg)


# for is_async in [False, True]:
#     print(f"is_async={is_async}")
#     callback_0 = Mock()
#     callback_0.side_effect = Mock(side_effect=Exception("Test"))
#     await process_msgs(
#         msgs={topic_part_0_0: [record]},
#         callbacks={topic: (asyncer.asyncify(callback_0) if is_async else callback_0)},
#         msg_types={topic: MyMessage},
#         process_f=process_f,
#     )

#     #     process_f.assert_called_with((callback_0, msg))
#     callback_0.assert_called_with(msg)
#     assert callback_0.call_count == 1

In [None]:
#| export

def sanitize_kafka_config(**kwargs):
    """Sanitize Kafka config"""
    return {k: "*"*len(v) if "pass" in k.lower() else v for k, v in kwargs.items()}

In [None]:
kwargs = {'bootstrap_servers': 'whatever.cloud:9092',
 'auto_offset_reset': 'earliest',
 'security_protocol': 'SASL_SSL',
 'sasl_mechanism': 'PLAIN',
 'sasl_plain_username': 'username',
 'sasl_plain_password': 'password',
 'ssl_context': "something"}

assert sanitize_kafka_config(**kwargs)["sasl_plain_password"] == '********'

In [None]:
# | export


async def aiokafka_consumer_loop(  # type: ignore
    topics: List[str],
    *,
    bootstrap_servers: str,
    auto_offset_reset: str,
    max_poll_records: int = 1_000,
    timeout_ms: int = 100,
    max_buffer_size: int = 10_000,
    callbacks: Dict[str, Callable[[BaseModel], Union[None, Awaitable[None]]]],
    msg_types: Dict[str, Type[BaseModel]],
    is_shutting_down_f: Callable[[], bool],
    **kwargs,
) -> None:
    """todo: write docs"""
    logger.info(f"aiokafka_consumer_loop() starting...")
    try:
        consumer_kwargs = dict(
            bootstrap_servers=bootstrap_servers,
            auto_offset_reset=auto_offset_reset,
            max_poll_records=max_poll_records,
        )
        consumer_kwargs = {**consumer_kwargs, **kwargs}
        consumer = AIOKafkaConsumer(
            **consumer_kwargs,
        )
        logger.info(
            f"aiokafka_consumer_loop(): Consumer created using the following parameters: {sanitize_kafka_config(**consumer_kwargs)}"
        )

        await consumer.start()
        logger.info("aiokafka_consumer_loop(): Consumer started.")
        consumer.subscribe(topics)
        logger.info("aiokafka_consumer_loop(): Consumer subscribed.")

        try:
            await _aiokafka_consumer_loop(
                consumer=consumer,
                max_buffer_size=max_buffer_size,
                timeout_ms=timeout_ms,
                callbacks=callbacks,
                msg_types=msg_types,
                is_shutting_down_f=is_shutting_down_f,
            )
        finally:
            await consumer.stop()
            logger.info(f"aiokafka_consumer_loop(): Consumer stopped.")
            logger.info(f"aiokafka_consumer_loop() finished.")
    except Exception as e:
        logger.error(f"aiokafka_consumer_loop(): unexpected exception raised: '{e.__repr__()}'")
        raise e

In [None]:
msgs_sent = 9178
msgs = [
    MyMessage(url="http://www.ai.com", port=port).json().encode("utf-8")
    for port in range(msgs_sent)
]
msgs_received = 0


async def count_msg(msg: MyMessage):
    global msgs_received
    msgs_received = msgs_received + 1
    if msgs_received % 1000 == 0:
        logger.info(f"{msgs_received=}")


async with create_and_fill_testing_topic(
    topic_prefix="my_topic_aiokafka_consumer_loop_",
    msgs=msgs,
    seed=seed(1),
    **aiokafka_config,
) as topic:
    await aiokafka_consumer_loop(
        topics=[topic],
        auto_offset_reset="earliest",
        callbacks={topic: count_msg},
        msg_types={topic: MyMessage},
        is_shutting_down_f=true_after(2),
        **aiokafka_config,
    )

assert msgs_sent == msgs_received, f"{msgs_sent} != {msgs_received}"

[INFO] fast_kafka_api.helpers: create_missing_topics(['my_topic_aiokafka_consumer_loop_5696213874']): new_topics = [NewTopic(topic=my_topic_aiokafka_consumer_loop_5696213874,num_partitions=3)]


producing to 'my_topic_aiokafka_consumer_loop_5696213874':   0%|          | 0/9178 [00:00<?, ?it/s]

[INFO] __main__: aiokafka_consumer_loop() starting...
[INFO] __main__: aiokafka_consumer_loop(): Consumer created using the following parameters: {'bootstrap_servers': 'tvrtko-fast-kafka-api-kafka-1:9092', 'auto_offset_reset': 'earliest', 'max_poll_records': 1000}
[INFO] __main__: aiokafka_consumer_loop(): Consumer started.
[INFO] aiokafka.consumer.subscription_state: Updating subscribed topics to: frozenset({'my_topic_aiokafka_consumer_loop_5696213874'})
[INFO] aiokafka.consumer.consumer: Subscribed to topic(s): {'my_topic_aiokafka_consumer_loop_5696213874'}
[INFO] __main__: aiokafka_consumer_loop(): Consumer subscribed.
[INFO] aiokafka.consumer.group_coordinator: Metadata for topic has changed from {} to {'my_topic_aiokafka_consumer_loop_5696213874': 3}. 
[INFO] __main__: msgs_received=1000
[INFO] __main__: msgs_received=2000
[INFO] __main__: msgs_received=3000
[INFO] __main__: msgs_received=4000
[INFO] __main__: msgs_received=5000
[INFO] __main__: msgs_received=6000
[INFO] __main__:

In [None]:
# | notest

msgs_sent = 100_000
msgs = [
    MyMessage(url="http://www.ai.com", port=port).json().encode("utf-8")
    for port in range(msgs_sent)
]
msgs_received = 0


async def count_msg(msg: MyMessage):
    global msgs_received
    msgs_received = msgs_received + 1
    if msgs_received % 1000 == 0:
        logger.info(f"{msgs_received=}")


def _is_shutting_down_f():
    return msgs_received == msgs_sent


async with create_and_fill_testing_topic(
    topic_prefix="my_topic_aiokafka_consumer_loop_",
    msgs=msgs,
    seed=seed(3),
    **aiokafka_config,
) as topic:
    start = datetime.now()
    await aiokafka_consumer_loop(
        topics=[topic],
        auto_offset_reset="earliest",
        callbacks={topic: count_msg},
        msg_types={topic: MyMessage},
        is_shutting_down_f=_is_shutting_down_f,
        **aiokafka_config
    )
    t = (datetime.now() - start) / timedelta(seconds=1)
    thrp = msgs_received / t

    print(f"Messages processed: {msgs_received:,d}")
    print(f"Time              : {t:.2f} s")
    print(f"Throughput.       : {thrp:,.0f} msg/s")

[INFO] fast_kafka_api.helpers: create_missing_topics(['my_topic_aiokafka_consumer_loop_5168585847']): new_topics = [NewTopic(topic=my_topic_aiokafka_consumer_loop_5168585847,num_partitions=3)]


producing to 'my_topic_aiokafka_consumer_loop_5168585847':   0%|          | 0/100000 [00:00<?, ?it/s]

[INFO] __main__: aiokafka_consumer_loop() starting...
[INFO] __main__: aiokafka_consumer_loop(): Consumer created using the following parameters: {'bootstrap_servers': 'tvrtko-fast-kafka-api-kafka-1:9092', 'auto_offset_reset': 'earliest', 'max_poll_records': 1000}
[INFO] __main__: aiokafka_consumer_loop(): Consumer started.
[INFO] aiokafka.consumer.subscription_state: Updating subscribed topics to: frozenset({'my_topic_aiokafka_consumer_loop_5168585847'})
[INFO] aiokafka.consumer.consumer: Subscribed to topic(s): {'my_topic_aiokafka_consumer_loop_5168585847'}
[INFO] __main__: aiokafka_consumer_loop(): Consumer subscribed.
[INFO] aiokafka.consumer.group_coordinator: Metadata for topic has changed from {} to {'my_topic_aiokafka_consumer_loop_5168585847': 3}. 
[INFO] __main__: msgs_received=1000
[INFO] __main__: msgs_received=2000
[INFO] __main__: msgs_received=3000
[INFO] __main__: msgs_received=4000
[INFO] __main__: msgs_received=5000
[INFO] __main__: msgs_received=6000
[INFO] __main__: