Skip to content
This repository has been archived by the owner on Aug 22, 2019. It is now read-only.

Commit

Permalink
Merge 1e60feb into 15a4c3c
Browse files Browse the repository at this point in the history
  • Loading branch information
Ghostvv committed Sep 12, 2018
2 parents 15a4c3c + 1e60feb commit 466d5a4
Show file tree
Hide file tree
Showing 3 changed files with 382 additions and 250 deletions.
2 changes: 1 addition & 1 deletion docs/policies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ This policy has a pre-defined architecture, which comprises the following steps:
- sum this raw recurrent embedding of a dialogue with system attention vector to create dialogue level embedding,
this step allows the algorithm to repeat previous system action by copying its embedding vector directly to the current time output;
- weight previous LSTM states with system attention probabilities to get the previous action embedding, the policy is likely payed attention to;
- if the similarity between this action embedding and current time dialogue embedding is high,
- if the similarity between this previous action embedding and current time dialogue embedding is high,
overwrite current LSTM state with the one from the time when this action happened;
- for each LSTM time step, calculate the similarity between the dialogue embedding and embedded system actions.
This step is based on the starspace idea from: `<https://arxiv.org/abs/1709.03856>`_.
Expand Down
212 changes: 131 additions & 81 deletions rasa_core/policies/embedding_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def __init__(
dial_embed=None, # type: Optional[tf.Tensor]
rnn_embed=None, # type: Optional[tf.Tensor]
attn_embed=None, # type: Optional[tf.Tensor]
copy_attn_debug=None # type: Optional[tf.Tensor]
copy_attn_debug=None, # type: Optional[tf.Tensor]
all_time_masks=None # type: Optional[tf.Tensor]
):
# type: (...) -> None
if featurizer:
Expand Down Expand Up @@ -205,6 +206,8 @@ def __init__(
self.attn_embed = attn_embed
self.copy_attn_debug = copy_attn_debug

self.all_time_masks = all_time_masks

# internal tf instances
self._train_op = None
self._is_training = None
Expand Down Expand Up @@ -476,12 +479,12 @@ def _create_rnn_cell(self):
)

@staticmethod
def num_mem_units(memory):
def _num_units(memory):
return memory.shape[-1].value

def _create_attn_mech(self, memory, real_length):
attn_mech = tf.contrib.seq2seq.BahdanauAttention(
num_units=self.num_mem_units(memory),
num_units=self._num_units(memory),
memory=memory,
memory_sequence_length=real_length,
normalize=True,
Expand All @@ -493,14 +496,85 @@ def _create_attn_mech(self, memory, real_length):
)
return attn_mech

def cell_input_fn(self, rnn_inputs, attention,
num_cell_input_memory_units):
"""Combine rnn inputs and attention into cell input
Args:
rnn_inputs: Tensor, first output from `rnn_and_attn_inputs_fn`.
attention: Tensor, concatenated all attentions for one time step.
num_cell_input_memory_units: int, number of the first units in
`attention` that are responsible for
enhancing cell input.
Returns:
A Tensor `cell_inputs` to feed to an rnn cell
"""

if num_cell_input_memory_units:
if num_cell_input_memory_units == self.embed_dim:
# since attention can contain additional
# attention mechanisms, only attention
# from previous user input is used as an input
# for rnn cell and only if memory before rnn
# is the same size as embed_utter
return tf.concat([rnn_inputs[:, :self.embed_dim] +
attention[:, :num_cell_input_memory_units],
rnn_inputs[:, self.embed_dim:]], -1)
else:
# in current implementation it cannot fall here,
# but this Exception exists in case
# attention before rnn is changed
raise ValueError("Number of memory units {} is not "
"equal to number of utter units {}. "
"Please modify cell input function "
"accordingly."
"".format(num_cell_input_memory_units,
self.embed_dim))
else:
return rnn_inputs

def rnn_and_attn_inputs_fn(self, inputs, cell_state):
"""Construct rnn input and attention mechanism input
Args:
inputs: Tensor, concatenated all embeddings for one time step:
[embed_utter, embed_slots, embed_prev_action].
cell_state: Tensor, state of an rnn cell
Returns:
Tuple of Tensors `rnn_inputs, attn_inputs` to feed to
rnn and attention mechanisms
"""

# the hidden state c and slots are not included,
# in hope that algorithm would learn correct attention
# regardless of the hidden state c of an lstm and slots
if isinstance(cell_state, tf.contrib.rnn.LSTMStateTuple):
attn_inputs = tf.concat([inputs[:, :self.embed_dim],
cell_state.h], -1)
else:
attn_inputs = tf.concat([inputs[:, :self.embed_dim],
cell_state], -1)

# include slots in inputs but exclude previous action, since
# rnn should get previous action from its hidden state
rnn_inputs = inputs[:, :(self.embed_dim +
self.embed_dim)]

return rnn_inputs, attn_inputs

def _create_attn_cell(self, cell, embed_utter, embed_prev_action,
real_length, embed_for_no_intent,
embed_for_no_action, embed_for_action_listen):
"""Wrap cell in attention wrapper with given memory"""

if self.attn_before_rnn:
# create attention over previous user input
num_mem_units = self.num_mem_units(embed_utter)
num_memory_units_before_rnn = self._num_units(embed_utter)
attn_mech = self._create_attn_mech(embed_utter, real_length)

# create mask for empty user input not to pay attention to it
Expand All @@ -512,7 +586,7 @@ def _create_attn_cell(self, cell, embed_utter, embed_prev_action,
else:
attn_mech = None
ignore_mask = None
num_mem_units = 0
num_memory_units_before_rnn = None
attn_shift_range = None

if self.attn_after_rnn:
Expand Down Expand Up @@ -552,59 +626,17 @@ def _create_attn_cell(self, cell, embed_utter, embed_prev_action,
else:
index_of_attn_to_copy = None

num_utter_units = self.num_mem_units(embed_utter)

def cell_input_fn(inputs, attention):
"""Combine rnn inputs and attention into cell input"""
if num_mem_units > 0:
if num_mem_units == num_utter_units:
# since attention can contain additional
# attention mechanisms, only attention
# from previous user input is used as an input
# for rnn cell and only if memory before rnn
# is the same size as embed_utter
return tf.concat([inputs[:, :num_utter_units] +
attention[:, :num_utter_units],
inputs[:, num_utter_units:]], -1)
else:
# in current implementation it cannot fall here,
# but this Exception exists in case
# attention before rnn is changed
raise ValueError("Number of memory units {} is not "
"equal to number of utter units {}. "
"Please modify cell input function "
"accordingly.".format(num_mem_units,
num_utter_units))
else:
return inputs

# noinspection PyUnusedLocal
def inputs_and_attn_inputs_fn(inputs, cell_state):
"""Construct rnn input and attention mechanism input"""

# the hidden state and slots are not included,
# in hope that algorithm would learn correct attention
# regardless of the hidden state of an rnn and slots
attn_inputs = tf.concat([inputs[:, :num_utter_units],
inputs[:, (num_utter_units +
num_utter_units):]], 1)

# include slots in inputs but exclude previous action, since
# rnn should get previous action from its hidden state
inputs = inputs[:, :(num_utter_units +
num_utter_units)]

return inputs, attn_inputs

attn_cell = TimeAttentionWrapper(
cell=cell,
attention_mechanism=attn_mech,
sequence_len=self._dialogue_len,
attn_shift_range=attn_shift_range,
sparse_attention=self.sparse_attention,
inputs_and_attn_inputs_fn=inputs_and_attn_inputs_fn,
rnn_and_attn_inputs_fn=self.rnn_and_attn_inputs_fn,
ignore_mask=ignore_mask,
cell_input_fn=cell_input_fn,
cell_input_fn=lambda inputs, attention: (
self.cell_input_fn(inputs, attention,
num_memory_units_before_rnn)),
index_of_attn_to_copy=index_of_attn_to_copy,
likelihood_fn=lambda emb_1, emb_2: (
self._tf_sim(emb_1, emb_2, None)),
Expand All @@ -614,6 +646,33 @@ def inputs_and_attn_inputs_fn(inputs, cell_state):
)
return attn_cell

def _create_tf_dial_embed(self, embed_utter, embed_slots,
embed_prev_action, mask,
embed_for_no_intent, embed_for_no_action,
embed_for_action_listen):
"""Create rnn for dialogue level embedding"""

cell_input = tf.concat([embed_utter, embed_slots,
embed_prev_action], -1)

cell = self._create_rnn_cell()

real_length = tf.cast(tf.reduce_sum(mask, 1), tf.int32)

if self.is_using_attention():
cell = self._create_attn_cell(cell, embed_utter,
embed_prev_action,
real_length, embed_for_no_intent,
embed_for_no_action,
embed_for_action_listen)

return tf.nn.dynamic_rnn(
cell, cell_input,
dtype=tf.float32,
sequence_length=real_length,
scope='rnn_decoder'
)

def _alignments_history_from(self, final_state):
"""Extract alignments history form final rnn cell state"""

Expand All @@ -632,6 +691,17 @@ def _alignments_history_from(self, final_state):

return tf.concat(alignment_history, -1)

def _all_time_masks_from(self, final_state):
"""Extract all time masks form final rnn cell state"""

if not self.is_using_attention():
return None

# reshape to (batch, time, memory_time) and ignore last time
# because time_mask is created for the next time step
return tf.transpose(final_state.all_time_masks.stack(),
[1, 0, 2])[:, :-1, :]

def _sim_rnn_to_max_from(self, cell_output):
"""Save intermediate tensors for debug purposes"""

Expand Down Expand Up @@ -676,33 +746,6 @@ def _embed_dialogue_from(self, cell_output):

return embed_dialogue

def _create_tf_dial_embed(self, embed_utter, embed_slots,
embed_prev_action, mask,
embed_for_no_intent, embed_for_no_action,
embed_for_action_listen):
"""Create rnn for dialogue level embedding"""

cell_input = tf.concat([embed_utter, embed_slots,
embed_prev_action], -1)

cell = self._create_rnn_cell()

real_length = tf.cast(tf.reduce_sum(mask, 1), tf.int32)

if self.is_using_attention():
cell = self._create_attn_cell(cell, embed_utter,
embed_prev_action,
real_length, embed_for_no_intent,
embed_for_no_action,
embed_for_action_listen)

return tf.nn.dynamic_rnn(
cell, cell_input,
dtype=tf.float32,
sequence_length=real_length,
scope='rnn_decoder'
)

def _tf_sim(self, embed_dialogue, embed_action, mask):
"""Define similarity
this method has two roles:
Expand Down Expand Up @@ -811,7 +854,7 @@ def _tf_loss(self, sim, sim_act, sim_rnn_to_max, mask):
def train(self,
training_trackers, # type: List[DialogueStateTracker]
domain, # type: Domain
**kwargs # type: **Any
**kwargs # type: Any
):
# type: (...) -> None
"""Trains the policy on given training trackers."""
Expand Down Expand Up @@ -930,6 +973,8 @@ def train(self,
self.alignment_history = \
self._alignments_history_from(final_state)

self.all_time_masks = self._all_time_masks_from(final_state)

sim_rnn_to_max = self._sim_rnn_to_max_from(cell_output)
self.dial_embed = self._embed_dialogue_from(cell_output)

Expand Down Expand Up @@ -1129,7 +1174,7 @@ def _calc_train_acc(self, session_data, mask):
_mask) / np.sum(_mask)

def continue_training(self, training_trackers, domain, **kwargs):
# type: (List[DialogueStateTracker], Domain, **Any) -> None
# type: (List[DialogueStateTracker], Domain, Any) -> None
"""Continues training an already trained policy."""

batch_size = kwargs.get("batch_size", 5)
Expand Down Expand Up @@ -1274,6 +1319,8 @@ def persist(self, path):
self._persist_tensor('attn_embed', self.attn_embed)
self._persist_tensor('copy_attn_debug', self.copy_attn_debug)

self._persist_tensor('all_time_masks', self.all_time_masks)

saver = tf.train.Saver()
saver.save(self.session, checkpoint)

Expand Down Expand Up @@ -1335,6 +1382,8 @@ def load(cls, path):
attn_embed = cls.load_tensor('attn_embed')
copy_attn_debug = cls.load_tensor('copy_attn_debug')

all_time_masks = cls.load_tensor('all_time_masks')

encoded_actions_file = os.path.join(
path, "{}.encoded_all_actions.pkl".format(file_name))

Expand All @@ -1361,4 +1410,5 @@ def load(cls, path):
dial_embed=dial_embed,
rnn_embed=rnn_embed,
attn_embed=attn_embed,
copy_attn_debug=copy_attn_debug)
copy_attn_debug=copy_attn_debug,
all_time_masks=all_time_masks)

0 comments on commit 466d5a4

Please sign in to comment.