In [None]:
# | default_exp _components.producer_decorator

In [None]:
# | export

import asyncio
import functools
import json
from asyncio import iscoroutinefunction  # do not use the version from inspect
from collections import namedtuple
from dataclasses import dataclass
from typing import *

import nest_asyncio
from aiokafka import AIOKafkaProducer
from pydantic import BaseModel

from fastkafka._components.meta import export

In [None]:
import asyncio
from contextlib import asynccontextmanager
from unittest.mock import Mock

from pydantic import Field

from fastkafka._testing.apache_kafka_broker import ApacheKafkaBroker
from fastkafka._testing.test_utils import mock_AIOKafkaProducer_send
from fastkafka.encoder import avro_encoder, json_encoder

In [None]:
# | export


BaseSubmodel = TypeVar("BaseSubmodel", bound=BaseModel)
BaseSubmodel


@dataclass
@export("fastkafka")
class KafkaEvent(Generic[BaseSubmodel]):
    """
    A generic class for representing Kafka events. Based on BaseSubmodel, bound to pydantic.BaseModel

    Attributes:
        message (BaseSubmodel): The message contained in the Kafka event, can be of type pydantic.BaseModel.
        key (bytes, optional): The optional key used to identify the Kafka event.
    """

    message: BaseSubmodel
    key: Optional[bytes] = None

In [None]:
event = KafkaEvent("Some message")
assert event.message == "Some message"
assert event.key == None

event = KafkaEvent("Some message", b"123")
assert event.message == "Some message"
assert event.key == b"123"

In [None]:
# | export

ProduceReturnTypes = Union[BaseModel, KafkaEvent[BaseModel]]

ProduceCallable = Union[
    Callable[..., ProduceReturnTypes], Callable[..., Awaitable[ProduceReturnTypes]]
]

In [None]:
# # | export


# def _to_json_utf8(o: Any) -> bytes:
#     """Converts to JSON and then encodes with UTF-8"""
#     if hasattr(o, "json"):
#         return o.json().encode("utf-8")  # type: ignore
#     else:
#         return json.dumps(o).encode("utf-8")

In [None]:
# assert _to_json_utf8({"a": 1, "b": [2, 3]}) == b'{"a": 1, "b": [2, 3]}'


class A(BaseModel):
    name: str = Field()
    age: int


# assert _to_json_utf8(A(name="Davor", age=12)) == b'{"name": "Davor", "age": 12}'

In [None]:
# | export


def _wrap_in_event(message: Union[BaseModel, KafkaEvent]) -> KafkaEvent:
    return message if type(message) == KafkaEvent else KafkaEvent(message)

In [None]:
message = A(name="Davor", age=12)
wrapped = _wrap_in_event(message)

assert type(wrapped) == KafkaEvent
assert wrapped.message == message
assert wrapped.key == None

In [None]:
message = KafkaEvent(A(name="Davor", age=12), b"123")
wrapped = _wrap_in_event(message)

assert type(wrapped) == KafkaEvent
assert wrapped.message == message.message
assert wrapped.key == b"123"

In [None]:
# | export


def get_loop() -> asyncio.AbstractEventLoop:
    try:
        loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
    except RuntimeError as e:
        loop = asyncio.new_event_loop()

    if loop.is_running():
        nest_asyncio.apply(loop)

    return loop

In [None]:
loop = get_loop()

assert isinstance(loop, asyncio.AbstractEventLoop)

In [None]:
# | export


def producer_decorator(
    producer_store: Dict[str, Any],
    func: ProduceCallable,
    topic: str,
    encoder_fn: Callable[[BaseModel], bytes],
) -> ProduceCallable:
    """todo: write documentation"""

    loop = get_loop()

    def release_callback(fut: asyncio.Future) -> None:
        pass

    @functools.wraps(func)
    async def _produce_async(
        *args: List[Any],
        producer_store: Dict[str, Any] = producer_store,
        f: Callable[..., Awaitable[ProduceReturnTypes]] = func,  # type: ignore
        **kwargs: Any
    ) -> ProduceReturnTypes:
        return_val = await f(*args, **kwargs)
        wrapped_val = _wrap_in_event(return_val)
        _, producer, _ = producer_store[topic]
        fut = await producer.send(
            topic, encoder_fn(wrapped_val.message), key=wrapped_val.key
        )
        fut.add_done_callback(release_callback)
        return return_val

    @functools.wraps(func)
    def _produce_sync(
        *args: List[Any],
        producer_store: Dict[str, Any] = producer_store,
        f: Callable[..., ProduceReturnTypes] = func,  # type: ignore
        loop: asyncio.AbstractEventLoop = loop,
        **kwargs: Any
    ) -> ProduceReturnTypes:
        return_val = f(*args, **kwargs)
        wrapped_val = _wrap_in_event(return_val)
        _, producer, _ = producer_store[topic]
        fut = loop.run_until_complete(
            producer.send(topic, encoder_fn(wrapped_val.message), key=wrapped_val.key)
        )
        fut.add_done_callback(release_callback)
        return return_val

    return _produce_async if iscoroutinefunction(func) else _produce_sync

In [None]:
class MockMsg(BaseModel):
    name: str = "Micky Mouse"
    id: int = 123


mock_msg = MockMsg()

topic = "test_topic"

In [None]:
@asynccontextmanager
async def mock_producer_env(
    is_sync: bool,
) -> AsyncGenerator[Tuple[Mock, AIOKafkaProducer], None]:
    try:
        with mock_AIOKafkaProducer_send() as send_mock:
            async with ApacheKafkaBroker(topics=[topic]) as bootstrap_server:
                producer = AIOKafkaProducer(bootstrap_servers=bootstrap_server)
                await producer.start()
                yield send_mock, producer
    finally:
        await producer.stop()

In [None]:
async def func(mock_msg: MockMsg) -> MockMsg:
    return mock_msg


async with mock_producer_env(is_sync=False) as (send_mock, producer):
    test_func = producer_decorator(
        {topic: (None, producer, None)}, func, topic, encoder_fn=json_encoder
    )

    assert iscoroutinefunction(test_func) == True

    value = await test_func(mock_msg)

    send_mock.assert_called_once_with(topic, mock_msg.json().encode("utf-8"), key=None)

    assert value == mock_msg

In [None]:
# Test with avro_encoder
async def func(mock_msg: MockMsg) -> MockMsg:
    return mock_msg


async with mock_producer_env(is_sync=False) as (send_mock, producer):
    test_func = producer_decorator(
        {topic: (None, producer, None)}, func, topic, encoder_fn=avro_encoder
    )

    assert iscoroutinefunction(test_func) == True

    value = await test_func(mock_msg)

    send_mock.assert_called_once_with(topic, avro_encoder(mock_msg), key=None)

    assert value == mock_msg

In [None]:
def func(mock_msg: MockMsg) -> MockMsg:
    return mock_msg


async with mock_producer_env(is_sync=True) as (send_mock, producer):
    test_func = producer_decorator(
        {topic: (None, producer, None)}, func, topic, encoder_fn=json_encoder
    )

    assert iscoroutinefunction(test_func) == False

    value = test_func(mock_msg)
    await asyncio.sleep(1)

    send_mock.assert_called_once_with(topic, mock_msg.json().encode("utf-8"), key=None)

    assert value == mock_msg

In [None]:
# Test with avro_encoder
def func(mock_msg: MockMsg) -> MockMsg:
    return mock_msg


async with mock_producer_env(is_sync=True) as (send_mock, producer):
    test_func = producer_decorator(
        {topic: (None, producer, None)}, func, topic, encoder_fn=avro_encoder
    )

    assert iscoroutinefunction(test_func) == False

    value = test_func(mock_msg)
    await asyncio.sleep(1)

    send_mock.assert_called_once_with(topic, avro_encoder(mock_msg), key=None)

    assert value == mock_msg

In [None]:
test_key = b"some_key"

In [None]:
async def func(mock_msg: MockMsg) -> KafkaEvent[MockMsg]:
    return KafkaEvent(mock_msg, key=test_key)


async with mock_producer_env(is_sync=False) as (send_mock, producer):
    test_func = producer_decorator(
        {topic: (None, producer, None)}, func, topic, encoder_fn=json_encoder
    )

    assert iscoroutinefunction(test_func) == True

    value = await test_func(mock_msg)

    send_mock.assert_called_once_with(
        topic, mock_msg.json().encode("utf-8"), key=test_key
    )

    assert value == KafkaEvent(mock_msg, key=test_key)

In [None]:
# Test with avro_encoder


async def func(mock_msg: MockMsg) -> KafkaEvent[MockMsg]:
    return KafkaEvent(mock_msg, key=test_key)


async with mock_producer_env(is_sync=False) as (send_mock, producer):
    test_func = producer_decorator(
        {topic: (None, producer, None)}, func, topic, encoder_fn=avro_encoder
    )

    assert iscoroutinefunction(test_func) == True

    value = await test_func(mock_msg)

    send_mock.assert_called_once_with(topic, avro_encoder(mock_msg), key=test_key)

    assert value == KafkaEvent(mock_msg, key=test_key)

In [None]:
async def func(mock_msg: MockMsg) -> KafkaEvent[MockMsg]:
    return KafkaEvent(mock_msg, key=test_key)


async with mock_producer_env(is_sync=False) as (send_mock, producer):
    test_func = producer_decorator(
        {topic: (None, producer, None)}, func, topic, encoder_fn=json_encoder
    )

    assert iscoroutinefunction(test_func) == True

    value = await test_func(mock_msg)

    send_mock.assert_called_once_with(
        topic, mock_msg.json().encode("utf-8"), key=test_key
    )

    assert value == KafkaEvent(mock_msg, key=test_key)

In [None]:
# Test with avro_encoder


async def func(mock_msg: MockMsg) -> KafkaEvent[MockMsg]:
    return KafkaEvent(mock_msg, key=test_key)


async with mock_producer_env(is_sync=False) as (send_mock, producer):
    test_func = producer_decorator(
        {topic: (None, producer, None)}, func, topic, encoder_fn=avro_encoder
    )

    assert iscoroutinefunction(test_func) == True

    value = await test_func(mock_msg)

    send_mock.assert_called_once_with(topic, avro_encoder(mock_msg), key=test_key)

    assert value == KafkaEvent(mock_msg, key=test_key)