Skip to content

Commit

Permalink
unify external class instantiation
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Dec 4, 2019
1 parent c588184 commit d10b86d
Show file tree
Hide file tree
Showing 24 changed files with 609 additions and 440 deletions.
2 changes: 2 additions & 0 deletions changelog/4801.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Allow creation of natural language interpreter and generator by classname reference
in the ``endpoints.yml``.
79 changes: 79 additions & 0 deletions rasa/core/brokers/broker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import logging
import warnings
from typing import Any, Dict, Text, Optional, Union

from rasa.utils import common
from rasa.utils.endpoints import EndpointConfig

logger = logging.getLogger(__name__)


class EventBroker:
"""Base class for any event broker implementation."""

@staticmethod
def create(
obj: Union["EventBroker", EndpointConfig, None],
) -> Optional["EventBroker"]:
"""Factory to create an event broker."""

if isinstance(obj, EventBroker):
return obj
else:
return _create_from_endpoint_config(obj)

@classmethod
def from_endpoint_config(cls, broker_config: EndpointConfig) -> "EventBroker":
raise NotImplementedError(
"Event broker must implement the `from_endpoint_config` method."
)

def publish(self, event: Dict[Text, Any]) -> None:
"""Publishes a json-formatted Rasa Core event into an event queue."""

raise NotImplementedError("Event broker must implement the `publish` method.")


def _create_from_endpoint_config(
endpoint_config: Optional[EndpointConfig],
) -> Optional["EventBroker"]:
"""Instantiate an event broker based on its configuration."""

if endpoint_config is None:
broker = None
elif endpoint_config.type is None or endpoint_config.type.lower() == "pika":
from rasa.core.brokers.pika import PikaEventBroker

# default broker if no type is set
broker = PikaEventBroker.from_endpoint_config(endpoint_config)
elif endpoint_config.type.lower() == "sql":
from rasa.core.brokers.sql import SQLEventBroker

broker = SQLEventBroker.from_endpoint_config(endpoint_config)
elif endpoint_config.type.lower() == "file":
from rasa.core.brokers.file import FileEventBroker

broker = FileEventBroker.from_endpoint_config(endpoint_config)
elif endpoint_config.type.lower() == "kafka":
from rasa.core.brokers.kafka import KafkaEventBroker

broker = KafkaEventBroker.from_endpoint_config(endpoint_config)
else:
broker = _load_from_module_string(endpoint_config)

logger.debug(f"Instantiated event broker to '{broker.__class__.__name__}'.")
return broker


def _load_from_module_string(broker_config: EndpointConfig,) -> Optional["EventBroker"]:
"""Instantiate an event broker based on its class name."""

try:
event_broker_class = common.class_from_module_path(broker_config.type)
return event_broker_class.from_endpoint_config(broker_config)
except (AttributeError, ImportError) as e:
logger.warning(
f"The `EventBroker` type '{broker_config.type}' could not be found. "
f"Not using any event broker. Error: {e}"
)
return None
26 changes: 10 additions & 16 deletions rasa/core/brokers/event_channel.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
import logging
from typing import Any, Dict, Text, Optional
import warnings

from rasa.utils.endpoints import EndpointConfig
from rasa.core.brokers.broker import EventBroker

logger = logging.getLogger(__name__)


class EventChannel:
@classmethod
def from_endpoint_config(cls, broker_config: EndpointConfig) -> "EventChannel":
raise NotImplementedError(
"Event broker must implement the `from_endpoint_config` method."
)

def publish(self, event: Dict[Text, Any]) -> None:
"""Publishes a json-formatted Rasa Core event into an event queue."""

raise NotImplementedError("Event broker must implement the `publish` method.")
# noinspection PyAbstractClass
class EventChannel(EventBroker):
warnings.warn(
"Deprecated, inherit from `EventBroker` instead of `EventChannel`. "
"The `EventChannel` class will be removed.",
DeprecationWarning,
stacklevel=2,
)
56 changes: 56 additions & 0 deletions rasa/core/brokers/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import json
import logging
import typing
import warnings
from typing import Optional, Text, Dict

from rasa.core.brokers.broker import EventBroker

if typing.TYPE_CHECKING:
from rasa.utils.endpoints import EndpointConfig

logger = logging.getLogger(__name__)


class FileEventBroker(EventBroker):
"""Log events to a file in json format.
There will be one event per line and each event is stored as json."""

DEFAULT_LOG_FILE_NAME = "rasa_event.log"

def __init__(self, path: Optional[Text] = None) -> None:
self.path = path or self.DEFAULT_LOG_FILE_NAME
self.event_logger = self._event_logger()

@classmethod
def from_endpoint_config(
cls, broker_config: Optional["EndpointConfig"]
) -> Optional["FileEventBroker"]:
if broker_config is None:
return None

# noinspection PyArgumentList
return cls(**broker_config.kwargs)

def _event_logger(self):
"""Instantiate the file logger."""

logger_file = self.path
# noinspection PyTypeChecker
query_logger = logging.getLogger("event-logger")
query_logger.setLevel(logging.INFO)
handler = logging.FileHandler(logger_file)
handler.setFormatter(logging.Formatter("%(message)s"))
query_logger.propagate = False
query_logger.addHandler(handler)

logger.info(f"Logging events to '{logger_file}'.")

return query_logger

def publish(self, event: Dict) -> None:
"""Write event to file."""

self.event_logger.info(json.dumps(event))
self.event_logger.handlers[0].flush()
61 changes: 9 additions & 52 deletions rasa/core/brokers/file_producer.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,12 @@
import json
import logging
import typing
from typing import Optional, Text, Dict
import warnings

from rasa.core.brokers.event_channel import EventChannel
from rasa.core.brokers.file import FileEventBroker

if typing.TYPE_CHECKING:
from rasa.utils.endpoints import EndpointConfig

logger = logging.getLogger(__name__)


class FileProducer(EventChannel):
"""Log events to a file in json format.
There will be one event per line and each event is stored as json."""

DEFAULT_LOG_FILE_NAME = "rasa_event.log"

def __init__(self, path: Optional[Text] = None) -> None:
self.path = path or self.DEFAULT_LOG_FILE_NAME
self.event_logger = self._event_logger()

@classmethod
def from_endpoint_config(
cls, broker_config: Optional["EndpointConfig"]
) -> Optional["FileProducer"]:
if broker_config is None:
return None

# noinspection PyArgumentList
return cls(**broker_config.kwargs)

def _event_logger(self):
"""Instantiate the file logger."""

logger_file = self.path
# noinspection PyTypeChecker
query_logger = logging.getLogger("event-logger")
query_logger.setLevel(logging.INFO)
handler = logging.FileHandler(logger_file)
handler.setFormatter(logging.Formatter("%(message)s"))
query_logger.propagate = False
query_logger.addHandler(handler)

logger.info(f"Logging events to '{logger_file}'.")

return query_logger

def publish(self, event: Dict) -> None:
"""Write event to file."""

self.event_logger.info(json.dumps(event))
self.event_logger.handlers[0].flush()
class FileProducer(FileEventBroker):
warnings.warn(
"Deprecated, the class `FileProducer` has been renamed to `FileEventBroker`. "
"The `FileProducer` class will be removed.",
DeprecationWarning,
stacklevel=2,
)
16 changes: 13 additions & 3 deletions rasa/core/brokers/kafka.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import json
import logging
import warnings
from typing import Optional

from rasa.core.brokers.event_channel import EventChannel
from rasa.core.brokers.broker import EventBroker
from rasa.utils.io import DEFAULT_ENCODING

logger = logging.getLogger(__name__)


class KafkaProducer(EventChannel):
class KafkaEventBroker(EventBroker):
def __init__(
self,
host,
Expand Down Expand Up @@ -37,7 +38,7 @@ def __init__(
logging.getLogger("kafka").setLevel(loglevel)

@classmethod
def from_endpoint_config(cls, broker_config) -> Optional["KafkaProducer"]:
def from_endpoint_config(cls, broker_config) -> Optional["KafkaEventBroker"]:
if broker_config is None:
return None

Expand Down Expand Up @@ -76,3 +77,12 @@ def _publish(self, event):

def _close(self):
self.producer.close()


class KafkaProducer(KafkaEventBroker):
warnings.warn(
"Deprecated, the class `KafkaProducer` has been renamed to "
"`KafkaEventBroker`. The `KafkaProducer` class will be removed.",
DeprecationWarning,
stacklevel=2,
)
65 changes: 60 additions & 5 deletions rasa/core/brokers/pika.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@
import logging
import typing
import os
import warnings
from collections import deque
from threading import Thread
from typing import Dict, Optional, Text, Union, Deque, Callable

import time

import rasa.core.brokers.utils as rasa_broker_utils
from rasa.constants import ENV_LOG_LEVEL_LIBRARIES, DEFAULT_LOG_LEVEL_LIBRARIES
from rasa.core.brokers.event_channel import EventChannel
from rasa.core.brokers.broker import EventBroker
from rasa.utils.endpoints import EndpointConfig
from rasa.utils.io import DEFAULT_ENCODING

if typing.TYPE_CHECKING:
from pika.adapters.blocking_connection import BlockingChannel
from pika import SelectConnection, BlockingConnection, BasicProperties
from pika.channel import Channel
import pika
from pika.connection import Parameters, Connection

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -96,7 +97,7 @@ def _get_pika_parameters(
# it can take some time until
# RabbitMQ comes up.
retry_delay=retry_delay_in_seconds,
ssl_options=rasa_broker_utils.create_rabbitmq_ssl_options(host),
ssl_options=create_rabbitmq_ssl_options(host),
)

return parameters
Expand Down Expand Up @@ -195,7 +196,7 @@ def close_pika_connection(connection: "Connection") -> None:
logger.exception("Failed to close Pika connection with host.")


class PikaProducer(EventChannel):
class PikaEventBroker(EventBroker):
def __init__(
self,
host: Text,
Expand Down Expand Up @@ -243,7 +244,7 @@ def rasa_environment(self) -> Optional[Text]:
@classmethod
def from_endpoint_config(
cls, broker_config: Optional["EndpointConfig"]
) -> Optional["PikaProducer"]:
) -> Optional["PikaEventBroker"]:
if broker_config is None:
return None

Expand Down Expand Up @@ -359,3 +360,57 @@ def _publish(self, body: Text) -> None:
f"Published Pika events to queue '{self.queue}' on host "
f"'{self.host}':\n{body}"
)


def create_rabbitmq_ssl_options(
rabbitmq_host: Optional[Text] = None,
) -> Optional["pika.SSLOptions"]:
"""Create RabbitMQ SSL options.
Requires the following environment variables to be set:
RABBITMQ_SSL_CLIENT_CERTIFICATE - path to the SSL client certificate (required)
RABBITMQ_SSL_CLIENT_KEY - path to the SSL client key (required)
RABBITMQ_SSL_CA_FILE - path to the SSL CA file for verification (optional)
RABBITMQ_SSL_KEY_PASSWORD - SSL private key password (optional)
Details on how to enable RabbitMQ TLS support can be found here:
https://www.rabbitmq.com/ssl.html#enabling-tls
Args:
rabbitmq_host: RabbitMQ hostname
Returns:
Pika SSL context of type `pika.SSLOptions` if
the RABBITMQ_SSL_CLIENT_CERTIFICATE and RABBITMQ_SSL_CLIENT_KEY
environment variables are valid paths, else `None`.
"""

client_certificate_path = os.environ.get("RABBITMQ_SSL_CLIENT_CERTIFICATE")
client_key_path = os.environ.get("RABBITMQ_SSL_CLIENT_KEY")

if client_certificate_path and client_key_path:
import pika
import rasa.server

logger.debug(f"Configuring SSL context for RabbitMQ host '{rabbitmq_host}'.")

ca_file_path = os.environ.get("RABBITMQ_SSL_CA_FILE")
key_password = os.environ.get("RABBITMQ_SSL_KEY_PASSWORD")

ssl_context = rasa.server.create_ssl_context(
client_certificate_path, client_key_path, ca_file_path, key_password
)
return pika.SSLOptions(ssl_context, rabbitmq_host)
else:
return None


class PikaProducer(PikaEventBroker):
warnings.warn(
"Deprecated, the class `PikaProducer` has been renamed to "
"`PikaEventBroker`. The `PikaProducer` class will be removed.",
DeprecationWarning,
stacklevel=2,
)
Loading

0 comments on commit d10b86d

Please sign in to comment.