Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodiff/training support for Nearest Interpolation #1414

Merged
merged 27 commits into from
Mar 6, 2024

Conversation

ashdtu
Copy link
Contributor

@ashdtu ashdtu commented Mar 5, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#1393

Changes

This PR implements the backward pass for nearest interpolation for all backends (Ndarray, Tch, WGPU) except Candle. This ensures that training can now be done with this operation. Nearest neighbour interpolation is generally a popular and default method for up/down sampling.

Testing

  • Test cases to verify the gradients of the backward pass of nearest interpolation operation has been written in burn-autodiff with test for both upsampling/downsampling cases. The outputs are verified against Python's PyTorch upsample2D(mode="nearest") operation.

Nikaidou-Shinku and others added 27 commits February 4, 2024 00:19
feat: bilinear interpolation for tch, ndarray and wgpu backend

fix: reduce test case size to avoid exceeding floating-point precision limits

feat: support nearest-neighbor interpolation for ndarray backend

feat: support nearest-neighbor interpolation for wgpu backend

feat: support fusion backend

fix: no-std support

build: upgrade dependencies
…ku/burn into feat/interpolation

merge upstream changes
:wq!
@ashdtu ashdtu requested review from antimora and louisfd March 5, 2024 08:16
Copy link

codecov bot commented Mar 5, 2024

Codecov Report

Attention: Patch coverage is 81.45695% with 56 lines in your changes are missing coverage. Please review.

Project coverage is 85.80%. Comparing base (4ed90a9) to head (b836c69).

Files Patch % Lines
crates/burn-tch/src/ops/module.rs 0.00% 35 Missing ⚠️
crates/burn-autodiff/src/ops/module.rs 84.09% 7 Missing ⚠️
crates/burn-candle/src/ops/module.rs 0.00% 7 Missing ⚠️
crates/burn-ndarray/src/ops/interpolate.rs 93.75% 2 Missing ⚠️
crates/burn-ndarray/src/ops/module.rs 81.81% 2 Missing ⚠️
crates/burn-wgpu/src/kernel/interpolate.rs 94.28% 2 Missing ⚠️
crates/burn-tensor/src/tensor/ops/modules/base.rs 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1414      +/-   ##
==========================================
- Coverage   85.81%   85.80%   -0.02%     
==========================================
  Files         610      611       +1     
  Lines       70417    70715     +298     
==========================================
+ Hits        60428    60674     +246     
- Misses       9989    10041      +52     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ashdtu
Copy link
Contributor Author

ashdtu commented Mar 5, 2024

Some more context:

For Tch-backend -> Backward passes for all interpolation types(nearest, bilinear, bicubic) are supported in this PR.

Bilinear, Bicubic for the backends(ndarray, wsgl) are under progress in another branch.

Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ashdtu ashdtu merged commit 0c92c8c into main Mar 6, 2024
15 checks passed
@ashdtu ashdtu deleted the feat/nearest_interp_autodiff branch March 6, 2024 05:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants