Skip to content

Commit

Permalink
interact with human
Browse files Browse the repository at this point in the history
  • Loading branch information
RunzheYang committed Aug 19, 2019
1 parent c3c260a commit a00f7ab
Show file tree
Hide file tree
Showing 18 changed files with 1,633 additions and 35 deletions.
13 changes: 8 additions & 5 deletions pydial/Agent.py
Expand Up @@ -202,6 +202,7 @@ def start_call(self, session_id, domainSimulatedUsers=None, maxNumTurnsScaling=1

# SEMO:
self.prompt_str = self._agents_semo(sys_act)


self.callValidator.validate(sys_act)

Expand Down Expand Up @@ -268,20 +269,22 @@ def continue_call(self, asr_info, domainString=None, domainSimulatedUsers=None,

state = self.semi_belief_manager.update_belief_state(ASR_obs=asr_info, sys_act=prev_sys_act,
dstring=currentDomain, turn=self.currentTurn,hub_id = self.hub_id, sim_lvl=self.sim_level)

self._print_usr_act(state, currentDomain)

# 2. Policy -- Determine system act/response
sys_act = self.policy_manager.act_on(dstring=currentDomain,
state=state, preference=preference)

# Check ending the call:
sys_act = self._check_ENDING_CALL(state, sys_act) # NB: this may change the self.prompt_str

self._print_sys_act(sys_act)

self._print_sys_act(sys_act)
# SEMO:
# print(sys_act)
self.prompt_str = self._agents_semo(sys_act)

sys_act.prompt = self.prompt_str


Expand Down
6 changes: 2 additions & 4 deletions pydial/Simulate.py
Expand Up @@ -130,11 +130,11 @@ def run_dialogs(self, numDialogs):
'''
for i in range(numDialogs):
logger.info('Dialogue %d' % (i+1))
self.run(session_id='simulate_dialog'+str(i), sim_level=self.sim_level, roi=roi)
self.run(session_id='simulate_dialog'+str(i), sim_level=self.sim_level)

self.agent_factory.power_down_factory() # Important! -uses FORCE_SAVE on policy- which will finalise learning and save policy.

def run(self, session_id, agent_id='Smith', sim_level='dial_act', roi=False):
def run(self, session_id, agent_id='Smith', sim_level='dial_act'):
'''
Runs one episode through the simulator
Expand All @@ -150,8 +150,6 @@ def run(self, session_id, agent_id='Smith', sim_level='dial_act', roi=False):

preference = torch.randn(2)
preference = (torch.abs(preference) / torch.norm(preference, p=1)).type(FloatTensor)
if roi:
pass
logger.dial('User\'s preference: [{}, {}]'.format(preference[0], preference[1]))

# RESET THE USER SIMULATOR:
Expand Down
34 changes: 25 additions & 9 deletions pydial/Texthub.py
Expand Up @@ -58,6 +58,8 @@
'''
import argparse, re

import torch

from Agent import DialogueAgent
from utils import ContextLogger
from utils import Settings
Expand All @@ -67,6 +69,9 @@
__author__ = "cued_dialogue_systems_group"
__version__ = Settings.__version__

use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

class ConsoleHub(object):
'''
text based dialog system
Expand All @@ -77,15 +82,25 @@ def __init__(self):
#-----------------------------------------
self.agent = DialogueAgent(hub_id='texthub')

def run(self):
def run(self, domain):
'''
Runs one episode through Hub
:returns: None
'''

# GENERATE A USER PREFERENCE: a * Length + (1-a) * Success


preference = torch.randn(2)
preference = (torch.abs(preference) / torch.norm(preference, p=1)).type(FloatTensor)
logger.dial('User\'s preference: [{}, {}]'.format(preference[0], preference[1]))
print 'User\'s preference: [Brevity: {}, Success: {}]'.format(preference[0], preference[1])

logger.warning("NOTE: texthub is not using any error simulation at present.")
sys_act = self.agent.start_call(session_id='texthub_dialog')
sys_act = self.agent.start_call(session_id='texthub_dialog',
preference=preference)

print 'Prompt > ' + sys_act.prompt
while not self.agent.ENDING_DIALOG:
# USER ACT:
Expand All @@ -107,15 +122,16 @@ def run(self):
# print 'Semi > null() [0.001]'
#--------------------------------
'''
domain = None
if "domain(" in obs:
match = re.search("(.*)(domain\()([^\)]+)(\))(.*)",obs)
if match is not None:
domain = match.group(3)
obs = match.group(1) + match.group(5)

# domain = None
# if "domain(" in obs:
# match = re.search("(.*)(domain\()([^\)]+)(\))(.*)",obs)
# if match is not None:
# domain = match.group(3)
# obs = match.group(1) + match.group(5)

# SYSTEM ACT:
sys_act = self.agent.continue_call(asr_info = [(obs,1.0)], domainString = domain)
sys_act = self.agent.continue_call(asr_info = [(obs,1.0)], domainString = domain, preference=preference)
print 'Prompt > ' + sys_act.prompt

# Process ends. -----------------------------------------------------
Expand Down
89 changes: 89 additions & 0 deletions pydial/_morlconfigs/envelope-00.1.train.cfg
@@ -0,0 +1,89 @@
[GENERAL]
domains = CamRestaurants
singledomain = True
tracedialog = 0
seed = 30061522

[exec_config]
domain = CamRestaurants
configdir = _morlconfigs
logfiledir = _morllogs
numtrainbatches = 5
traindialogsperbatch = 1000
numbatchtestdialogs = 10
trainsourceiteration = 0
numtestdialogs = 100
trainerrorrate = 0
testerrorrate = 0
testeverybatch = False

[logging]
usecolor = False
screen_level = results
file_level = dial
file = _morllogs/envelope-seed30061522-00.1-5.train.log

[agent]
maxturns = 25

[usermodel]
usenewgoalscenarios = True
oldstylepatience = False
patience = 5
configfile = config/sampledUM.cfg

[errormodel]
nbestsize = 1
nbestgeneratormodel = SampledNBestGenerator
confscorer = additive

[summaryacts]
maxinformslots = 5
informmask = True
requestmask = True
informcountaccepted = 4
byemask = True

[policy]
policydir = _morlpolicies
belieftype = focus
useconfreq = False
learning = True
policytype = morl
startwithhello = False
inpolicyfile = _morlpolicies/envelope-00.0
outpolicyfile = _morlpolicies/envelope-00.1

[morlpolicy]
n_rew = 2
learning_rate = 0.001
epsilon = 0.5
epsilon_decay = True
gamma = 0.999
batch_size = 64
weight_num = 32
mem_size = 1000
episode_num = 1000
optimizer = Adam
save_step = 100
update_freq = 50
training_freq = 1
algorithm = envelope
beta = 0.9
homotopy = True

[gppolicy]
kernel = polysort

[gpsarsa]
random = False
scale = 3

[eval]
rewardvenuerecommended = 0
penaliseallturns = True
wrongvenuepenalty = 0
notmentionedvaluepenalty = 0
successmeasure = objective
successreward = 20

89 changes: 89 additions & 0 deletions pydial/_morlconfigs/envelope-00.2.train.cfg
@@ -0,0 +1,89 @@
[GENERAL]
domains = CamRestaurants
singledomain = True
tracedialog = 0
seed = 30061522

[exec_config]
domain = CamRestaurants
configdir = _morlconfigs
logfiledir = _morllogs
numtrainbatches = 5
traindialogsperbatch = 1000
numbatchtestdialogs = 10
trainsourceiteration = 0
numtestdialogs = 100
trainerrorrate = 0
testerrorrate = 0
testeverybatch = False

[logging]
usecolor = False
screen_level = results
file_level = dial
file = _morllogs/envelope-seed30061522-00.1-5.train.log

[agent]
maxturns = 25

[usermodel]
usenewgoalscenarios = True
oldstylepatience = False
patience = 5
configfile = config/sampledUM.cfg

[errormodel]
nbestsize = 1
nbestgeneratormodel = SampledNBestGenerator
confscorer = additive

[summaryacts]
maxinformslots = 5
informmask = True
requestmask = True
informcountaccepted = 4
byemask = True

[policy]
policydir = _morlpolicies
belieftype = focus
useconfreq = False
learning = True
policytype = morl
startwithhello = False
inpolicyfile = _morlpolicies/envelope-00.1
outpolicyfile = _morlpolicies/envelope-00.2

[morlpolicy]
n_rew = 2
learning_rate = 0.001
epsilon = 0.5
epsilon_decay = True
gamma = 0.999
batch_size = 64
weight_num = 32
mem_size = 1000
episode_num = 1000
optimizer = Adam
save_step = 100
update_freq = 50
training_freq = 1
algorithm = envelope
beta = 0.9
homotopy = True

[gppolicy]
kernel = polysort

[gpsarsa]
random = False
scale = 3

[eval]
rewardvenuerecommended = 0
penaliseallturns = True
wrongvenuepenalty = 0
notmentionedvaluepenalty = 0
successmeasure = objective
successreward = 20

89 changes: 89 additions & 0 deletions pydial/_morlconfigs/envelope-00.3.train.cfg
@@ -0,0 +1,89 @@
[GENERAL]
domains = CamRestaurants
singledomain = True
tracedialog = 0
seed = 30061522

[exec_config]
domain = CamRestaurants
configdir = _morlconfigs
logfiledir = _morllogs
numtrainbatches = 5
traindialogsperbatch = 1000
numbatchtestdialogs = 10
trainsourceiteration = 0
numtestdialogs = 100
trainerrorrate = 0
testerrorrate = 0
testeverybatch = False

[logging]
usecolor = False
screen_level = results
file_level = dial
file = _morllogs/envelope-seed30061522-00.1-5.train.log

[agent]
maxturns = 25

[usermodel]
usenewgoalscenarios = True
oldstylepatience = False
patience = 5
configfile = config/sampledUM.cfg

[errormodel]
nbestsize = 1
nbestgeneratormodel = SampledNBestGenerator
confscorer = additive

[summaryacts]
maxinformslots = 5
informmask = True
requestmask = True
informcountaccepted = 4
byemask = True

[policy]
policydir = _morlpolicies
belieftype = focus
useconfreq = False
learning = True
policytype = morl
startwithhello = False
inpolicyfile = _morlpolicies/envelope-00.2
outpolicyfile = _morlpolicies/envelope-00.3

[morlpolicy]
n_rew = 2
learning_rate = 0.001
epsilon = 0.5
epsilon_decay = True
gamma = 0.999
batch_size = 64
weight_num = 32
mem_size = 1000
episode_num = 1000
optimizer = Adam
save_step = 100
update_freq = 50
training_freq = 1
algorithm = envelope
beta = 0.9
homotopy = True

[gppolicy]
kernel = polysort

[gpsarsa]
random = False
scale = 3

[eval]
rewardvenuerecommended = 0
penaliseallturns = True
wrongvenuepenalty = 0
notmentionedvaluepenalty = 0
successmeasure = objective
successreward = 20

0 comments on commit a00f7ab

Please sign in to comment.