Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question regarding example model and its config #2

Open
Toshihiro-Ota opened this issue Mar 7, 2024 · 3 comments
Open

Question regarding example model and its config #2

Toshihiro-Ota opened this issue Mar 7, 2024 · 3 comments

Comments

@Toshihiro-Ota
Copy link

Thank you for providing the PyTorch version of the Energy Transformer originally implemented in JAX. This is really helpful for me, but I have not yet been able to reproduce the results like example1.png or images in visuals directory (gif files).

Let me confirm that the example model model.pth is trained with the config file model_config.txt in the same directory, not with the config file of the same name in the parent directory energy-transformer-torch/model_config.txt? If this is correct, some parameters ('alpha': 5.0, 'attn_beta': 16.0, 'epochs': 10000) are much bigger than the hyperparameter settings reported in the paper p.16 of 2302.07253, and some ('b1': 0.99, 'b2': 0.999, 'batch_size': 512, 'learning_rate': 5e-05, 'weight_decay': 0.001) are still somewhat different.

I know that the best hyperparameters would be different since the paper uses ImageNet-1k to train the model and here uses CIFAR-10/100. I just would like to know in what configuration the example model is trained and how I can reproduce the results of example1.png or images in visuals directory.

Thanks for any additional details you can provide, I will be trying to reproduce the results in this repo.

@Lemon-cmd
Copy link
Owner

Hey there,

Sorry, it has been a while since I've touched upon this code due to my focus on another work. Yes, I believe the provided model is trained with a different set of hyper parameters in contrast to the default ones. To be honest, it was a random model that I trained from scratch with some random chosen hyper parameters. In regards to the generation of GIFs, I updated the notebook, I forgot to include interpolation=transforms.InterpolationMode.NEAREST in the Resize function. Finally, please definitely don't use the parameters for ImageNet-1k, they might not do well at all and also, keep in mind the --batch-size flag is the global batch size. In my setting, I was using 4 gpu devices which should be 128 samples per gpu.

Let me know if you need additional help! I will try to respond asap.

@Toshihiro-Ota
Copy link
Author

Thanks for your reply.

it was a random model that I trained from scratch with some random chosen hyper parameters.

So, you mean that you conducted a training with randomly chosen hyperparameters without any tuning? If so, example1.png and the gif files in visuals look surprisingly well restored. It would be highly appreciated if you could find the exact configurations and let me know that, or is it really example_model/model_config.txt?

In regards to the generation of GIFs, I updated the notebook, I forgot to include interpolation=transforms.InterpolationMode.NEAREST in the Resize function.

Thanks! I will try it again.

@Lemon-cmd
Copy link
Owner

Hey there, I have updated the code base to use accelerator from Huggingface and fix some minor errors (including plotting) from the previous code base. Masking is now randomize across batches as well. I am currently searching for a new model configuration that will be decent enough for visualization. I will give another update soon, hopefully :). It is surprising that the reconstructions were very good in the previous code base, but I don't think it will be the same for this one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants