In [None]:
# | default_exp testing

In [None]:
# | export
from typing import List, Dict, Any, Optional, Callable, Tuple, Generator
from os import environ
from contextlib import contextmanager, asynccontextmanager
import random
from datetime import datetime, timedelta
import time
import asyncio
import hashlib

import unittest

from confluent_kafka.admin import AdminClient, NewTopic
from aiokafka import AIOKafkaProducer, AIOKafkaConsumer

from fast_kafka_api._components.logger import get_logger

In [None]:
# | eval: false
# allows async calls in notebooks

import nest_asyncio
nest_asyncio.apply()

In [None]:
# | export

logger = get_logger(__name__)

In [None]:
logger = get_logger(__name__, level=20)
logger.debug("ok")

In [None]:
# | export

kafka_server_url = (
    environ["KAFKA_HOSTNAME"] if "KAFKA_HOSTNAME" in environ else "localhost"
)
kafka_server_port = environ["KAFKA_PORT"] if "KAFKA_PORT" in environ else "9092"

kafka_config = {
    "bootstrap.servers": f"{kafka_server_url}:{kafka_server_port}",
    # "group.id": f"{kafka_server_url}:{kafka_server_port}_group"
}

In [None]:
# | export

def true_after(seconds: float):
    """Function returning True after a given number of seconds"""
    t = datetime.now()

    def _true_after(seconds=seconds, t=t):
        return (datetime.now() - t) > timedelta(seconds=seconds)

    return _true_after

In [None]:
f = true_after(1.1)
assert not f()
time.sleep(1)
assert not f()
time.sleep(0.1)
assert f()

In [None]:
# | export

## TODO: Check if replication num is <= of number of brokers
## TODO: Add tests for:
#             - Replication factor (less than and greater than number of brokers)
#             - Num partitions

def create_missing_topics(
    admin: AdminClient,
    topic_names: List[str],
    *,
    num_partitions: Optional[int] = None,
    replication_factor: Optional[int] = None,
    **kwargs,
) -> None:
    if not replication_factor:
        replication_factor = len(admin.list_topics().brokers)
    if not num_partitions:
        num_partitions = replication_factor
    existing_topics = list(admin.list_topics().topics.keys())
    logger.debug(
        f"create_missing_topics({topic_names}): existing_topics={existing_topics}, num_partitions={num_partitions}, replication_factor={replication_factor}"
    )
    new_topics = [
        NewTopic(
            topic,
            num_partitions=num_partitions,
            replication_factor=replication_factor,
            **kwargs,
        )
        for topic in topic_names
        if topic not in existing_topics
    ]
    if len(new_topics):
        logger.info(f"create_missing_topics({topic_names}): new_topics = {new_topics}")
        fs = admin.create_topics(new_topics)
        while not set(topic_names).issubset(set(admin.list_topics().topics.keys())):
            time.sleep(1)

In [None]:
# Check if topics are created

kafka_admin = AdminClient(kafka_config)
topics = ["A", "B", "C"]
create_missing_topics(kafka_admin, topics)

existing_topics = kafka_admin.list_topics().topics.keys()
assert set(["A", "B", "C"]) <= existing_topics

# Cleanup
[await asyncio.wrap_future(topic, loop=None) for topic in kafka_admin.delete_topics(topics=topics).values()]

22-12-23 22:40:45.356 [INFO] __main__: create_missing_topics(['A', 'B', 'C']): new_topics = [NewTopic(topic=A,num_partitions=3), NewTopic(topic=B,num_partitions=3), NewTopic(topic=C,num_partitions=3)]


[None, None, None]

In [None]:
# | export

@contextmanager
def create_testing_topic(
    kafka_config: Dict[str, Any], topic_prefix: str, seed: Optional[int] = None
) -> Generator[str, None, None]:
    # create random topic name
    random.seed(seed)
    topic = topic_prefix + str(random.randint(0, 10**10)).zfill(3)

    # delete topic if it already exists
    admin = AdminClient(kafka_config)
    existing_topics = admin.list_topics().topics.keys()
    if topic in existing_topics:
        logger.warning(f"topic {topic} exists, deleting it...")
        fs = admin.delete_topics(topics=[topic])
        results = {k: f.result() for k, f in fs.items()}
        while topic in admin.list_topics().topics.keys():
            time.sleep(1)
    try:
        # create topic if needed
        create_missing_topics(admin, [topic])
        while topic not in admin.list_topics().topics.keys():
            time.sleep(1)
        yield topic

    finally:
        pass
        # cleanup if needed again
        fs = admin.delete_topics(topics=[topic])
        while topic in admin.list_topics().topics.keys():
            time.sleep(1)

In [None]:
kafka_admin = AdminClient(kafka_config)

with create_testing_topic(kafka_config, "my_topic_", 1) as topic:
    # Check if topic is created and exists in topic list
    existing_topics = kafka_admin.list_topics().topics.keys()
    assert topic in existing_topics

# Check if topic is deleted after exiting context
existing_topics = kafka_admin.list_topics().topics.keys()
assert topic not in existing_topics

22-12-23 22:40:46.431 [INFO] __main__: create_missing_topics(['my_topic_9167024629']): new_topics = [NewTopic(topic=my_topic_9167024629,num_partitions=3)]


In [None]:
# | export

@asynccontextmanager
async def create_and_fill_testing_topic(
    msgs: List[bytes], kafka_config: Dict[str, str] = kafka_config, *, seed: int
) -> Generator[str, None, None]:

    with create_testing_topic(kafka_config, "my_topic_", seed=seed) as topic:

        producer = AIOKafkaProducer(
            bootstrap_servers=kafka_config["bootstrap.servers"]
        )
        logger.info(f"Producer {producer} created.")

        await producer.start()
        logger.info(f"Producer {producer} started.")
        try:
            fx = [
                producer.send(topic, msg, key=f"{i % 17}".encode("utf-8"), )
                for i, msg in enumerate(msgs)
            ]
            await producer.flush()
            sent_msgs = [await f for f in fx]
            msg_statuses = [await s for s in sent_msgs]
            logger.info(f"Sent messages: len(sent_msgs)={len(sent_msgs)}")

            yield topic
        finally:
            await producer.stop()
            logger.info(f"Producer {producer} stoped.")

In [None]:
msgs_sent = 317
msgs = [f"Hello world {i:05d}".encode("utf-8") for i in range(msgs_sent)]

async with create_and_fill_testing_topic(msgs, seed=1) as topic:
    consumer = AIOKafkaConsumer(
        topic,
        bootstrap_servers=kafka_config["bootstrap.servers"],
        auto_offset_reset="earliest",
        max_poll_records=100,
    )
    logger.info(f"Consumer {consumer} created.")
    await consumer.start()
    logger.info(f"Consumer {consumer} started.")
    is_shutting_down_f = true_after(5)
    msgs_received = 0
    try:
        while True:
            msgs = await consumer.getmany(timeout_ms=100)
            for k, v in msgs.items():
                msgs_received = msgs_received + len(v)
            if is_shutting_down_f():
                break

    finally:
        assert msgs_received == msgs_sent
        print(f"Total messages received: {msgs_received}")
        await consumer.stop()
        logger.info(f"Consumer {consumer} stopped.")

22-12-23 22:40:48.458 [INFO] __main__: create_missing_topics(['my_topic_9167024629']): new_topics = [NewTopic(topic=my_topic_9167024629,num_partitions=3)]
22-12-23 22:40:49.461 [INFO] __main__: Producer <aiokafka.producer.producer.AIOKafkaProducer object> created.
22-12-23 22:40:49.469 [INFO] __main__: Producer <aiokafka.producer.producer.AIOKafkaProducer object> started.
22-12-23 22:40:49.580 [INFO] __main__: Sent messages: len(sent_msgs)=317
22-12-23 22:40:49.581 [INFO] aiokafka.consumer.subscription_state: Updating subscribed topics to: frozenset({'my_topic_9167024629'})
22-12-23 22:40:49.581 [INFO] __main__: Consumer <aiokafka.consumer.consumer.AIOKafkaConsumer object> created.
22-12-23 22:40:49.586 [INFO] aiokafka.consumer.group_coordinator: Metadata for topic has changed from {} to {'my_topic_9167024629': 3}. 
22-12-23 22:40:49.588 [INFO] __main__: Consumer <aiokafka.consumer.consumer.AIOKafkaConsumer object> started.
Total messages received: 317
22-12-23 22:40:54.636 [INFO] __ma

In [None]:
# TODO: Send repeatedly?

In [None]:
# | export

def nb_safe_seed(s: str) -> Callable[[int], int]:
    """ Gets a unique seed function for a notebook
    
    Params:
        s: name of the notebook used to initialize the seed function
        
    Returns:
        A unique seed function
    """
    init_seed = int(hashlib.sha1(s.encode("utf-8")).hexdigest(), 16) % (10 ** 8)
    
    def _get_seed(x:int = 0, *, init_seed:int = init_seed) -> int:
        return init_seed + x
        
    return _get_seed

In [None]:
seed = nb_safe_seed("999_test_utils")

assert seed() == seed(0)
assert seed()+1 == seed(1)

In [None]:
# | export

@contextmanager
def mock_AIOKafkaProducer_send():
    """ Mocks **send** method of **AIOKafkaProducer**"""
    with unittest.mock.patch("__main__.AIOKafkaProducer.send") as mock:

        async def _f():
            pass

        mock.return_value = asyncio.create_task(_f())

        yield mock