Skip to content

Commit

Permalink
Merge 96a4671 into 6f1e54c
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Dec 5, 2019
2 parents 6f1e54c + 96a4671 commit b40a000
Show file tree
Hide file tree
Showing 85 changed files with 760 additions and 586 deletions.
5 changes: 3 additions & 2 deletions rasa/cli/data.py
Expand Up @@ -7,6 +7,7 @@
from rasa.cli.arguments import data as arguments
from rasa.cli.utils import get_validated_path
from rasa.constants import DEFAULT_DATA_PATH
from typing import NoReturn


# noinspection PyProtectedMember
Expand Down Expand Up @@ -74,7 +75,7 @@ def add_subparser(
arguments.set_validator_arguments(validate_parser)


def split_nlu_data(args):
def split_nlu_data(args) -> None:
from rasa.nlu.training_data.loading import load_data
from rasa.nlu.training_data.util import get_file_format

Expand All @@ -90,7 +91,7 @@ def split_nlu_data(args):
test.persist(args.out, filename=f"test_data.{fformat}")


def validate_files(args):
def validate_files(args) -> NoReturn:
"""Validate all files needed for training a model.
Fails with a non-zero exit code if there are any errors in the data."""
Expand Down
4 changes: 2 additions & 2 deletions rasa/cli/interactive.py
Expand Up @@ -79,7 +79,7 @@ def interactive_core(args: argparse.Namespace):
perform_interactive_learning(args, zipped_model)


def perform_interactive_learning(args, zipped_model):
def perform_interactive_learning(args, zipped_model) -> None:
from rasa.core.train import do_interactive_learning

if zipped_model and os.path.exists(zipped_model):
Expand Down Expand Up @@ -110,7 +110,7 @@ def get_provided_model(arg_model: Text):
return model_path


def check_training_data(args):
def check_training_data(args) -> None:
training_files = [
get_validated_path(f, "data", DEFAULT_DATA_PATH, none_is_valid=True)
for f in args.data
Expand Down
3 changes: 2 additions & 1 deletion rasa/cli/utils.py
Expand Up @@ -9,6 +9,7 @@
from questionary import Question

from rasa.constants import DEFAULT_MODELS_PATH
from typing import NoReturn

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -238,6 +239,6 @@ def print_error_and_exit(message: Text, exit_code: int = 1) -> None:
sys.exit(exit_code)


def signal_handler(sig, frame):
def signal_handler(sig, frame) -> NoReturn:
print("Goodbye 👋")
sys.exit(0)
2 changes: 1 addition & 1 deletion rasa/cli/x.py
Expand Up @@ -209,7 +209,7 @@ def start_rasa_for_local_rasa_x(args: argparse.Namespace, rasa_x_token: Text):
return p


def is_rasa_x_installed():
def is_rasa_x_installed() -> bool:
"""Check if Rasa X is installed."""

# we could also do something like checking if `import rasax` works,
Expand Down
79 changes: 61 additions & 18 deletions rasa/core/actions/action.py
Expand Up @@ -30,6 +30,7 @@
BotUttered,
)
from rasa.utils.endpoints import EndpointConfig, ClientResponseError
from typing import Coroutine, Union

if typing.TYPE_CHECKING:
from rasa.core.trackers import DialogueStateTracker
Expand Down Expand Up @@ -75,7 +76,7 @@ def default_action_names() -> List[Text]:
return [a.name() for a in default_actions()]


def combine_user_with_default_actions(user_actions):
def combine_user_with_default_actions(user_actions) -> list:
# remove all user actions that overwrite default actions
# this logic is a bit reversed, you'd think that we should remove
# the action name from the default action names if the user overwrites
Expand Down Expand Up @@ -183,7 +184,7 @@ def __init__(self, name: Text, silent_fail: Optional[bool] = False):
self.action_name = name
self.silent_fail = silent_fail

def intent_name_from_action(self):
def intent_name_from_action(self) -> Text:
return self.action_name.split(RESPOND_PREFIX)[1]

async def run(
Expand Down Expand Up @@ -232,11 +233,17 @@ class ActionUtterTemplate(Action):
Both, name and utter template, need to be specified using
the `name` method."""

def __init__(self, name, silent_fail: Optional[bool] = False):
def __init__(self, name: Text, silent_fail: Optional[bool] = False):
self.template_name = name
self.silent_fail = silent_fail

async def run(self, output_channel, nlg, tracker, domain):
async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
) -> List[Event]:
"""Simple run implementation uttering a (hopefully defined) template."""

message = await nlg.generate(self.template_name, tracker, output_channel.name())
Expand All @@ -263,10 +270,16 @@ class ActionBack(ActionUtterTemplate):
def name(self) -> Text:
return ACTION_BACK_NAME

def __init__(self):
def __init__(self) -> None:
super().__init__("utter_back", silent_fail=True)

async def run(self, output_channel, nlg, tracker, domain):
async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
) -> List[Event]:
# only utter the template if it is available
evts = await super().run(output_channel, nlg, tracker, domain)

Expand All @@ -282,7 +295,13 @@ class ActionListen(Action):
def name(self) -> Text:
return ACTION_LISTEN_NAME

async def run(self, output_channel, nlg, tracker, domain):
async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
) -> List[Event]:
return []


Expand All @@ -294,10 +313,16 @@ class ActionRestart(ActionUtterTemplate):
def name(self) -> Text:
return ACTION_RESTART_NAME

def __init__(self):
def __init__(self) -> None:
super().__init__("utter_restart", silent_fail=True)

async def run(self, output_channel, nlg, tracker, domain):
async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
) -> List[Event]:
from rasa.core.events import Restarted

# only utter the template if it is available
Expand All @@ -313,10 +338,16 @@ class ActionDefaultFallback(ActionUtterTemplate):
def name(self) -> Text:
return ACTION_DEFAULT_FALLBACK_NAME

def __init__(self):
def __init__(self) -> None:
super().__init__("utter_default", silent_fail=True)

async def run(self, output_channel, nlg, tracker, domain):
async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
) -> List[Event]:
from rasa.core.events import UserUtteranceReverted

# only utter the template if it is available
Expand All @@ -331,7 +362,13 @@ class ActionDeactivateForm(Action):
def name(self) -> Text:
return ACTION_DEACTIVATE_FORM_NAME

async def run(self, output_channel, nlg, tracker, domain):
async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
) -> List[Event]:
from rasa.core.events import Form, SlotSet

return [Form(None), SlotSet(REQUESTED_SLOT, None)]
Expand Down Expand Up @@ -360,7 +397,7 @@ def _action_call_format(
}

@staticmethod
def action_response_format_spec():
def action_response_format_spec() -> Dict[Text, Any]:
"""Expected response schema for an Action endpoint.
Used for validation of the response returned from the
Expand All @@ -379,7 +416,7 @@ def action_response_format_spec():
},
}

def _validate_action_result(self, result):
def _validate_action_result(self, result: Dict[Text, Any]) -> bool:
from jsonschema import validate
from jsonschema import ValidationError

Expand Down Expand Up @@ -428,7 +465,13 @@ async def _utter_responses(

return bot_messages

async def run(self, output_channel, nlg, tracker, domain) -> List[Event]:
async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
) -> List[Event]:
json_body = self._action_call_format(tracker, domain)

if not self.action_endpoint:
Expand Down Expand Up @@ -504,13 +547,13 @@ class ActionExecutionRejection(Exception):
"""Raising this exception will allow other policies
to predict a different action"""

def __init__(self, action_name, message=None):
def __init__(self, action_name: Text, message: Optional[Text] = None) -> None:
self.action_name = action_name
self.message = message or "Custom action '{}' rejected to run".format(
action_name
)

def __str__(self):
def __str__(self) -> Text:
return self.message


Expand Down Expand Up @@ -632,5 +675,5 @@ class ActionDefaultAskRephrase(ActionUtterTemplate):
def name(self) -> Text:
return ACTION_DEFAULT_ASK_REPHRASE_NAME

def __init__(self):
def __init__(self) -> None:
super().__init__("utter_ask_rephrase", silent_fail=True)
12 changes: 6 additions & 6 deletions rasa/core/agent.py
Expand Up @@ -410,17 +410,17 @@ def load(
path_to_model_archive=path_to_model_archive,
)

def is_core_ready(self):
def is_core_ready(self) -> bool:
"""Check if all necessary components and policies are ready to use the agent.
"""
return self.is_ready() and self.policy_ensemble
return self.is_ready() and self.policy_ensemble is not None

def is_ready(self):
def is_ready(self) -> bool:
"""Check if all necessary components are instantiated to use agent.
Policies might not be available, if this is an NLU only agent."""

return self.tracker_store and self.interpreter
return self.tracker_store is not None and self.interpreter is not None

async def parse_message_using_nlu_interpreter(
self, message_data: Text, tracker: DialogueStateTracker = None
Expand Down Expand Up @@ -584,7 +584,7 @@ def continue_training(
self.policy_ensemble.continue_training(trackers, self.domain, **kwargs)
self._set_fingerprint()

def _max_history(self):
def _max_history(self) -> int:
"""Find maximum max_history."""

max_histories = [
Expand All @@ -595,7 +595,7 @@ def _max_history(self):

return max(max_histories or [0])

def _are_all_featurizers_using_a_max_history(self):
def _are_all_featurizers_using_a_max_history(self) -> bool:
"""Check if all featurizers are MaxHistoryTrackerFeaturizer."""

def has_max_history_featurizer(policy):
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/brokers/file_producer.py
Expand Up @@ -32,7 +32,7 @@ def from_endpoint_config(
# noinspection PyArgumentList
return cls(**broker_config.kwargs)

def _event_logger(self):
def _event_logger(self) -> logging.Logger:
"""Instantiate the file logger."""

logger_file = self.path
Expand Down
10 changes: 5 additions & 5 deletions rasa/core/brokers/kafka.py
Expand Up @@ -21,7 +21,7 @@ def __init__(
topic="rasa_core_events",
security_protocol="SASL_PLAINTEXT",
loglevel=logging.ERROR,
):
) -> None:

self.producer = None
self.host = host
Expand All @@ -43,12 +43,12 @@ def from_endpoint_config(cls, broker_config) -> Optional["KafkaProducer"]:

return cls(broker_config.url, **broker_config.kwargs)

def publish(self, event):
def publish(self, event) -> None:
self._create_producer()
self._publish(event)
self._close()

def _create_producer(self):
def _create_producer(self) -> None:
import kafka

if self.security_protocol == "SASL_PLAINTEXT":
Expand All @@ -71,8 +71,8 @@ def _create_producer(self):
security_protocol=self.security_protocol,
)

def _publish(self, event):
def _publish(self, event) -> None:
self.producer.send(self.topic, event)

def _close(self):
def _close(self) -> None:
self.producer.close()
2 changes: 1 addition & 1 deletion rasa/core/channels/botframework.py
Expand Up @@ -42,7 +42,7 @@ def __init__(
self.global_uri = f"{service_url}v3/"
self.bot = bot

async def _get_headers(self):
async def _get_headers(self) -> Optional[Dict[Text, Any]]:
if BotFramework.token_expiration_date < datetime.datetime.now():
uri = f"{MICROSOFT_OAUTH2_URL}/{MICROSOFT_OAUTH2_PATH}"
grant_type = "client_credentials"
Expand Down
3 changes: 2 additions & 1 deletion rasa/core/channels/channel.py
Expand Up @@ -13,6 +13,7 @@
from rasa.constants import DOCS_BASE_URL
from rasa.core import utils
from sanic.response import HTTPResponse
from typing import NoReturn

try:
from urlparse import urljoin # pytype: disable=import-error
Expand Down Expand Up @@ -371,7 +372,7 @@ def __init__(self, message_queue: Optional[Queue] = None) -> None:
super().__init__()
self.messages = Queue() if not message_queue else message_queue

def latest_output(self):
def latest_output(self) -> NoReturn:
raise NotImplementedError("A queue doesn't allow to peek at messages.")

async def _persist_message(self, message) -> None:
Expand Down

0 comments on commit b40a000

Please sign in to comment.