Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize PyTorch Classifiers and Object Detectors #2180

Merged
merged 13 commits into from
Jun 20, 2023

Conversation

f4str
Copy link
Collaborator

@f4str f4str commented Jun 6, 2023

Description

Use PyTorch datasets and dataloaders to optimize the fit and predict methods for the following PyTorch estimators: PyTorchClassifier, PyTorchRegressor, PyTorchRandomizedSmoothing, PyTorchObjectDetector, PyTorchYolo. This not only speeds up the batching process, but also will only move the current batch to the GPU rather than the entire dataset. This will significantly reduce the VRAM usage.

Optimize the predict method for the TensorFlowV2Classifier by using a TensorFlow dataset and the @tf.function decorator to speed up inference.

Significant changes to the PyTorchObjectDetector estimator:

  • The channels_first property is now properly used to transpose inputs accordingly.
  • Unify the input preprocessing in a similar way that the PyTorchYolo estimator does. This simplifies the implementation and reduces redundant code.
  • Remove the extraneous loop when computing the loss gradients. The underlying bug has been fixed as of torch>=1.10 so the entire batch can be processed at once now. Due to batch processing rather an one sample at a time, the loss gradients have now changed. The unit tests have been adjusted accordingly.

Also removed the PyTorch version checking from the PyTorchYolo estimator as the bug only applies to the torchvision models and does not affect the external YOLO models.

Fixes #2157
Fixes #1637
Fixes #2173

Type of change

Please check all relevant options.

  • Improvement (non-breaking)
  • Bug fix (non-breaking)
  • New feature (non-breaking)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Testing

Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration.

  • Unit tests for PyTorch and TensorFlow estimators work as expected
  • Modified unit tests for PyTorchObjectDetector
  • Modified unit tests for PyTorchFasterRCNN

Test Configuration:

  • OS
  • Python version
  • ART version or commit number
  • TensorFlow / Keras / PyTorch / MXNet version

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

f4str added 8 commits May 30, 2023 10:45
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
@f4str
Copy link
Collaborator Author

f4str commented Jun 6, 2023

This may also potentially resolve #1943 but other issues may still persist.

@codecov-commenter
Copy link

codecov-commenter commented Jun 6, 2023

Codecov Report

Merging #2180 (f3fcf19) into dev_1.15.0 (011ab1e) will increase coverage by 4.71%.
The diff coverage is 74.86%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Impacted file tree graph

@@              Coverage Diff               @@
##           dev_1.15.0    #2180      +/-   ##
==============================================
+ Coverage       80.96%   85.68%   +4.71%     
==============================================
  Files             306      306              
  Lines           27067    27022      -45     
  Branches         4980     4962      -18     
==============================================
+ Hits            21914    23153    +1239     
+ Misses           3927     2599    -1328     
- Partials         1226     1270      +44     
Impacted Files Coverage Δ
art/estimators/regression/pytorch.py 14.20% <0.00%> (-0.55%) ⬇️
art/estimators/object_detection/utils.py 67.92% <48.48%> (-32.08%) ⬇️
...mators/object_detection/pytorch_object_detector.py 86.54% <84.50%> (+9.75%) ⬆️
...tors/certification/randomized_smoothing/pytorch.py 85.54% <100.00%> (+1.45%) ⬆️
art/estimators/classification/pytorch.py 86.30% <100.00%> (+0.42%) ⬆️
art/estimators/object_detection/pytorch_yolo.py 80.66% <100.00%> (+3.08%) ⬆️

... and 36 files with indirect coverage changes

Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
@f4str f4str marked this pull request as ready for review June 6, 2023 23:33
@beat-buesser beat-buesser self-requested a review June 9, 2023 09:29
@beat-buesser beat-buesser self-assigned this Jun 9, 2023
@beat-buesser beat-buesser added this to Pull request open in ART 1.15.0 via automation Jun 9, 2023
@beat-buesser beat-buesser added this to the ART 1.15.0 milestone Jun 9, 2023
Copy link
Collaborator

@beat-buesser beat-buesser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @f4str Thank you very much for these improvements to the estimators! They will be very useful for many users of ART. I have one suggestion, what do you think?


# Convert labels into tensor
if y is not None and isinstance(y, list) and isinstance(y[0]["boxes"], np.ndarray):
y_tensor = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we be able to remove the "type: ignore" in line 139 if we add more specific typing here in line 122 like

Suggested change
y_tensor = []
y_tensor: List[Dict[str, "torch.Tensor"]] = []

If yes, should we move this initialization before this if-block?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mypy issue is actually due to the None type that y can be. The type: ignore was actually here before, I've just abstracted this segment of code out (since it is now reused) and made it a function.

Attempting to make the typing changes to correct this needs to go a couple levels upwards which can makes thing overcomplicated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should now be resolved

Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
@f4str f4str changed the title Optimize PyTorch/TensorFlow Classifiers and Object Detectors Optimize PyTorch Classifiers and Object Detectors Jun 14, 2023
@f4str
Copy link
Collaborator Author

f4str commented Jun 14, 2023

Hi @beat-buesser thank you for approving the PR. I have reverted the changes to the TensorFlowV2Classifier as the tests/attacks/evasion/test_sign_opt.py test case runs for over 4 hours with this optimization. Since this is difficult to debug, it is simpler to omit this change for now (since the main focus of this PR is the PyTorch classifiers) and revisit it later.

I've also responded to your comment about the mypy issue.

@beat-buesser
Copy link
Collaborator

@f4str Thank you very much!

@beat-buesser beat-buesser merged commit 4bfed67 into Trusted-AI:dev_1.15.0 Jun 20, 2023
37 checks passed
ART 1.15.0 automation moved this from Pull request open to Pull request done Jun 20, 2023
@f4str f4str deleted the torch-dataloaders branch June 20, 2023 23:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
improvement Improve implementation
Projects
No open projects
ART 1.15.0
Pull request done
3 participants