Skip to content

Commit

Permalink
Merge 867858b into 5a56aa7
Browse files Browse the repository at this point in the history
  • Loading branch information
ricwo committed Jun 26, 2019
2 parents 5a56aa7 + 867858b commit 1528891
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 62 deletions.
47 changes: 16 additions & 31 deletions rasa/core/agent.py
Expand Up @@ -4,10 +4,10 @@
import tempfile
import uuid
from asyncio import CancelledError
from sanic import Sanic
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union

import aiohttp
from sanic import Sanic

import rasa
import rasa.utils.io
Expand All @@ -23,15 +23,15 @@
from rasa.core.domain import Domain, InvalidDomain
from rasa.core.exceptions import AgentNotReady
from rasa.core.interpreter import NaturalLanguageInterpreter, RegexInterpreter
from rasa.core.lock_store import LockStore, CounterLockStore
from rasa.core.nlg import NaturalLanguageGenerator
from rasa.core.policies.policy import Policy
from rasa.core.policies.form_policy import FormPolicy
from rasa.core.policies.ensemble import PolicyEnsemble, SimplePolicyEnsemble
from rasa.core.policies.form_policy import FormPolicy
from rasa.core.policies.memoization import MemoizationPolicy
from rasa.core.policies.policy import Policy
from rasa.core.processor import MessageProcessor
from rasa.core.tracker_store import InMemoryTrackerStore, TrackerStore
from rasa.core.trackers import DialogueStateTracker
from rasa.core.utils import LockCounter
from rasa.model import (
get_model_subdirectories,
get_latest_model,
Expand Down Expand Up @@ -284,6 +284,7 @@ def __init__(
interpreter: Optional[NaturalLanguageInterpreter] = None,
generator: Union[EndpointConfig, NaturalLanguageGenerator, None] = None,
tracker_store: Optional[TrackerStore] = None,
lock_store: Optional[LockStore] = None,
action_endpoint: Optional[EndpointConfig] = None,
fingerprint: Optional[Text] = None,
model_directory: Optional[Text] = None,
Expand All @@ -305,8 +306,8 @@ def __init__(

self.nlg = NaturalLanguageGenerator.create(generator, self.domain)
self.tracker_store = self.create_tracker_store(tracker_store, self.domain)
self.lock_store = self.create_lock_store(lock_store)
self.action_endpoint = action_endpoint
self.conversations_in_processing = {}

self._set_fingerprint(fingerprint)
self.model_directory = model_directory
Expand Down Expand Up @@ -426,34 +427,11 @@ def noop(_):

processor = self.create_processor(message_preprocessor)

# get the lock for the current conversation
lock = self.conversations_in_processing.get(message.sender_id)
if not lock:
logger.debug(
"Created a new lock for conversation '{}'".format(message.sender_id)
)
lock = LockCounter()
self.conversations_in_processing[message.sender_id] = lock

try:
async with lock:
# this makes sure that there can always only be one coroutine
# handling a conversation at any point in time
# Note: this doesn't support multi-processing, it just works
# for coroutines. If there are multiple processes handling
# messages, an external system needs to make sure messages
# for the same conversation are always processed by the same
# process.
async with self.lock_store.lock(message.sender_id):
return await processor.handle_message(message)
finally:
if not lock.is_someone_waiting():
# dispose of the lock if no one needs it to avoid
# accumulating locks
del self.conversations_in_processing[message.sender_id]
logger.debug(
"Deleted lock for conversation '{}' (unused)"
"".format(message.sender_id)
)
self.lock_store.cleanup(message.sender_id)

# noinspection PyUnusedLocal
def predict_next(self, sender_id: Text, **kwargs: Any) -> Optional[Dict[Text, Any]]:
Expand Down Expand Up @@ -841,6 +819,13 @@ def create_tracker_store(
else:
return InMemoryTrackerStore(domain)

@staticmethod
def create_lock_store(store: Optional[LockStore]) -> LockStore:
if store is not None:
return store
else:
return CounterLockStore()

@staticmethod
def _create_ensemble(
policies: Union[List[Policy], PolicyEnsemble, None]
Expand All @@ -856,7 +841,7 @@ def _create_ensemble(
raise ValueError(
"Invalid param `policies`. Passed object is "
"of type '{}', but should be policy, an array of "
"policies, or a policy ensemble".format(passed_type)
"policies, or a policy ensemble.".format(passed_type)
)

@staticmethod
Expand Down
104 changes: 104 additions & 0 deletions rasa/core/lock_store.py
@@ -0,0 +1,104 @@
import asyncio
import logging
import typing
from typing import Text, Optional

if typing.TYPE_CHECKING:
pass

logger = logging.getLogger(__name__)


class LockStore(object):
def lock(self, conversation_id: Text):
raise NotImplementedError

def cleanup(self, conversation_id: Text):
"""Dispose of the lock if no one needs it to avoid accumulating locks."""

pass


class LockCounter(asyncio.Lock):
"""Decorated asyncio lock that counts how many coroutines are waiting.
The counter can be used to discard the lock when there is no coroutine
waiting for it. For this to work, there should not be any execution yield
between retrieving the lock and acquiring it, otherwise there might be
race conditions."""

def __init__(self) -> None:
super().__init__()
self.wait_counter = 0

async def acquire(self) -> bool:
"""Acquire the lock, makes sure only one coroutine can retrieve it."""

self.wait_counter += 1
try:
return await super(LockCounter, self).acquire() # type: ignore
finally:
self.wait_counter -= 1

def is_someone_waiting(self) -> bool:
"""Check if a coroutine is waiting for this lock to be freed."""

return self.wait_counter != 0


class CounterLockStore(LockStore):
"""Store for LockCounter locks."""

def __init__(self) -> None:
self.conversation_locks = {}

def lock(self, conversation_id: Text) -> LockCounter:
lock = self._get_lock(conversation_id)
if not lock:
lock = self._create_lock(conversation_id)

return lock

def _get_lock(self, conversation_id: Text) -> Optional[LockCounter]:
return self.conversation_locks.get(conversation_id)

def _create_lock(self, conversation_id: Text) -> LockCounter:
lock = LockCounter()
self.conversation_locks[conversation_id] = lock
return lock

def _is_someone_waiting(self, conversation_id: Text) -> bool:
lock = self._get_lock(conversation_id)
if lock:
return lock.is_someone_waiting()

return False

def cleanup(self, conversation_id: Text) -> None:
if not self._is_someone_waiting(conversation_id):
del self.conversation_locks[conversation_id]
logger.debug(
"Deleted lock for conversation '{}' (unused)".format(conversation_id)
)


class RedisLockStore(LockStore):
def __init__(
self,
host: Optional[Text] = None,
port: Optional[int] = None,
db: Text = "rasa",
password: Text = None,
lock_timeout: float = 0.5,
retry_count: int = 20,
) -> None:
from aioredlock import Aioredlock

redis_instances = [{"host": host, "port": port, "db": db, "password": password}]
self.lock_manager = Aioredlock(
redis_connections=[redis_instances],
lock_timeout=lock_timeout,
retry_count=retry_count,
)

def lock(self, conversation_id: Text) -> "Aioredlock":
return self.lock_manager.lock(conversation_id)
34 changes: 3 additions & 31 deletions rasa/core/utils.py
@@ -1,29 +1,27 @@
# -*- coding: utf-8 -*-
import argparse
import asyncio
import json
import logging
import re
import sys
from pathlib import Path
from typing import Union
from asyncio import Future
from hashlib import md5, sha1
from io import StringIO
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Text, Tuple, Callable
from typing import Union

import aiohttp
from aiohttp import InvalidURL
from sanic import Sanic
from sanic.views import CompositionView

import rasa.utils.io as io_utils
from rasa.utils.endpoints import read_endpoint_config


# backwards compatibility 1.0.x
# noinspection PyUnresolvedReferences
from rasa.utils.endpoints import concat_url
from rasa.utils.endpoints import read_endpoint_config

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -452,29 +450,3 @@ def handler(fut: Future) -> None:
)

return handler


class LockCounter(asyncio.Lock):
"""Decorated asyncio lock that counts how many coroutines are waiting.
The counter can be used to discard the lock when there is no coroutine
waiting for it. For this to work, there should not be any execution yield
between retrieving the lock and acquiring it, otherwise there might be
race conditions."""

def __init__(self) -> None:
super().__init__()
self.wait_counter = 0

async def acquire(self) -> bool:
"""Acquire the lock, makes sure only one coroutine can retrieve it."""

self.wait_counter += 1
try:
return await super(LockCounter, self).acquire() # type: ignore
finally:
self.wait_counter -= 1

def is_someone_waiting(self) -> bool:
"""Check if a coroutine is waiting for this lock to be freed."""
return self.wait_counter != 0
1 change: 1 addition & 0 deletions requirements.txt
Expand Up @@ -47,3 +47,4 @@ SQLAlchemy~=1.3.3
kafka-python==1.4.6
sklearn-crfsuite==0.3.6
psycopg2-binary==2.8.2
aioredlock==0.3.0
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -74,6 +74,7 @@
"SQLAlchemy~=1.3.0",
"kafka-python~=1.4",
"sklearn-crfsuite~=0.3.6",
"aioredlock~=0.3.0",
]

extras_requires = {
Expand Down

0 comments on commit 1528891

Please sign in to comment.