/
interactive.py
124 lines (95 loc) 路 3.88 KB
/
interactive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import os
from typing import List, Text
import rasa.cli.train as train
from rasa.cli.arguments import interactive as arguments
from rasa import data, model
# noinspection PyProtectedMember
from rasa.cli.utils import get_validated_path, print_error
from rasa.constants import (
DEFAULT_DATA_PATH,
DEFAULT_MODELS_PATH,
DEFAULT_ENDPOINTS_PATH,
)
from rasa.model import get_latest_model
def add_subparser(
subparsers: argparse._SubParsersAction, parents: List[argparse.ArgumentParser]
):
interactive_parser = subparsers.add_parser(
"interactive",
conflict_handler="resolve",
parents=parents,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
help="Starts an interactive learning session to create new training data for a "
"Rasa model by chatting.",
)
interactive_parser.set_defaults(func=interactive)
interactive_parser.add_argument(
"--e2e",
action="store_true",
help="Save story files in e2e format. In this format user messages will be included in the stories.",
)
interactive_subparsers = interactive_parser.add_subparsers()
interactive_core_parser = interactive_subparsers.add_parser(
"core",
conflict_handler="resolve",
parents=parents,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
help="Starts an interactive learning session model to create new training data "
"for a Rasa Core model by chatting. Uses the 'RegexInterpreter', i.e. "
"`/<intent>` input format.",
)
interactive_core_parser.set_defaults(func=interactive_core)
arguments.set_interactive_arguments(interactive_parser)
arguments.set_interactive_core_arguments(interactive_core_parser)
def interactive(args: argparse.Namespace):
_set_not_required_args(args)
if args.model is None:
check_training_data(args)
zipped_model = train.train(args)
else:
zipped_model = get_provided_model(args.model)
perform_interactive_learning(args, zipped_model)
def _set_not_required_args(args: argparse.Namespace) -> None:
args.fixed_model_name = None
args.store_uncompressed = False
def interactive_core(args: argparse.Namespace):
_set_not_required_args(args)
if args.model is None:
zipped_model = train.train_core(args)
else:
zipped_model = get_provided_model(args.model)
perform_interactive_learning(args, zipped_model)
def perform_interactive_learning(args, zipped_model) -> None:
from rasa.core.train import do_interactive_learning
if zipped_model and os.path.exists(zipped_model):
args.model = zipped_model
with model.unpack_model(zipped_model) as model_path:
args.core, args.nlu = model.get_model_subdirectories(model_path)
stories_directory = data.get_core_directory(args.data)
args.endpoints = get_validated_path(
args.endpoints, "endpoints", DEFAULT_ENDPOINTS_PATH, True
)
do_interactive_learning(args, stories_directory)
else:
print_error(
"Interactive learning process cannot be started as no initial model was "
"found. Use 'rasa train' to train a model."
)
def get_provided_model(arg_model: Text):
model_path = get_validated_path(arg_model, "model", DEFAULT_MODELS_PATH)
if os.path.isdir(model_path):
model_path = get_latest_model(model_path)
return model_path
def check_training_data(args) -> None:
training_files = [
get_validated_path(f, "data", DEFAULT_DATA_PATH, none_is_valid=True)
for f in args.data
]
story_files, nlu_files = data.get_core_nlu_files(training_files)
if not story_files or not nlu_files:
print_error(
"Cannot train initial Rasa model. Please provide NLU and Core data "
"using the '--data' argument."
)
exit(1)