Skip to content

Commit

Permalink
Fix initialization without running loop (issue #689) (#694)
Browse files Browse the repository at this point in the history
* Fix initialization without running loop (issue #689)

* Run test without running loop in thread
  • Loading branch information
ods committed Dec 11, 2020
1 parent 4fbea5c commit 306ae6e
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Exclude `.so` from source distribution

689.bugfix
Add `dataclasses` backport package to dependencies for Python 3.6
Fix initialization without running loop

693.doc
Update docs and examples to not use deprecated practices like passing loop explicitly
Expand Down
18 changes: 12 additions & 6 deletions aiokafka/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
UnrecognizedBrokerVersion,
StaleMetadata)
from aiokafka.util import (
create_task, create_future, parse_kafka_version, get_running_loop
create_future, create_task, parse_kafka_version, get_running_loop
)


Expand Down Expand Up @@ -148,8 +148,14 @@ def __init__(self, *, loop=None, bootstrap_servers='localhost',
self._sync_task = None

self._md_update_fut = None
self._md_update_waiter = create_future()
self._get_conn_lock = asyncio.Lock()
self._md_update_waiter = loop.create_future()
self._get_conn_lock_value = None

@property
def _get_conn_lock(self):
if self._get_conn_lock_value is None:
self._get_conn_lock_value = asyncio.Lock()
return self._get_conn_lock_value

def __repr__(self):
return '<AIOKafkaClient client_id=%s>' % self._client_id
Expand Down Expand Up @@ -344,7 +350,7 @@ def force_metadata_update(self):
# Wake up the `_md_synchronizer` task
if not self._md_update_waiter.done():
self._md_update_waiter.set_result(None)
self._md_update_fut = create_future()
self._md_update_fut = self._loop.create_future()
# Metadata will be updated in the background by syncronizer
return asyncio.shield(self._md_update_fut)

Expand All @@ -364,7 +370,7 @@ def add_topic(self, topic):
topic (str): topic to track
"""
if topic in self._topics:
res = create_future()
res = self._loop.create_future()
res.set_result(True)
else:
res = self.force_metadata_update()
Expand All @@ -381,7 +387,7 @@ def set_topics(self, topics):
if not topics or set(topics).difference(self._topics):
res = self.force_metadata_update()
else:
res = create_future()
res = self._loop.create_future()
res.set_result(True)
self._topics = set(topics)
return res
Expand Down
4 changes: 2 additions & 2 deletions aiokafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def send(self, request, expect_response=True):

if not expect_response:
return self._writer.drain()
fut = create_future()
fut = self._loop.create_future()
self._requests.append((correlation_id, request.RESPONSE_TYPE, fut))
return asyncio.wait_for(fut, self._request_timeout)

Expand All @@ -458,7 +458,7 @@ def _send_sasl_token(self, payload, expect_response=True):
if not expect_response:
return self._writer.drain()

fut = create_future()
fut = self._loop.create_future()
self._requests.append((None, None, fut))
return asyncio.wait_for(fut, self._request_timeout)

Expand Down
2 changes: 1 addition & 1 deletion aiokafka/consumer/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def __init__(self, *topics, loop=None,
self._max_poll_interval_ms = max_poll_interval_ms

self._check_crcs = check_crcs
self._subscription = SubscriptionState()
self._subscription = SubscriptionState(loop=loop)
self._fetcher = None
self._coordinator = None
self._loop = loop
Expand Down
3 changes: 2 additions & 1 deletion aiokafka/consumer/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ def __init__(
auto_offset_reset='latest',
isolation_level="read_uncommitted"):
self._client = client
self._loop = client._loop
self._key_deserializer = key_deserializer
self._value_deserializer = value_deserializer
self._fetch_min_bytes = fetch_min_bytes
Expand Down Expand Up @@ -440,7 +441,7 @@ def _create_fetch_waiter(self):
# Creating a fetch waiter is usually not that frequent of an operation,
# (get methods will return all data first, before a waiter is created)

fut = create_future()
fut = self._loop.create_future()
self._fetch_waiters.add(fut)
fut.add_done_callback(
lambda f, waiters=self._fetch_waiters: waiters.remove(f))
Expand Down
43 changes: 25 additions & 18 deletions aiokafka/consumer/subscription_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from asyncio import shield, Event, Future
from enum import Enum

from typing import Set, Pattern, Dict, List
from typing import Dict, FrozenSet, Iterable, List, Pattern, Set

from aiokafka.errors import IllegalStateError
from aiokafka.structs import OffsetAndMetadata, TopicPartition
from aiokafka.abc import ConsumerRebalanceListener
from aiokafka.util import create_future
from aiokafka.util import create_future, get_running_loop

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,7 +42,11 @@ class SubscriptionState:
_subscription = None # type: Subscription
_listener = None # type: ConsumerRebalanceListener

def __init__(self):
def __init__(self, loop=None):
if loop is None:
loop = get_running_loop()
self._loop = loop

self._subscription_waiters = [] # type: List[Future]
self._assignment_waiters = [] # type: List[Future]

Expand Down Expand Up @@ -145,7 +149,7 @@ def subscribe(self, topics: Set[str], listener=None):
isinstance(listener, ConsumerRebalanceListener))
self._set_subscription_type(SubscriptionType.AUTO_TOPICS)

self._change_subscription(Subscription(topics))
self._change_subscription(Subscription(topics, loop=self._loop))
self._listener = listener
self._notify_subscription_waiters()

Expand All @@ -165,7 +169,7 @@ def subscribe_pattern(self, pattern: Pattern, listener=None):
self._subscribed_pattern = pattern
self._listener = listener

def assign_from_user(self, partitions: Set[TopicPartition]):
def assign_from_user(self, partitions: Iterable[TopicPartition]):
""" Manually assign partitions. After this call automatic assignment
will be impossible and will raise an ``IllegalStateError``.
Expand All @@ -175,7 +179,7 @@ def assign_from_user(self, partitions: Set[TopicPartition]):
self._set_subscription_type(SubscriptionType.USER_ASSIGNED)

self._change_subscription(
ManualSubscription(partitions))
ManualSubscription(partitions, loop=self._loop))
self._notify_assignment_waiters()

def unsubscribe(self):
Expand Down Expand Up @@ -316,10 +320,13 @@ class Subscription:
* Unsubscribed
"""

def __init__(self, topics: Set[str]):
self._topics = frozenset(topics) # type: Set[str]
def __init__(self, topics: Iterable[str], loop=None):
if loop is None:
loop = get_running_loop()

self._topics = frozenset(topics) # type: FrozenSet[str]
self._assignment = None # type: Assignment
self.unsubscribe_future = create_future() # type: Future
self.unsubscribe_future = loop.create_future() # type: Future
self._reassignment_in_progress = True

@property
Expand All @@ -334,7 +341,7 @@ def topics(self):
def assignment(self):
return self._assignment

def _assign(self, topic_partitions: Set[TopicPartition]):
def _assign(self, topic_partitions: Iterable[TopicPartition]):
for tp in topic_partitions:
assert tp.topic in self._topics, \
"Received an assignment for unsubscribed topic: %s" % (tp, )
Expand All @@ -358,14 +365,10 @@ class ManualSubscription(Subscription):
""" Describes a user assignment
"""

def __init__(self, user_assignment: Set[TopicPartition]):
topics = set([])
for tp in user_assignment:
topics.add(tp.topic)

self._topics = frozenset(topics)
def __init__(self, user_assignment: Iterable[TopicPartition], loop=None):
topics = (tp.topic for tp in user_assignment)
super().__init__(topics, loop=loop)
self._assignment = Assignment(user_assignment)
self.unsubscribe_future = create_future()

def _assign(
self, topic_partitions: Set[TopicPartition]): # pragma: no cover
Expand All @@ -375,6 +378,10 @@ def _assign(
def _reassignment_in_progress(self):
return False

@_reassignment_in_progress.setter
def _reassignment_in_progress(self, value):
pass

def _begin_reassignment(self): # pragma: no cover
assert False, "Should not be called"

Expand All @@ -388,7 +395,7 @@ class Assignment:
* Unassigned
"""

def __init__(self, topic_partitions: Set[TopicPartition]):
def __init__(self, topic_partitions: Iterable[TopicPartition]):
assert isinstance(topic_partitions, (list, set, tuple))

self._topic_partitions = frozenset(topic_partitions)
Expand Down
11 changes: 7 additions & 4 deletions aiokafka/producer/message_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aiokafka.record.legacy_records import LegacyRecordBatchBuilder
from aiokafka.record.default_records import DefaultRecordBatchBuilder
from aiokafka.structs import RecordMetadata
from aiokafka.util import create_future
from aiokafka.util import create_future, get_running_loop


class BatchBuilder:
Expand Down Expand Up @@ -247,14 +247,17 @@ class MessageAccumulator:
"""
def __init__(
self, cluster, batch_size, compression_type, batch_ttl, *,
txn_manager=None):
txn_manager=None, loop=None):
if loop is None:
loop = get_running_loop()
self._loop = loop
self._batches = collections.defaultdict(collections.deque)
self._pending_batches = set([])
self._cluster = cluster
self._batch_size = batch_size
self._compression_type = compression_type
self._batch_ttl = batch_ttl
self._wait_data_future = create_future()
self._wait_data_future = loop.create_future()
self._closed = False
self._api_version = (0, 9)
self._txn_manager = txn_manager
Expand Down Expand Up @@ -413,7 +416,7 @@ def drain_by_nodes(self, ignore_nodes, muted_partitions=set()):
# task
if not self._wait_data_future.done():
self._wait_data_future.set_result(None)
self._wait_data_future = create_future()
self._wait_data_future = self._loop.create_future()

return nodes, unknown_leaders_exist

Expand Down
9 changes: 5 additions & 4 deletions aiokafka/producer/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def __init__(self, *, loop=None, bootstrap_servers='localhost',
sasl_oauth_token_provider=None):
if loop is None:
loop = get_running_loop()
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
self._loop = loop

if acks not in (0, 1, -1, 'all', _missing):
raise ValueError("Invalid ACKS parameter")
Expand Down Expand Up @@ -256,16 +259,14 @@ def __init__(self, *, loop=None, bootstrap_servers='localhost',
self._metadata = self.client.cluster
self._message_accumulator = MessageAccumulator(
self._metadata, max_batch_size, compression_attrs,
self._request_timeout_ms / 1000, txn_manager=self._txn_manager)
self._request_timeout_ms / 1000, txn_manager=self._txn_manager,
loop=loop)
self._sender = Sender(
self.client, acks=acks, txn_manager=self._txn_manager,
retry_backoff_ms=retry_backoff_ms, linger_ms=linger_ms,
message_accumulator=self._message_accumulator,
request_timeout_ms=request_timeout_ms)

self._loop = loop
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
self._closed = False

# Warn if producer was not closed properly
Expand Down
13 changes: 13 additions & 0 deletions tests/_testutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sys
import os

from concurrent import futures
from contextlib import contextmanager
from functools import wraps

Expand Down Expand Up @@ -43,6 +44,18 @@ def wrapper(test, *args, **kw):
return wrapper


def run_in_thread(fun):

@wraps(fun)
def wrapper(test, *args, **kw):
timeout = getattr(test, "TEST_TIMEOUT", 120)
with futures.ThreadPoolExecutor() as executor:
fut = executor.submit(fun, test, *args, **kw)
fut.result(timeout=timeout)

return wrapper


def kafka_versions(*versions):
# Took from kafka-python

Expand Down
16 changes: 15 additions & 1 deletion tests/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from ._testutil import (
KafkaIntegrationTestCase, StubRebalanceListener,
run_until_complete, random_string, kafka_versions)
run_until_complete, run_in_thread, random_string, kafka_versions)


class TestConsumerIntegration(KafkaIntegrationTestCase):
Expand Down Expand Up @@ -116,6 +116,20 @@ async def test_simple_consumer(self):
# will ignore, no exception expected
await consumer.stop()

@run_in_thread
def test_create_consumer_no_running_loop(self):
loop = asyncio.new_event_loop()
consumer = AIOKafkaConsumer(
self.topic, bootstrap_servers=self.hosts, loop=loop)
loop.run_until_complete(consumer.start())
try:
loop.run_until_complete(
self.send_messages(0, list(range(0, 10))))
for _ in range(10):
loop.run_until_complete(consumer.getone())
finally:
loop.run_until_complete(consumer.stop())

@run_until_complete
async def test_consumer_context_manager(self):
await self.send_messages(0, list(range(0, 10)))
Expand Down
17 changes: 16 additions & 1 deletion tests/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from kafka.cluster import ClusterMetadata

from ._testutil import (
KafkaIntegrationTestCase, run_until_complete, kafka_versions
KafkaIntegrationTestCase, run_until_complete, run_in_thread, kafka_versions
)

from aiokafka.protocol.produce import ProduceResponse
Expand Down Expand Up @@ -133,6 +133,21 @@ async def test_producer_send(self):
with self.assertRaises(ProducerClosed):
await producer.send(self.topic, b'value', key=b'KEY')

@run_in_thread
def test_create_producer_no_running_loop(self):
loop = asyncio.new_event_loop()
producer = AIOKafkaProducer(bootstrap_servers=self.hosts, loop=loop)
loop.run_until_complete(producer.start())
try:
future = loop.run_until_complete(
producer.send(self.topic, b'hello, Kafka!', partition=0))
resp = loop.run_until_complete(future)
self.assertEqual(resp.topic, self.topic)
self.assertTrue(resp.partition in (0, 1))
self.assertEqual(resp.offset, 0)
finally:
loop.run_until_complete(producer.stop())

@run_until_complete
async def test_producer_context_manager(self):
producer = AIOKafkaProducer(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_subscription_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@pytest.fixture
def subscription_state():
async def subscription_state():
return SubscriptionState()


Expand Down

0 comments on commit 306ae6e

Please sign in to comment.