New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Pointer Networks implementation based on the Vocabulary Approach #505
Changes from all commits
0395f00
66b2a6f
70a6f7b
3d13ed2
4e5beff
6d86641
3aa8c07
f833eb8
e8ad32b
0b18d6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
VOCAB_SYMBOLS = [PAD_SYMBOL, UNK_SYMBOL, BOS_SYMBOL, EOS_SYMBOL] | ||
# reserve extra space for the EOS or BOS symbol that is added to both source and target | ||
SPACE_FOR_XOS = 1 | ||
MAX_OOV_WORDS = 50 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not reuse another variable for this, such as the maximum input length? That way we avoid creating another variable by using a good default. Is there any reason to set it to something other than that? |
||
|
||
ARG_SEPARATOR = ":" | ||
|
||
|
@@ -390,7 +391,8 @@ | |
# pointer networks | ||
POINTER_NET_RNN = "rnn" | ||
POINTER_NET_SHARED = "shared" | ||
POINTER_NET_CHOICES = [POINTER_NET_RNN] | ||
POINTER_NET_SUMMARY = "summary" | ||
POINTER_NET_CHOICES = [POINTER_NET_RNN, POINTER_NET_SUMMARY] | ||
|
||
# data sharding | ||
SHARD_NAME = "shard.%05d" | ||
|
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -470,7 +470,9 @@ def __init__(self, | |
layer_normalization: bool = False, | ||
attention_in_upper_layers: bool = False, | ||
dtype: str = C.DTYPE_FP32, | ||
enc_last_hidden_concat_to_embedding: bool = False) -> None: | ||
enc_last_hidden_concat_to_embedding: bool = False, | ||
use_pointer_nets: bool = False, | ||
pointer_nets_type: str = C.POINTER_NET_SUMMARY) -> None: | ||
|
||
super().__init__() | ||
self.max_seq_len_source = max_seq_len_source | ||
|
@@ -484,6 +486,8 @@ def __init__(self, | |
self.attention_in_upper_layers = attention_in_upper_layers | ||
self.enc_last_hidden_concat_to_embedding = enc_last_hidden_concat_to_embedding | ||
self.dtype = dtype | ||
self.use_pointer_nets = use_pointer_nets | ||
self.pointer_nets_type = pointer_nets_type | ||
|
||
|
||
@Decoder.register(RecurrentDecoderConfig, C.RNN_DECODER_PREFIX) | ||
|
@@ -582,6 +586,7 @@ def decode_sequence(self, | |
""" | ||
|
||
# target_embed: target_seq_len * (batch_size, num_target_embed) | ||
target_embed_local = target_embed | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It isn't necessary to save this and then return it. This reassignment here does not change the value in the caller. |
||
target_embed = mx.sym.split(data=target_embed, num_outputs=target_embed_max_length, axis=1, squeeze_axis=True) | ||
|
||
# Get last state from source (batch_size, num_target_embed) | ||
|
@@ -606,7 +611,7 @@ def decode_sequence(self, | |
hidden_states = [] # type: List[mx.sym.Symbol] | ||
context_vectors = [] # type: List[mx.sym.Symbol] | ||
attention_probs = [] # type: List[mx.sym.Symbol] | ||
# TODO: possible alternative: feed back the context vector instead of the hidden (see lamtram) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's keep this line in. |
||
coverage_vectors = [] # type: List[mx.sym.Symbol] | ||
self.reset() | ||
for seq_idx in range(target_embed_max_length): | ||
# hidden: (batch_size, rnn_num_hidden) | ||
|
@@ -619,11 +624,23 @@ def decode_sequence(self, | |
hidden_states.append(state.hidden) | ||
context_vectors.append(attention_state.context) | ||
attention_probs.append(attention_state.probs) | ||
coverage_vectors.append(attention_state.dynamic_source) | ||
|
||
# concatenate along time axis: (batch_size, target_embed_max_length, rnn_num_hidden) | ||
return mx.sym.Group([mx.sym.stack(*hidden_states, axis=1, name='%shidden_stack' % self.prefix), \ | ||
mx.sym.stack(*context_vectors, axis=1, name='%scontext_stack' % self.prefix), | ||
mx.sym.stack(*attention_probs, axis=1, name='%sattention_stack' % self.prefix)]) | ||
if self.rnn_config.use_pointer_nets and self.rnn_config.pointer_nets_type == C.POINTER_NET_SUMMARY: | ||
return mx.sym.Group([mx.sym.stack(*hidden_states, axis=1, name='%shidden_stack' % self.prefix), | ||
# expected size: (batch_size, trg_max_length, encoder_num_hidden) | ||
mx.sym.stack(*context_vectors, axis=1, name='%scontext_stack' % self.prefix), | ||
# expected size: (batch_size, trg_max_length, attn_len) | ||
mx.sym.stack(*attention_probs, axis=1, name='%sattention_stack' % self.prefix), | ||
# expected size: (batch_size, trg_max_length, attn_len) | ||
mx.sym.stack(*coverage_vectors, axis=1, name='%scoverage_stack' % self.prefix), | ||
# expected size: (batch_size, trg_max_length, trg_embed_len) | ||
target_embed_local]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to return this. The caller already has it. |
||
else: | ||
return mx.sym.Group([mx.sym.stack(*hidden_states, axis=1, name='%shidden_stack' % self.prefix), | ||
mx.sym.stack(*context_vectors, axis=1, name='%scontext_stack' % self.prefix), | ||
mx.sym.stack(*attention_probs, axis=1, name='%sattention_stack' % self.prefix)]) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should get rid of the |
||
def decode_step(self, | ||
step: int, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,11 +57,17 @@ def __init__(self, | |
self.use_feature_loader = use_feature_loader | ||
|
||
def decode_and_evaluate(self, | ||
use_pointer_nets: bool, | ||
max_oov_words: int, | ||
pointer_nets_type: str, | ||
checkpoint: Optional[int] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't be needed (per above). |
||
output_name: str = os.devnull) -> Dict[str, float]: | ||
""" | ||
Decodes data set and evaluates given a checkpoint. | ||
|
||
:param use_pointer_nets: Flag to indicate if pointer network is enabled(not available with captioning as of now) | ||
:param max_oov_words: Maximum number of words to consider in the extended vocabulary (with pointer networks) | ||
:param pointer_nets_type: Pointer Networks Implementation to use. | ||
:param checkpoint: Checkpoint to load parameters from. | ||
:param output_name: Filename to write translations to. Defaults to /dev/null. | ||
:return: Mapping of metric names to scores. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -226,7 +226,11 @@ def create_model_config(args: argparse.Namespace, | |
config_loss = loss.LossConfig(name=args.loss, | ||
vocab_size=vocab_target_size, | ||
normalization_type=args.loss_normalization_type, | ||
label_smoothing=args.label_smoothing) | ||
label_smoothing=args.label_smoothing, | ||
use_pointer_nets=False, | ||
use_coverage_loss=False, | ||
coverage_loss_weight=0, | ||
pointer_nets_type=C.POINTER_NET_SUMMARY) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that pointer nets are not going to be supported for image captioning. Could these calls rely on defaults, such that no changes need to be made to the image captioning code? |
||
model_config = model.ModelConfig(config_data=config_data, | ||
vocab_source_size=0, | ||
|
@@ -384,7 +388,10 @@ def train(args: argparse.Namespace): | |
mxmonitor_pattern=args.monitor_pattern, | ||
mxmonitor_stat_func=args.monitor_stat_func, | ||
allow_missing_parameters=args.allow_missing_params, | ||
existing_parameters=args.params) | ||
existing_parameters=args.params, | ||
use_pointer_nets=False, | ||
max_oov_words=C.MAX_OOV_WORDS, | ||
pointer_nets_type=C.POINTER_NET_SUMMARY) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These three arguments will be read by the model from the ModelConfig. You shouldn't need to pass them through at all.