diff --git a/mockafka/aiokafka/aiokafka_consumer.py b/mockafka/aiokafka/aiokafka_consumer.py index 7688dd6..7b4aeec 100644 --- a/mockafka/aiokafka/aiokafka_consumer.py +++ b/mockafka/aiokafka/aiokafka_consumer.py @@ -39,6 +39,8 @@ def message_to_record(message: Message, offset: int) -> ConsumerRecord[bytes, by key = key_str.encode() if key_str is not None else None value = value_str.encode() if value_str is not None else None + headers = message.headers() + return ConsumerRecord( topic=topic, partition=partition, @@ -51,7 +53,7 @@ def message_to_record(message: Message, offset: int) -> ConsumerRecord[bytes, by checksum=None, # Deprecated, we won't support it serialized_key_size=len(key) if key else 0, serialized_value_size=len(value) if value else 0, - headers=tuple((message.headers() or {}).items()), + headers=tuple(headers) if headers else (), ) diff --git a/mockafka/aiokafka/aiokafka_producer.py b/mockafka/aiokafka/aiokafka_producer.py index 9682604..db45d91 100644 --- a/mockafka/aiokafka/aiokafka_producer.py +++ b/mockafka/aiokafka/aiokafka_producer.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Optional + from mockafka.kafka_store import KafkaStore from mockafka.message import Message @@ -46,7 +48,7 @@ async def send( key=None, partition=0, timestamp_ms=None, - headers=None, + headers: Optional[list[tuple[str, Optional[bytes]]]] = None, ) -> None: await self._produce( topic=topic, diff --git a/mockafka/message.py b/mockafka/message.py index 974ee0b..1a5feb9 100644 --- a/mockafka/message.py +++ b/mockafka/message.py @@ -13,7 +13,7 @@ class Message: def __init__(self, *args: Any, **kwargs: Any) -> None: - self._headers: Optional[dict] = kwargs.get("headers", None) + self._headers: Optional[list[tuple[str, Optional[bytes]]]] = kwargs.get("headers", None) self._key: Optional[str] = kwargs.get("key", None) self._value: Optional[str] = kwargs.get("value", None) self._topic: Optional[str] = kwargs.get("topic", None) @@ -37,7 +37,7 @@ def latency(self, *args, **kwargs): def leader_epoch(self, *args, **kwargs): return self._leader_epoch - def headers(self, *args, **kwargs): + def headers(self) -> Optional[list[tuple[str, Optional[bytes]]]]: return self._headers def key(self, *args, **kwargs): diff --git a/mockafka/producer.py b/mockafka/producer.py index b156ff6..d168f62 100644 --- a/mockafka/producer.py +++ b/mockafka/producer.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Optional, Union + from mockafka.cluster_metadata import ClusterMetadata from mockafka.kafka_store import KafkaStore from mockafka.message import Message @@ -11,10 +13,39 @@ class FakeProducer(object): def __init__(self, config: dict | None = None): self.kafka = KafkaStore() - def produce(self, topic, value=None, *args, **kwargs): + def produce( + self, + topic, + value=None, + key=None, + partition=None, + callback=None, + on_delivery=None, + timestamp=None, + headers: Union[ + # While Kafka itself supports only list[tuple[...]], confluent_kafka + # allows passing in a dict here. + dict[str, Optional[bytes]], + list[tuple[str, Optional[bytes]]], + None, + ] = None, + **kwargs, + ) -> None: + if isinstance(headers, dict): + headers = list(headers.items()) # create a message and call produce kafka - message = Message(value=value, topic=topic, *args, **kwargs) - self.kafka.produce(message=message, topic=topic, partition=kwargs["partition"]) + message = Message( + topic=topic, + value=value, + key=key, + partition=partition, + callback=callback, + on_delivery=on_delivery, + timestamp=timestamp, + headers=headers, + **kwargs, + ) + self.kafka.produce(message=message, topic=topic, partition=partition) def list_topics(self, topic=None, *args, **kwargs): return ClusterMetadata(topic) diff --git a/tests/test_aiokafka/test_aiokafka_producer.py b/tests/test_aiokafka/test_aiokafka_producer.py index 95ef24e..d410624 100644 --- a/tests/test_aiokafka/test_aiokafka_producer.py +++ b/tests/test_aiokafka/test_aiokafka_producer.py @@ -64,7 +64,11 @@ async def test_produce_fail_for_none_partition(self): async def test_produce_once(self) -> None: await self._create_mock_topic() await self.producer.send( - headers={}, + headers=[ + ("header-name1", b"header-value"), + ("header-name2", None), + ("header-name1", b"duplicate!"), + ], key=self.key, value=self.value, topic=self.topic, @@ -76,7 +80,14 @@ async def test_produce_once(self) -> None: self.assertEqual(message.key(), self.key) self.assertEqual(message.value(payload=None), self.value) self.assertEqual(message.topic(), self.topic) - self.assertEqual(message.headers(), {}) + self.assertEqual( + message.headers(), + [ + ("header-name1", b"header-value"), + ("header-name2", None), + ("header-name1", b"duplicate!"), + ], + ) self.assertEqual(message.error(), None) self.assertEqual(message.latency(), None) diff --git a/tests/test_async_mockafka.py b/tests/test_async_mockafka.py index 0a8f6eb..400d5fa 100644 --- a/tests/test_async_mockafka.py +++ b/tests/test_async_mockafka.py @@ -81,3 +81,25 @@ async def test_produce_and_consume_with_decorator(message=None): assert message.key == b"test_key" assert message.value == b"test_value" + + +@pytest.mark.asyncio +@asetup_kafka(topics=[{"topic": "test_topic1", "partition": 2}], clean=True) +async def test_produce_and_consume_with_headers(): + producer = FakeAIOKafkaProducer() + consumer = FakeAIOKafkaConsumer() + + await producer.start() + await consumer.start() + consumer.subscribe({"test_topic1"}) + + await producer.send( + topic="test_topic1", + headers=[('header_name', b"test"), ('header_name2', b"test")], + ) + await producer.stop() + + record = await consumer.getone() + assert record.headers == (('header_name', b"test"), ('header_name2', b"test")) + + await consumer.stop() diff --git a/tests/test_producer.py b/tests/test_producer.py index ee17877..53f2cdb 100644 --- a/tests/test_producer.py +++ b/tests/test_producer.py @@ -60,7 +60,10 @@ def test_produce_fail_for_none_partition(self): def test_produce_once(self) -> None: self.producer.produce( - headers={}, + headers={ + "header-name1": b"header-value", + "header-name2": None, + }, key=self.key, value=self.value, topic=self.topic, @@ -72,7 +75,13 @@ def test_produce_once(self) -> None: self.assertEqual(message.key(), self.key) self.assertEqual(message.value(payload=None), self.value) self.assertEqual(message.topic(), self.topic) - self.assertEqual(message.headers(), {}) + self.assertEqual( + message.headers(), + [ + ("header-name1", b"header-value"), + ("header-name2", None), + ], + ) self.assertEqual(message.error(), None) self.assertEqual(message.latency(), None)