Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve batched jacobian #848

Merged
merged 5 commits into from
Aug 17, 2024
Merged

feat: improve batched jacobian #848

merged 5 commits into from
Aug 17, 2024

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Aug 17, 2024

  • batched_jacobian not using lux layers are now differentiable
  • Support Tracker.jl for the above case (use Zygote internally)
  • Support ReverseDiff.jl for the above case (use Zygote internally)

Copy link
Contributor

github-actions bot commented Aug 17, 2024

Benchmark Results (ASV)

main 2b795ce... main/2b795ce3fa4a08...
basics/overhead 0.0883 ± 0.0016 μs 0.0912 ± 0.0019 μs 0.968
time_to_load 1.02 ± 0.0049 s 1.02 ± 0.0018 s 1

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

test/autodiff/batched_autodiff_tests.jl Outdated Show resolved Hide resolved
test/autodiff/batched_autodiff_tests.jl Outdated Show resolved Hide resolved
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 2b795ce Previous: b6171a6 Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3741.875 ns 3675.625 ns 1.02
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 6728.285714285715 ns 8093.5 ns 0.83
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 21050 ns 21210 ns 0.99
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9856.4 ns 9748.2 ns 1.01
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 9185 ns 9167.2 ns 1.00
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4667.5 ns 4470.875 ns 1.04
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 4988.125 ns 4956.875 ns 1.01
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1046.3151515151515 ns 2373.4 ns 0.44
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1062.6564417177915 ns 2270.3 ns 0.47
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1793.1967213114754 ns 1790.017543859649 ns 1.00
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 180.0140056022409 ns 179.70239774330042 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17312 ns 17562.5 ns 0.99
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 12673 ns 24787 ns 0.51
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 36298 ns 38393 ns 0.95
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 28804 ns 29025 ns 0.99
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 20047 ns 21590 ns 0.93
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17183 ns 17092 ns 1.01
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 25788 ns 25648 ns 1.01
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 1415.6 ns 20248 ns 0.06991307783484788
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 1445.6 ns 14448 ns 0.10
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4889.142857142857 ns 4846.285714285715 ns 1.01
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1660 ns 1659.2 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 117792312 ns 77690170 ns 1.52
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 48750528 ns 76782338 ns 0.63
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 190814626.5 ns 155414925 ns 1.23
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 205639818.5 ns 167638289.5 ns 1.23
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 210638995 ns 142842293.5 ns 1.47
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 12019986 ns 11557321.5 ns 1.04
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 182404684 ns 199234044.5 ns 0.92
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 5924764 ns 15528408.5 ns 0.38
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 5917701 ns 15540189 ns 0.38
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 47268995 ns 30661456 ns 1.54
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6384243 ns 6376663 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 1023131234 ns 1064055959.5 ns 0.96
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2847075600 ns 2970205700 ns 0.96
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 216104427 ns 178121161 ns 1.21
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 1374030570 ns 1320655778 ns 1.04
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 4091197249 ns 3516351096 ns 1.16
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 401783545 ns 344809509 ns 1.17
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 1515463169 ns 1431616033 ns 1.06
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 4404884213 ns 4058579611 ns 1.09
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 429373736.5 ns 436008182 ns 0.98
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 417449321 ns 381866129 ns 1.09
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 923421065.5 ns 905256978 ns 1.02
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 54364697.5 ns 54567006.5 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 402764024.5 ns 382293897 ns 1.05
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 922893504 ns 870357323.5 ns 1.06
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 29382300 ns 54472914.5 ns 0.54
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 534852007 ns 551222188 ns 0.97
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 1447265873 ns 1387168504 ns 1.04
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 166049072 ns 164122645 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1284285410.5 ns 1180058919 ns 1.09
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1573918068 ns 1610297742 ns 0.98
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2266896153 ns 2289727615.5 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2524001196 ns 2640437136 ns 0.96
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 2264631726 ns 2193753011.5 ns 1.03
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 2425508620 ns 2122924359 ns 1.14
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 286158756 ns 282003619 ns 1.01
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 284746326 ns 286261947 ns 0.99
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 458549355.5 ns 437257287 ns 1.05
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11691323 ns 11806435 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 15153808 ns 34527638 ns 0.44
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 16445650 ns 16364743 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 21065487 ns 21004093 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 15201542 ns 15284140 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1146511 ns 1148921.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 19185166 ns 35777843.5 ns 0.54
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 1846077 ns 4500694 ns 0.41
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 1843918.5 ns 4506207 ns 0.41
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 1979827 ns 2045686 ns 0.97
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 193021 ns 196300 ns 0.98
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 375492 ns 378068 ns 0.99
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 208168.5 ns 314462 ns 0.66
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 374379 ns 377972 ns 0.99
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 524910 ns 520691 ns 1.01
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 290211.5 ns 289716 ns 1.00
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 404816 ns 401777 ns 1.01
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 425615 ns 425321 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 56655 ns 157406 ns 0.36
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 60202 ns 162456 ns 0.37
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 92022 ns 91953 ns 1.00
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104375 ns 104407 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 339878291.5 ns 297649242 ns 1.14
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 269324213 ns 287837994 ns 0.94
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 588762738.5 ns 545531151.5 ns 1.08
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 685038653 ns 655809148 ns 1.04
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 682011247 ns 554893727 ns 1.23
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 325025762.5 ns 316084028.5 ns 1.03
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 622892766 ns 583442251.5 ns 1.07
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 38430370.5 ns 40159465 ns 0.96
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 38442335 ns 40173961.5 ns 0.96
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 111238041.5 ns 96663497 ns 1.15
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28030484 ns 28321531 ns 0.99
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 21176816 ns 21078472 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19294631 ns 17393481 ns 1.11
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 22824568.5 ns 22657728 ns 1.01
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 28022340 ns 28019412 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19353731 ns 19298592.5 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 20894939 ns 20720819 ns 1.01
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6585712.5 ns 6086608 ns 1.08
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6540915 ns 6101998 ns 1.07
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6561562 ns 6509879.5 ns 1.01

This comment was automatically generated by workflow using github-action-benchmark.

@avik-pal avik-pal force-pushed the ap/bjac branch 2 times, most recently from e25b6ee to 29d28ba Compare August 17, 2024 18:43
Copy link

codecov bot commented Aug 17, 2024

Codecov Report

Attention: Patch coverage is 92.30769% with 2 lines in your changes missing coverage. Please review.

Project coverage is 95.34%. Comparing base (12bbeaf) to head (2b795ce).
Report is 5 commits behind head on main.

Files Patch % Lines
src/autodiff/batched_autodiff.jl 91.30% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #848      +/-   ##
==========================================
+ Coverage   95.10%   95.34%   +0.24%     
==========================================
  Files          58       58              
  Lines        2840     2859      +19     
==========================================
+ Hits         2701     2726      +25     
+ Misses        139      133       -6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@avik-pal avik-pal merged commit 3e77701 into main Aug 17, 2024
66 of 76 checks passed
@avik-pal avik-pal deleted the ap/bjac branch August 17, 2024 20:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant