/
Texthub.py
159 lines (119 loc) · 5.29 KB
/
Texthub.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
###############################################################################
# PyDial: Multi-domain Statistical Spoken Dialogue System Software
###############################################################################
#
# Copyright 2015 - 2017
# Cambridge University Engineering Department Dialogue Systems Group
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
###############################################################################
'''
Texthub.py - text level dialog system.
====================================================
Copyright CUED Dialogue Systems Group 2015 - 2017
**Basic Execution**:
>>> python texthub.py [-h] -C CONFIG [-l -r]
Optional arguments/flags [default values]::
-r semantic error rate [0]
-l set the system to use the given policy file
-h help
**Relevant config variables** (values are defaults)::
[semi_DOMAIN]
semitype = PassthroughSemI
[semo_DOMAIN]
semotype = PassthroughSemO
.. seealso:: CUED Imports/Dependencies:
import :mod:`utils.ContextLogger` |.|
import :mod:`utils.Settings` |.|
import :class:`Agent.DialogueAgent` |.|
import :mod:`ontology.Ontology`
************************
'''
import argparse, re
import torch
from Agent import DialogueAgent
from utils import ContextLogger
from utils import Settings
from ontology import Ontology
logger = ContextLogger.getLogger('')
__author__ = "cued_dialogue_systems_group"
__version__ = Settings.__version__
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
class ConsoleHub(object):
'''
text based dialog system
'''
def __init__(self):
# Dialogue Agent:
#-----------------------------------------
self.agent = DialogueAgent(hub_id='texthub')
def run(self, domain):
'''
Runs one episode through Hub
:returns: None
'''
# GENERATE A USER PREFERENCE: a * Length + (1-a) * Success
preference = torch.randn(2)
preference = (torch.abs(preference) / torch.norm(preference, p=1)).type(FloatTensor)
logger.dial('User\'s preference: [{}, {}]'.format(preference[0], preference[1]))
print 'User\'s preference: [Brevity: {}, Success: {}]'.format(preference[0], preference[1])
logger.warning("NOTE: texthub is not using any error simulation at present.")
sys_act = self.agent.start_call(session_id='texthub_dialog',
preference=preference)
print 'Prompt > ' + sys_act.prompt
while not self.agent.ENDING_DIALOG:
# USER ACT:
obs = raw_input('User > ')
'''
# Confused user act.
# lastHyps = self.errorSimulator.confuseAct(lastUserAct)
# print 'lastHyps >', lastHyps
# nullProb = 0.0
# for (act, prob) in lastHyps:
# if act == 'null()':
# nullProb += prob
# print 'Semi >', act, '['+str(prob)+']'
# if self.forceNullPositive and nullProb < 0.001:
# lastHyps.append(('null()',0.001))
# print 'Semi > null() [0.001]'
#--------------------------------
'''
# domain = None
# if "domain(" in obs:
# match = re.search("(.*)(domain\()([^\)]+)(\))(.*)",obs)
# if match is not None:
# domain = match.group(3)
# obs = match.group(1) + match.group(5)
# SYSTEM ACT:
sys_act = self.agent.continue_call(asr_info = [(obs,1.0)], domainString = domain, preference=preference)
print 'Prompt > ' + sys_act.prompt
# Process ends. -----------------------------------------------------
# NB: Can add evaluation things here - possibly useful to check things by hand with texthub ...
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='TextHub')
parser.add_argument('-C','-c', '--config', help='set config file', required=True, type=argparse.FileType('r'))
parser.add_argument('-s', '--seed', help='set random seed', type=int)
parser.set_defaults(use_color=True)
parser.add_argument('--nocolor', dest='use_color',action='store_false', help='no color in logging. best to\
turn off if dumping to file. Will be overriden by [logging] config setting of "usecolor=".')
args = parser.parse_args()
seed = Settings.init(config_file=args.config.name,seed=args.seed)
ContextLogger.createLoggingHandlers(config=Settings.config, use_color=args.use_color)
logger.info("Random Seed is {}".format(seed))
Ontology.init_global_ontology()
hub = ConsoleHub()
hub.run()
#END OF FILE