Skip to content

Commit

Permalink
added optimization config to examples/seq2seq_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhitingHu committed Oct 9, 2018
1 parent a73d727 commit d223165
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
8 changes: 8 additions & 0 deletions examples/seq2seq_attn/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,11 @@
'attention_layer_size': num_units
}
}
opt = {
'optimizer': {
'type': 'AdamOptimizer',
'kwargs': {
'learning_rate': 0.001,
},
},
}
18 changes: 18 additions & 0 deletions examples/seq2seq_attn/config_model_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,21 @@
'max_decoding_length_infer': None,
'name': 'attention_rnn_decoder'
}
# --------------------- Optimization --------------------- #
opt = {
'optimizer': {
'type': 'AdamOptimizer',
'kwargs': {
'learning_rate': 0.001,
# Other keyword arguments for the optimizer class
},
},
'learning_rate_decay': {
# Hyperparameters of learning rate decay
},
'gradient_clip': {
# Hyperparameters of gradient clipping
},
'gradient_noise_scale': None,
'name': None
}
3 changes: 2 additions & 1 deletion examples/seq2seq_attn/seq2seq_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def build_model(batch, train_data):
tx.losses.sequence_sparse_softmax_cross_entropy(
labels=batch['target_text_ids'][:, 1:],
logits=training_outputs.logits,
sequence_length=batch['target_length'] - 1))
sequence_length=batch['target_length'] - 1),
hparams=config_model.opt)

start_tokens = tf.ones_like(batch['target_length']) * \
train_data.target_vocab.bos_token_id
Expand Down

0 comments on commit d223165

Please sign in to comment.