diff --git a/rasa/core/agent.py b/rasa/core/agent.py index 24fd7868ec95..428869a90be1 100644 --- a/rasa/core/agent.py +++ b/rasa/core/agent.py @@ -4,10 +4,10 @@ import tempfile import uuid from asyncio import CancelledError -from sanic import Sanic from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union import aiohttp +from sanic import Sanic import rasa import rasa.utils.io @@ -23,15 +23,15 @@ from rasa.core.domain import Domain, InvalidDomain from rasa.core.exceptions import AgentNotReady from rasa.core.interpreter import NaturalLanguageInterpreter, RegexInterpreter +from rasa.core.lock_store import LockStore, CounterLockStore from rasa.core.nlg import NaturalLanguageGenerator -from rasa.core.policies.policy import Policy -from rasa.core.policies.form_policy import FormPolicy from rasa.core.policies.ensemble import PolicyEnsemble, SimplePolicyEnsemble +from rasa.core.policies.form_policy import FormPolicy from rasa.core.policies.memoization import MemoizationPolicy +from rasa.core.policies.policy import Policy from rasa.core.processor import MessageProcessor from rasa.core.tracker_store import InMemoryTrackerStore, TrackerStore from rasa.core.trackers import DialogueStateTracker -from rasa.core.utils import LockCounter from rasa.model import ( get_model_subdirectories, get_latest_model, @@ -284,6 +284,7 @@ def __init__( interpreter: Optional[NaturalLanguageInterpreter] = None, generator: Union[EndpointConfig, NaturalLanguageGenerator, None] = None, tracker_store: Optional[TrackerStore] = None, + lock_store: Optional[LockStore] = None, action_endpoint: Optional[EndpointConfig] = None, fingerprint: Optional[Text] = None, model_directory: Optional[Text] = None, @@ -305,8 +306,8 @@ def __init__( self.nlg = NaturalLanguageGenerator.create(generator, self.domain) self.tracker_store = self.create_tracker_store(tracker_store, self.domain) + self.lock_store = self.create_lock_store(lock_store) self.action_endpoint = action_endpoint - self.conversations_in_processing = {} self._set_fingerprint(fingerprint) self.model_directory = model_directory @@ -426,34 +427,11 @@ def noop(_): processor = self.create_processor(message_preprocessor) - # get the lock for the current conversation - lock = self.conversations_in_processing.get(message.sender_id) - if not lock: - logger.debug( - "Created a new lock for conversation '{}'".format(message.sender_id) - ) - lock = LockCounter() - self.conversations_in_processing[message.sender_id] = lock - try: - async with lock: - # this makes sure that there can always only be one coroutine - # handling a conversation at any point in time - # Note: this doesn't support multi-processing, it just works - # for coroutines. If there are multiple processes handling - # messages, an external system needs to make sure messages - # for the same conversation are always processed by the same - # process. + async with self.lock_store.lock(message.sender_id): return await processor.handle_message(message) finally: - if not lock.is_someone_waiting(): - # dispose of the lock if no one needs it to avoid - # accumulating locks - del self.conversations_in_processing[message.sender_id] - logger.debug( - "Deleted lock for conversation '{}' (unused)" - "".format(message.sender_id) - ) + self.lock_store.cleanup(message.sender_id) # noinspection PyUnusedLocal def predict_next(self, sender_id: Text, **kwargs: Any) -> Optional[Dict[Text, Any]]: @@ -841,6 +819,13 @@ def create_tracker_store( else: return InMemoryTrackerStore(domain) + @staticmethod + def create_lock_store(store: Optional[LockStore]) -> LockStore: + if store is not None: + return store + else: + return CounterLockStore() + @staticmethod def _create_ensemble( policies: Union[List[Policy], PolicyEnsemble, None] @@ -856,7 +841,7 @@ def _create_ensemble( raise ValueError( "Invalid param `policies`. Passed object is " "of type '{}', but should be policy, an array of " - "policies, or a policy ensemble".format(passed_type) + "policies, or a policy ensemble.".format(passed_type) ) @staticmethod diff --git a/rasa/core/lock_store.py b/rasa/core/lock_store.py new file mode 100644 index 000000000000..fc7bf4f2f765 --- /dev/null +++ b/rasa/core/lock_store.py @@ -0,0 +1,104 @@ +import asyncio +import logging +import typing +from typing import Text, Optional + +if typing.TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +class LockStore(object): + def lock(self, conversation_id: Text): + raise NotImplementedError + + def cleanup(self, conversation_id: Text): + """Dispose of the lock if no one needs it to avoid accumulating locks.""" + + pass + + +class LockCounter(asyncio.Lock): + """Decorated asyncio lock that counts how many coroutines are waiting. + The counter can be used to discard the lock when there is no coroutine + waiting for it. For this to work, there should not be any execution yield + between retrieving the lock and acquiring it, otherwise there might be + race conditions.""" + + def __init__(self) -> None: + super().__init__() + self.wait_counter = 0 + + async def acquire(self) -> bool: + """Acquire the lock, makes sure only one coroutine can retrieve it.""" + + self.wait_counter += 1 + try: + return await super(LockCounter, self).acquire() # type: ignore + finally: + self.wait_counter -= 1 + + def is_someone_waiting(self) -> bool: + """Check if a coroutine is waiting for this lock to be freed.""" + + return self.wait_counter != 0 + + +class CounterLockStore(LockStore): + """Store for LockCounter locks.""" + + def __init__(self) -> None: + self.conversation_locks = {} + + def lock(self, conversation_id: Text) -> LockCounter: + lock = self._get_lock(conversation_id) + if not lock: + lock = self._create_lock(conversation_id) + + return lock + + def _get_lock(self, conversation_id: Text) -> Optional[LockCounter]: + return self.conversation_locks.get(conversation_id) + + def _create_lock(self, conversation_id: Text) -> LockCounter: + lock = LockCounter() + self.conversation_locks[conversation_id] = lock + return lock + + def _is_someone_waiting(self, conversation_id: Text) -> bool: + lock = self._get_lock(conversation_id) + if lock: + return lock.is_someone_waiting() + + return False + + def cleanup(self, conversation_id: Text) -> None: + if not self._is_someone_waiting(conversation_id): + del self.conversation_locks[conversation_id] + logger.debug( + "Deleted lock for conversation '{}' (unused)".format(conversation_id) + ) + + +class RedisLockStore(LockStore): + def __init__( + self, + host: Optional[Text] = None, + port: Optional[int] = None, + db: Text = "rasa", + password: Text = None, + lock_timeout: float = 0.5, + retry_count: int = 20, + ) -> None: + from aioredlock import Aioredlock + + redis_instances = [{"host": host, "port": port, "db": db, "password": password}] + self.lock_manager = Aioredlock( + redis_connections=[redis_instances], + lock_timeout=lock_timeout, + retry_count=retry_count, + ) + + def lock(self, conversation_id: Text) -> "Aioredlock": + return self.lock_manager.lock(conversation_id) diff --git a/rasa/core/utils.py b/rasa/core/utils.py index ac8691b83610..5442f6ebf675 100644 --- a/rasa/core/utils.py +++ b/rasa/core/utils.py @@ -1,16 +1,15 @@ # -*- coding: utf-8 -*- import argparse -import asyncio import json import logging import re import sys -from pathlib import Path -from typing import Union from asyncio import Future from hashlib import md5, sha1 from io import StringIO +from pathlib import Path from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Text, Tuple, Callable +from typing import Union import aiohttp from aiohttp import InvalidURL @@ -18,12 +17,11 @@ from sanic.views import CompositionView import rasa.utils.io as io_utils -from rasa.utils.endpoints import read_endpoint_config - # backwards compatibility 1.0.x # noinspection PyUnresolvedReferences from rasa.utils.endpoints import concat_url +from rasa.utils.endpoints import read_endpoint_config logger = logging.getLogger(__name__) @@ -452,29 +450,3 @@ def handler(fut: Future) -> None: ) return handler - - -class LockCounter(asyncio.Lock): - """Decorated asyncio lock that counts how many coroutines are waiting. - - The counter can be used to discard the lock when there is no coroutine - waiting for it. For this to work, there should not be any execution yield - between retrieving the lock and acquiring it, otherwise there might be - race conditions.""" - - def __init__(self) -> None: - super().__init__() - self.wait_counter = 0 - - async def acquire(self) -> bool: - """Acquire the lock, makes sure only one coroutine can retrieve it.""" - - self.wait_counter += 1 - try: - return await super(LockCounter, self).acquire() # type: ignore - finally: - self.wait_counter -= 1 - - def is_someone_waiting(self) -> bool: - """Check if a coroutine is waiting for this lock to be freed.""" - return self.wait_counter != 0 diff --git a/requirements.txt b/requirements.txt index cff22f35bc60..cdf34587f860 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,3 +47,4 @@ SQLAlchemy~=1.3.3 kafka-python==1.4.6 sklearn-crfsuite==0.3.6 psycopg2-binary==2.8.2 +aioredlock==0.3.0 diff --git a/setup.py b/setup.py index 7efa8ec2e8cf..ea4a96a868e6 100644 --- a/setup.py +++ b/setup.py @@ -74,6 +74,7 @@ "SQLAlchemy~=1.3.0", "kafka-python~=1.4", "sklearn-crfsuite~=0.3.6", + "aioredlock~=0.3.0", ] extras_requires = {