Skip to content
Permalink
Browse files

Changed how checklist balance is used in action prediction (#1204)

* changed action projection in nlvr

* removed +=
  • Loading branch information
pdasigi committed May 13, 2018
1 parent 1554fb0 commit 2dda95beede0322bf9bde1d99803113aaabf30bc
@@ -26,7 +26,7 @@ class NlvrDecoderStep(DecoderStep[NlvrDecoderState]):
attention_function : ``SimilarityFunction``
dropout : ``float``
Dropout to use on decoder outputs and before action prediction.
use_coverage : ``bool``
use_coverage : ``bool``, optional (default=False)
Is this DecoderStep being used in a semantic parser trained using coverage? We need to know
this to define a learned parameter for using checklist balances in action prediction.
"""
@@ -95,7 +95,7 @@ def take_step(self, # type: ignore
# (group_size, decoder_input_dim)
decoder_input = self._input_projection_layer(torch.cat([attended_sentence,
previous_action_embedding], -1))

decoder_input = torch.nn.functional.tanh(decoder_input)
hidden_state, memory_cell = self._decoder_cell(decoder_input, (hidden_state, memory_cell))

hidden_state = self._dropout(hidden_state)
@@ -138,10 +138,11 @@ def take_step(self, # type: ignore
action_query = torch.cat([hidden_state, attended_sentence], dim=-1)
# (group_size, action_embedding_dim)
predicted_action_embedding = self._output_projection_layer(action_query)
predicted_action_embedding = self._dropout(torch.nn.functional.relu(predicted_action_embedding))
predicted_action_embedding = self._dropout(torch.nn.functional.tanh(predicted_action_embedding))
if state.checklist_state[0] is not None:
embedding_addition = self._get_predicted_embedding_addition(state)
predicted_action_embedding += self._checklist_embedding_multiplier * embedding_addition
addition = embedding_addition * self._checklist_embedding_multiplier
predicted_action_embedding = predicted_action_embedding + addition
# We'll do a batch dot product here with `bmm`. We want `dot(predicted_action_embedding,
# action_embedding)` for each `action_embedding`, and we can get that efficiently with
# `bmm` and some squeezing.
@@ -147,7 +147,8 @@ def take_step(self,
embedding_addition = self._get_predicted_embedding_addition(state,
self._unlinked_terminal_indices,
unlinked_balance)
predicted_action_embedding += self._unlinked_checklist_multiplier * embedding_addition
addition = embedding_addition * self._unlinked_checklist_multiplier
predicted_action_embedding = predicted_action_embedding + addition

# We'll do a batch dot product here with `bmm`. We want `dot(predicted_action_embedding,
# action_embedding)` for each `action_embedding`, and we can get that efficiently with

0 comments on commit 2dda95b

Please sign in to comment.
You can’t perform that action at this time.