In [None]:
# | default_exp _components.aiokafka_loop

In [None]:
# | export
from typing import *

from os import environ
import asyncio
from unittest.mock import MagicMock, Mock, call
from datetime import datetime, timedelta

from aiokafka import AIOKafkaConsumer
from aiokafka.structs import TopicPartition, ConsumerRecord
from pydantic import BaseModel, HttpUrl, NonNegativeInt, Field
import asyncer
import anyio

from fast_kafka_api.logger import get_logger, supress_timestamps
from fast_kafka_api.testing import true_after, create_and_fill_testing_topic, nb_safe_seed

In [None]:
seed = nb_safe_seed("_components.aiokafka_loop")

In [None]:
# | notest
# allows async calls in notebooks

import nest_asyncio
nest_asyncio.apply()

In [None]:
# | export

logger = get_logger(__name__)

In [None]:
supress_timestamps()
logger = get_logger(__name__, level=20)
logger.debug("ok")

In [None]:
kafka_server_url = environ["KAFKA_HOSTNAME"]
kafka_server_port = environ["KAFKA_PORT"]

kafka_config = {
    "bootstrap.servers": f"{kafka_server_url}:{kafka_server_port}"
}

In [None]:
class MyMessage(BaseModel):
    url: HttpUrl = Field(..., example="http://www.acme.com", description="Url example")
    port: NonNegativeInt = Field(1000)

In [None]:
# | export


async def process_msgs(
    *,
    msgs: Dict[TopicPartition, List[ConsumerRecord]],
    callbacks: Dict[
        str, Callable[[BaseModel], None]
    ],
    msg_types: Dict[str, Type[BaseModel]],
    process_f: Callable[[Callable[[BaseModel], None], BaseModel], None]
):
    for topic_partition, topic_msgs in msgs.items():
        topic = topic_partition.topic
        msg_type = msg_types[topic]
        decoded_msgs = [
            msg_type.parse_raw(msg.value.decode("utf-8")) for msg in topic_msgs
        ]
        for msg in decoded_msgs:
            await process_f((callbacks[topic], msg))

In [None]:
def create_consumer_record(topic: str, partition: int, msg: BaseModel):
    record = ConsumerRecord(
        topic= topic,
        partition= partition,
        offset= 0,
        timestamp= 0,
        timestamp_type= 0,
        key= None,
        value= msg.json().encode("utf-8"),
        checksum= 0,
        serialized_key_size= 0,
        serialized_value_size= 0,
        headers= []
    )
    return record

In [None]:
# Sanity check
# One msg, one topic, process_f called once with callback and decoded_msg

topic="topic_0"
partition=0
topic_part_0_0 = TopicPartition(topic, partition)
msg = MyMessage(url="http://www.acme.com", port=22)
record = create_consumer_record(topic=topic, partition=partition, msg=msg)

process_f = Mock()
callback_0 = Mock()

await process_msgs(
    msgs={topic_part_0_0: [record]},
    callbacks={topic: callback_0},
    msg_types={topic: MyMessage},
    process_f=asyncer.asyncify(process_f)
)

process_f.assert_called_with((callback_0, msg))
assert process_f.call_count==1

In [None]:
# Check different topics

# Two msg, two topics, process_f called twice with each callback called once

topic_part_0_0 = TopicPartition("topic_0", 0)
topic_part_1_0 = TopicPartition("topic_1", 0)

topic="topic_0"
partition=0
topic_part_0_0 = TopicPartition("topic_0", 0)
msg = MyMessage(url="http://www.acme.com", port=22)
record = create_consumer_record(topic=topic, partition=partition, msg=msg)

process_f = Mock()
callback_0 = Mock()
callback_1 = Mock()

await process_msgs(
    msgs={topic_part_0_0: [record], topic_part_1_0: [record]},
    callbacks={"topic_0": callback_0, "topic_1": callback_1},
    msg_types={"topic_0": MyMessage, "topic_1": MyMessage},
    process_f=asyncer.asyncify(process_f)
)

process_f.assert_has_calls([call((callback_0, msg)), call((callback_1, msg))], any_order=True)
assert process_f.call_count==2

In [None]:
# Check multiple msgs in same topic
# Check callback not called if there are no msgs for it in the queue

# Two msg, one topic, one callback called twice, other called nonce, produce and process_f called twice

# Check different topics

# Two msg, two topics, process_f called twice with each callback called once and produce twice

topic_part_0_0 = TopicPartition("topic_0", 0)

topic="topic_0"
partition=0
topic_part_0_0 = TopicPartition("topic_0", 0)
msg = MyMessage(url="http://www.acme.com", port=22)
record = create_consumer_record(topic=topic, partition=partition, msg=msg)

process_f = Mock()
callback_0 = Mock()
callback_1 = Mock()

await process_msgs(
    msgs={topic_part_0_0: [record, record]},
    callbacks={"topic_0": callback_0, "topic_1": callback_1},
    msg_types={"topic_0": MyMessage, "topic_1": MyMessage},
    process_f=asyncer.asyncify(process_f)
)

process_f.assert_has_calls([call((callback_0, msg)), call((callback_0, msg))], any_order=True)
assert process_f.call_count==2

In [None]:
# Check multiple partitions

# Two msg, one topic, two partitions, one callback called twice, produce and process_f called twice

topic_part_0_0 = TopicPartition("topic_0", 0)
topic_part_0_1 = TopicPartition("topic_0", 1)

msg = MyMessage(url="http://www.acme.com", port=22)
record = create_consumer_record(topic=topic, partition=partition, msg=msg)

process_f = Mock()
callback_0 = Mock()
callback_1 = Mock()

await process_msgs(
    msgs={topic_part_0_0: [
            create_consumer_record(topic="topic_0", partition=0, msg=msg)],
          topic_part_0_1: [
            create_consumer_record(topic="topic_0", partition=1, msg=msg)]
         },
    callbacks={"topic_0": callback_0, "topic_1": callback_1},
    msg_types={"topic_0": MyMessage, "topic_1": MyMessage},
    process_f=asyncer.asyncify(process_f)
)

process_f.assert_has_calls([call((callback_0, msg)), call((callback_0, msg))], any_order=True)
assert process_f.call_count==2

In [None]:
# | export


async def process_message_callback(receive_stream):
    async with receive_stream:
        async for callback, msg in receive_stream:
            await callback(msg)
            

async def _aiokafka_consumer_loop(
    *,
    consumer: AIOKafkaConsumer,
    max_buffer_size: int,
    callbacks: Dict[
        str, Callable[[BaseModel], None]
    ],
    msg_types: Dict[str, Type[BaseModel]],
    is_shutting_down_f: Callable[[], bool],
):
    send_stream, receive_stream = anyio.create_memory_object_stream(max_buffer_size=max_buffer_size)
    async with anyio.create_task_group() as tg:
        tg.start_soon(process_message_callback, receive_stream)
        async with send_stream:
            while True:
                msgs = await consumer.getmany(timeout_ms=100)
                await process_msgs(
                    msgs=msgs,
                    callbacks=callbacks,
                    msg_types=msg_types,
                    process_f=send_stream.send,
                )
                if is_shutting_down_f():
                    break

In [None]:
topic="topic_0"
msg = MyMessage(url="http://www.acme.com", port=22)
record = create_consumer_record(topic=topic, partition=partition, msg=msg)

mock_consumer = MagicMock()
msgs = {
    TopicPartition(topic, 0): [
        record
    ]
}

f = asyncio.Future()
f.set_result(msgs)
mock_consumer.configure_mock(**{'getmany.return_value': f})
mock_callback = Mock()

def is_shutting_down_f(mock_func):
    def _is_shutting_down_f():
        return mock_func.called
    return _is_shutting_down_f

await _aiokafka_consumer_loop(
    consumer= mock_consumer,
    max_buffer_size= 100,
    callbacks= {
        topic: asyncer.asyncify(mock_callback)
    },
    msg_types={
        topic: MyMessage
    },
    is_shutting_down_f = is_shutting_down_f(mock_consumer.getmany)
)

assert mock_consumer.getmany.call_count == 1
mock_callback.assert_called_once_with(msg)

In [None]:
# | export


async def aiokafka_consumer_loop(
    topics: List[str],
    *,
    bootstrap_servers: str,
    auto_offset_reset: str,
    max_poll_records: int,
    max_buffer_size: int,
    callbacks: Dict[
        str, Callable[[BaseModel], None]
    ],
    msg_types: Dict[str, Type[BaseModel]],
    is_shutting_down_f: Callable[[], bool],
    **kwargs,
):
    consumer = AIOKafkaConsumer(
        bootstrap_servers=bootstrap_servers,
        auto_offset_reset=auto_offset_reset,
        max_poll_records=max_poll_records,
    )
    logger.info("Consumer created.")

    await consumer.start()
    logger.info("Consumer started.")
    consumer.subscribe(topics)
    logger.info("Consumer subscribed.")

    try:
        await _aiokafka_consumer_loop(
            consumer=consumer,
            max_buffer_size=max_buffer_size,
            callbacks=callbacks,
            msg_types=msg_types,
            is_shutting_down_f=is_shutting_down_f,
        )
    finally:
        await consumer.stop()
        logger.info(f"Consumer stopped.")

In [None]:
# %%time

msgs_sent = 9178
msgs = [
    MyMessage(url="http://www.ai.com", port=port).json().encode("utf-8")
    for port in range(msgs_sent)
]
msgs_received = 0

async def count_msg(msg: MyMessage):
    global msgs_received
    msgs_received = msgs_received + 1

async with create_and_fill_testing_topic(kafka_config=kafka_config, msgs=msgs, seed=seed(1)) as topic:
    await aiokafka_consumer_loop(
        topics = [topic],
        bootstrap_servers = kafka_config["bootstrap.servers"],
        auto_offset_reset="earliest",
        max_poll_records=100,
        max_buffer_size= 100,
        callbacks = {topic: count_msg},
        msg_types= {topic: MyMessage},
        is_shutting_down_f= true_after(5),
    )

assert msgs_sent == msgs_received, f"{msgs_sent} != {msgs_received}"

[INFO] fast_kafka_api.testing: create_missing_topics(['my_topic_928922829']): new_topics = [NewTopic(topic=my_topic_928922829,num_partitions=3)]
[INFO] fast_kafka_api.testing: Producer <aiokafka.producer.producer.AIOKafkaProducer object> created.
[INFO] fast_kafka_api.testing: Producer <aiokafka.producer.producer.AIOKafkaProducer object> started.
[INFO] fast_kafka_api.testing: Sent messages: len(sent_msgs)=9178
[INFO] __main__: Consumer created.
[INFO] __main__: Consumer started.
[INFO] aiokafka.consumer.subscription_state: Updating subscribed topics to: frozenset({'my_topic_928922829'})
[INFO] aiokafka.consumer.consumer: Subscribed to topic(s): {'my_topic_928922829'}
[INFO] __main__: Consumer subscribed.
[INFO] aiokafka.consumer.group_coordinator: Metadata for topic has changed from {} to {'my_topic_928922829': 3}. 
[INFO] __main__: Consumer stopped.
[INFO] fast_kafka_api.testing: Producer <aiokafka.producer.producer.AIOKafkaProducer object> stoped.


In [None]:
# | notest

msgs_sent = 100000
msgs = [
    MyMessage(url="http://www.ai.com", port=port).json().encode("utf-8")
    for port in range(msgs_sent)
]
msgs_received = 0

async def count_msg(msg: MyMessage):
    global msgs_received
    msgs_received = msgs_received + 1

async with create_and_fill_testing_topic(kafka_config=kafka_config, msgs=msgs, seed=seed(3)) as topic:
    start = datetime.now()
    await aiokafka_consumer_loop(
        topics = [topic],
        bootstrap_servers = kafka_config["bootstrap.servers"],
        auto_offset_reset="earliest",
        max_poll_records=100,
        max_buffer_size=100,
        callbacks = {topic: count_msg},
        msg_types= {topic: MyMessage},
        is_shutting_down_f= true_after(5),
    )
    t = (datetime.now() - start) / timedelta(seconds=1)
    thrp = msgs_received / t
    
    print(f"Messages processed: {msgs_received:,d}")
    print(f"Time              : {t:.2f} s")
    print(f"Throughput.       : {thrp:,.0f} msg/s")

[INFO] fast_kafka_api.testing: create_missing_topics(['my_topic_1849625992']): new_topics = [NewTopic(topic=my_topic_1849625992,num_partitions=3)]
[INFO] fast_kafka_api.testing: Producer <aiokafka.producer.producer.AIOKafkaProducer object> created.
[INFO] fast_kafka_api.testing: Producer <aiokafka.producer.producer.AIOKafkaProducer object> started.
[INFO] fast_kafka_api.testing: Sent messages: len(sent_msgs)=100000
[INFO] __main__: Consumer created.
[INFO] __main__: Consumer started.
[INFO] aiokafka.consumer.subscription_state: Updating subscribed topics to: frozenset({'my_topic_1849625992'})
[INFO] aiokafka.consumer.consumer: Subscribed to topic(s): {'my_topic_1849625992'}
[INFO] __main__: Consumer subscribed.
[INFO] aiokafka.consumer.group_coordinator: Metadata for topic has changed from {} to {'my_topic_1849625992': 3}. 
[INFO] __main__: Consumer stopped.
Messages processed: 89,800
Time              : 5.01 s
Throughput.       : 17,937 msg/s
[INFO] fast_kafka_api.testing: Producer <a