New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add vae_text example #133
Add vae_text example #133
Conversation
* Port vae_text from texar-TF. * Add external distribution MultivariateNormalDiag. * Add preprocessing for data batch. * Modify None checking condition for initial_state in RNNDecoderBase. * Modify max_pos for config_trans_yahoo.py. * Modify connectors mlp function. * Refactor vae_text training & generations. * Refactor vae_text decoder embeddings. * Polish code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally, it looks good to me now. However, the performance on the Yahoo dataset is significantly worse than the TF version. Do you have assumptions as to why that is?
@huzecong thanks. One possible reason I think is that the checkpoint needs to store more information than current version, since we are using more than |
Then why can't we store more info in order to recover the results? what's the difficulty here? |
@ZhitingHu yeah... It is actually used to handle the case that training process pauses. I think we can add a parameter for continue training checkpoints, I'll make that. |
Did you mean the current code can reproduce the TF results if training is not paused in the middle? |
@ZhitingHu Yes, actually "Yahoo-Lstm" once reached |
Then report the results without interrupted training |
Got it. |
examples/vae_text/README.md
Outdated
@@ -42,7 +42,7 @@ Here `--model` specifies the saved model checkpoint, which is saved in `./models | |||
|
|||
|Dataset |Metrics | VAE-LSTM |VAE-Transformer | | |||
|---------------|-------------|----------------|------------------------| | |||
|Yahoo | Test PPL<br>Test NLL | 75.21<br>336.41 |67.81<br>328.34| | |||
|Yahoo | Test PPL<br>Test NLL | 69.42<br>338.65 |67.81<br>328.34| |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about the Transformer results? Are those still running?
This PR is not yet ready for review. |
This PR is ready for review now. |
Results now look reasonable to me. Let's merge this. |
Port vae_text from texar-TF.
Add external distribution
MultivariateNormalDiag.
Add preprocessing for data batch.
Modify None checking condition for initial_state in
RNNDecoderBase.
Modify max_pos for config_trans_yahoo.py.
Modify connectors mlp function.
Refactor vae_text training & generation decoder.
Refactor vae_text decoder embeddings.
Refactor to import
texar.torch
.Polish code.