Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Double Softmax in PyTorch image estimator for test cases. #2227

Open
GiulioZizzo opened this issue Jul 26, 2023 · 2 comments
Open

Double Softmax in PyTorch image estimator for test cases. #2227

GiulioZizzo opened this issue Jul 26, 2023 · 2 comments
Assignees
Labels
improvement Improve implementation

Comments

@GiulioZizzo
Copy link
Collaborator

Many tests use the PyTorch image estimator defined in the test utils.

By default this estimator does not use logits, e.g. the function signature is:

get_image_classifier_pt(from_logits=False, load_init=True, use_maxpool=True)

However, the loss function is loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")

torch.nn.CrossEntropyLoss by default expects logits and will re-apply a softmax: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

Hence, we should aim to make the default configuration mathematically correct. We could:

  1. Have the default from_logits=True
  2. Additionally the loss depend on from_logits=True/False by using either CrossEntropyLoss/NLLLoss

This may require updating certain ART tests.

System information (please complete the following information):

  • OS: MacOS
  • Python version: 3.9
  • ART version or commit number: 1.15
  • TensorFlow / Keras / PyTorch / MXNet version: Torch 1.13
@beat-buesser
Copy link
Collaborator

Hi @GiulioZizzo Thank you very much for raising this issue! Have you found any tests where the wrong value for from_logits has been used?

GiulioZizzo added a commit to GiulioZizzo/adversarial-robustness-toolbox that referenced this issue Aug 14, 2023
Signed-off-by: GiulioZizzo <giulio.zizzo@yahoo.co.uk>
@GiulioZizzo
Copy link
Collaborator Author

Hi @beat-buesser ! So this issue is prevalent in the virtually all of ART test cases which use get_image_classifier_pt as they all use the default parameters as far as I can see. In many cases this doesn't cause a huge problem when the tests just do forward passes and then compute things like accuracy (the argmax would be unaffected).

However, it does start to cause problems when the neural network is trained and an exact result is expected. I came across the problem when refactoring test_adversarial_trainer for issue #2225 (including Huggingface support in ART). The Tensorflow and PyTorch models would train in a totally different manner even though they ought to converge to almost identical results (allowing for framework specific numerical deltas.) When you change the PyTorch classifier to use the correct logits/loss function combination the model then trains as it should and the framework results then match.

There could well be other tests that are affected by this, so could use investigating and correcting for current and future tests.

GiulioZizzo added a commit to GiulioZizzo/adversarial-robustness-toolbox that referenced this issue Aug 31, 2023
Signed-off-by: GiulioZizzo <giulio.zizzo@yahoo.co.uk>
@beat-buesser beat-buesser added the improvement Improve implementation label Sep 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
improvement Improve implementation
Projects
None yet
Development

No branches or pull requests

2 participants