Skip to content
This repository has been archived by the owner on Aug 22, 2019. It is now read-only.

Commit

Permalink
Merge 19a52e9 into d3defd0
Browse files Browse the repository at this point in the history
  • Loading branch information
ricwo committed Mar 15, 2019
2 parents d3defd0 + 19a52e9 commit 8817403
Show file tree
Hide file tree
Showing 9 changed files with 331 additions and 61 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Expand Up @@ -27,6 +27,8 @@ Added
- added ``priority`` property of policies to influence best policy in
the case of equal confidence
- added rasa command line interface and API
- Rasa Stack HTTP training endpoint at ``POST /jobs``. This endpoint
will train a combined Rasa Core and NLU model.

Changed
-------
Expand Down
10 changes: 10 additions & 0 deletions data/test_config/stack_config.yml
@@ -0,0 +1,10 @@
# Configuration for Rasa NLU.
# https://rasa.com/docs/nlu/components/
language: en
pipeline: supervised_embeddings

# Configuration for Rasa Core.
# https://rasa.com/docs/core/policies
policies:
- name: MemoizationPolicy
- name: KerasPolicy
141 changes: 140 additions & 1 deletion docs/_static/spec/server.yml
Expand Up @@ -624,6 +624,37 @@ paths:
400:
$ref: '#/components/responses/400Evaluation'

/jobs:
post:
security:
- TokenAuth: []
- JWT: []
tags:
- Model
summary: Train a Rasa Stack model
description: >-
Trains a Rasa Stack model. A stack model is a model combining a
trained dialogue model with an NLU model.
operationId: trainStack
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/StackTrainingRequest'
examples:
StackTrainingRequest:
$ref: '#/components/examples/StackTrainingRequest'
responses:
200:
description: Zipped Rasa Stack model
content:
application/octet-stream:
schema:
$ref: '#/components/schemas/StackTrainingResult'
400:
$ref: '#/components/responses/400Training'

/predict:
post:
security:
Expand Down Expand Up @@ -970,6 +1001,72 @@ components:
utter_on_it 1.00 0.50 0.67 2
avg / total 1.00 0.90 0.93 10
StackTrainingRequest:
value:
config: >-
# Configuration for Rasa NLU.
# https://rasa.com/docs/nlu/components/
language: en
pipeline: tensorflow_embedding
# Configuration for Rasa Core.
# https://rasa.com/docs/core/policies
policies:
- name: MemoizationPolicy
- name: KerasPolicy
domain: >-
intents:
- greet
- goodbye
- mood_affirm
- mood_deny
- mood_great
- mood_unhappy
actions:
- utter_greet
- utter_cheer_up
- utter_did_that_help
- utter_happy
- utter_goodbye
templates:
utter_greet:
- text: "Hey! How are you?"
utter_cheer_up:
- text: "Here is something to cheer you up:"
image: "https://i.imgur.com/nGF1K8f.jpg"
utter_did_that_help:
- text: "Did that help you?"
utter_happy:
- text: "Great carry on!"
utter_goodbye:
- text: "Bye"
nlu: >-
## intent:greet
- hey
- hello
## intent:goodbye
- cu
- good by
stories: >-
## happy path
* greet
- utter_greet
* mood_great
- utter_happy
## sad path 1
* greet
- utter_greet
* mood_unhappy
out: models
force: false

responses:
200Tracker:
Expand Down Expand Up @@ -1038,6 +1135,20 @@ components:
reason: "FailedEvaluation"
code: 400

400Training:
description: Failed Training
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
example:
version: "0.12.0"
status: "failure"
message: >-
Rasa Stack model could not be trained.
reason: "TrainingError"
code: 400

403Permissions:
description: User has insufficient permission.
content:
Expand Down Expand Up @@ -1418,6 +1529,34 @@ components:
tracker:
$ref: '#/components/schemas/Tracker'

StackTrainingRequest:
type: object
properties:
domain:
type: string
description: Rasa Core domain in plain text
config:
type: string
description: Rasa Stack config in plain text
nlu:
type: string
description: Rasa NLU training data in markdown format
stories:
type: string
description: Rasa Core stories in markdown format
out:
type: string
description: Output directory
force:
type: boolean
description: >-
Force a model training even if the data has not changed
required: ["domain", "config", "nlu", "stories"]

StackTrainingResult:
type: string
format: binary

EvaluationResult:
type: object
properties:
Expand Down Expand Up @@ -1480,7 +1619,7 @@ components:
- type: string
- type: array
items:
- type: string
type: string
SlotDescription:
type: object
properties:
Expand Down
103 changes: 56 additions & 47 deletions rasa_core/agent.py
@@ -1,15 +1,16 @@
import time
import logging
import os
import shutil
import tempfile
import typing
import uuid
from gevent.pywsgi import WSGIServer
from requests.exceptions import InvalidURL, RequestException
from threading import Thread
from typing import Text, List, Optional, Callable, Any, Dict, Union

import time
from gevent.pywsgi import WSGIServer
from requests.exceptions import InvalidURL, RequestException

from rasa_core import training, constants, utils
from rasa_core.channels import UserMessage, OutputChannel, InputChannel
from rasa_core.constants import DEFAULT_REQUEST_TIMEOUT
Expand Down Expand Up @@ -83,21 +84,29 @@ def _init_model_from_server(model_server: EndpointConfig
return fingerprint, model_directory


def _is_stack_model(model_directory: Text) -> bool:
"""Decide whether a persisted model is a stack or a core model."""
def _get_stack_model_directory(model_directory: Text) -> Optional[Text]:
"""Decide whether a persisted model is a stack or a core model.
Return the root stack model directory if it's a stack model.
"""

for root, _, files in os.walk(model_directory):
if "fingerprint.json" in files:
return root

return os.path.exists(os.path.join(model_directory, "fingerprint.json"))
return None


def _load_and_set_updated_model(agent: 'Agent',
model_directory: Text,
fingerprint: Text):
"""Load the persisted model into memory and set the model on the agent."""

if _is_stack_model(model_directory):
stack_model_directory = _get_stack_model_directory(model_directory)
if stack_model_directory:
from rasa_core.interpreter import RasaNLUInterpreter
nlu_model = os.path.join(model_directory, "nlu")
core_model = os.path.join(model_directory, "core")
nlu_model = os.path.join(stack_model_directory, "nlu")
core_model = os.path.join(stack_model_directory, "core")
interpreter = RasaNLUInterpreter(model_directory=nlu_model)
else:
interpreter = agent.interpreter
Expand Down Expand Up @@ -205,14 +214,14 @@ class Agent(object):
getting the next action, and handling a channel."""

def __init__(
self,
domain: Union[Text, Domain] = None,
policies: Union[PolicyEnsemble, List[Policy], None] = None,
interpreter: Optional[NaturalLanguageInterpreter] = None,
generator: Union[EndpointConfig, 'NLG', None] = None,
tracker_store: Optional['TrackerStore'] = None,
action_endpoint: Optional[EndpointConfig] = None,
fingerprint: Optional[Text] = None
self,
domain: Union[Text, Domain] = None,
policies: Union[PolicyEnsemble, List[Policy], None] = None,
interpreter: Optional[NaturalLanguageInterpreter] = None,
generator: Union[EndpointConfig, 'NLG', None] = None,
tracker_store: Optional['TrackerStore'] = None,
action_endpoint: Optional[EndpointConfig] = None,
fingerprint: Optional[Text] = None
):
# Initializing variables with the passed parameters.
self.domain = self._create_domain(domain)
Expand Down Expand Up @@ -297,10 +306,10 @@ def is_ready(self):
self.policy_ensemble is not None)

def handle_message(
self,
message: UserMessage,
message_preprocessor: Optional[Callable[[Text], Text]] = None,
**kwargs
self,
message: UserMessage,
message_preprocessor: Optional[Callable[[Text], Text]] = None,
**kwargs
) -> Optional[List[Text]]:
"""Handle a single message."""

Expand All @@ -323,9 +332,9 @@ def noop(_):

# noinspection PyUnusedLocal
def predict_next(
self,
sender_id: Text,
**kwargs: Any
self,
sender_id: Text,
**kwargs: Any
) -> Dict[Text, Any]:
"""Handle a single message."""

Expand All @@ -334,23 +343,23 @@ def predict_next(

# noinspection PyUnusedLocal
def log_message(
self,
message: UserMessage,
message_preprocessor: Optional[Callable[[Text], Text]] = None,
**kwargs: Any
self,
message: UserMessage,
message_preprocessor: Optional[Callable[[Text], Text]] = None,
**kwargs: Any
) -> DialogueStateTracker:
"""Append a message to a dialogue - does not predict actions."""

processor = self.create_processor(message_preprocessor)
return processor.log_message(message)

def execute_action(
self,
sender_id: Text,
action: Text,
output_channel: OutputChannel,
policy: Text,
confidence: float
self,
sender_id: Text,
action: Text,
output_channel: OutputChannel,
policy: Text,
confidence: float
) -> DialogueStateTracker:
"""Handle a single message."""

Expand All @@ -362,11 +371,11 @@ def execute_action(
confidence)

def handle_text(
self,
text_message: Union[Text, Dict[Text, Any]],
message_preprocessor: Optional[Callable[[Text], Text]] = None,
output_channel: Optional[OutputChannel] = None,
sender_id: Optional[Text] = UserMessage.DEFAULT_SENDER_ID
self,
text_message: Union[Text, Dict[Text, Any]],
message_preprocessor: Optional[Callable[[Text], Text]] = None,
output_channel: Optional[OutputChannel] = None,
sender_id: Optional[Text] = UserMessage.DEFAULT_SENDER_ID
) -> Optional[List[Dict[Text, Any]]]:
"""Handle a single message.
Expand Down Expand Up @@ -402,8 +411,8 @@ def handle_text(
return self.handle_message(msg, message_preprocessor)

def toggle_memoization(
self,
activate: bool
self,
activate: bool
) -> None:
"""Toggles the memoization on and off.
Expand Down Expand Up @@ -451,8 +460,8 @@ def _are_all_featurizers_using_a_max_history(self):
"""Check if all featurizers are MaxHistoryTrackerFeaturizer."""

for policy in self.policy_ensemble.policies:
if (policy.featurizer and not
hasattr(policy.featurizer, 'max_history')):
if (policy.featurizer and
not hasattr(policy.featurizer, 'max_history')):
return False
return True

Expand Down Expand Up @@ -684,7 +693,7 @@ def create_tracker_store(store: Optional['TrackerStore'],

@staticmethod
def _create_ensemble(
policies: Union[List[Policy], PolicyEnsemble, None]
policies: Union[List[Policy], PolicyEnsemble, None]
) -> Optional[PolicyEnsemble]:
if policies is None:
return None
Expand All @@ -703,6 +712,6 @@ def _form_policy_not_present(self) -> bool:
"""Check whether form policy is not present
if there is a form action in the domain
"""
return (self.domain and self.domain.form_names and not
any(isinstance(p, FormPolicy)
for p in self.policy_ensemble.policies))
return (self.domain and self.domain.form_names and
not any(isinstance(p, FormPolicy)
for p in self.policy_ensemble.policies))

0 comments on commit 8817403

Please sign in to comment.