Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Used New Fused Ops from LuxLib #591

Merged
merged 5 commits into from
Apr 28, 2024
Merged

Used New Fused Ops from LuxLib #591

merged 5 commits into from
Apr 28, 2024

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Apr 19, 2024

  • Fix Nested AD on CUDA
  • ReverseDiff and Tracker seems to be slowed down. Rewrite the rrule_via_ad to directly use ForwardDiff they don't hit the fused kernels anyways

@avik-pal avik-pal mentioned this pull request Apr 20, 2024
3 tasks
@avik-pal avik-pal force-pushed the ap/fused_ops branch 6 times, most recently from 16b014c to 192a886 Compare April 22, 2024 04:47
@avik-pal avik-pal force-pushed the ap/fused_ops branch 4 times, most recently from 9641171 to 5b6ae14 Compare April 24, 2024 14:46
@avik-pal avik-pal force-pushed the ap/fused_ops branch 6 times, most recently from 18b958b to 6500812 Compare April 24, 2024 19:26
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 6dfaa8b Previous: 36b362a Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3887.25 ns 3674.375 ns 1.06
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7169.333333333334 ns 5854.25 ns 1.22
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20914.5 ns 15508 ns 1.35
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9861.25 ns 9975.333333333334 ns 0.99
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8861.75 ns 8696 ns 1.02
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4501 ns 4494.625 ns 1.00
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1129.7290322580645 ns 2060.9 ns 0.55
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1187.8714285714286 ns 1664.8521126760563 ns 0.71
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1803.9107142857142 ns 1815.6923076923076 ns 0.99
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 180.15514809590974 ns 179.37413073713492 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17412 ns 17743 ns 0.98
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 17392 ns 18735 ns 0.93
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 37120 ns 35667 ns 1.04
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 27923 ns 28753 ns 0.97
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19757 ns 19787 ns 1.00
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 16992 ns 17562.5 ns 0.97
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3859.6875 ns 4920.571428571428 ns 0.78
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3912.375 ns 5003.571428571428 ns 0.78
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4890.571428571428 ns 5028 ns 0.97
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1653.1 ns 1651.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 38572466.5 ns 48926002 ns 0.79
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 57762387 ns 108271301 ns 0.53
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 75940503 ns 84036071.5 ns 0.90
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 88676393.5 ns 107192834 ns 0.83
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 72616094 ns 106869664 ns 0.68
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11613199 ns 11898560 ns 0.98
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7053661.5 ns 18820810.5 ns 0.37
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 7070223 ns 18550564.5 ns 0.38
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 10480696 ns 18693425 ns 0.56
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6396029 ns 6446973 ns 0.99
vgg16/cpu/reverse/Zygote/(32, 32, 3, 1) 116324643.5 ns 106088743.5 ns 1.10
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 720564523.5 ns 832416622 ns 0.87
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2847349591 ns 2984233767 ns 0.95
vgg16/cpu/reverse/Tracker/(32, 32, 3, 1) 151600988 ns 146290469 ns 1.04
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 808013590 ns 1085323519.5 ns 0.74
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 2540007772 ns 3036724601 ns 0.84
vgg16/cpu/reverse/Flux/(32, 32, 3, 1) 77923897.5 ns 90590491 ns 0.86
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 652488915.5 ns 733826110 ns 0.89
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2712912081 ns 3075391905 ns 0.88
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 1) 32452947 ns 29790496 ns 1.09
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 181021928.5 ns 212277593 ns 0.85
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 679238492 ns 781425925 ns 0.87
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 1) 28925183 ns 30383532 ns 0.95
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 172674307.5 ns 197512172 ns 0.87
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 667402991 ns 778085435.5 ns 0.86
vgg16/cpu/forward/Flux/(32, 32, 3, 1) 30177193 ns 29280386 ns 1.03
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 186755940.5 ns 188405608 ns 0.99
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 711151249 ns 809432575 ns 0.88
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1053774509 ns 1147032763 ns 0.92
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1860099125 ns 1880482284 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2144009823.5 ns 2148352018 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2279485898 ns 2539276579 ns 0.90
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1822477348.5 ns 1864259381 ns 0.98
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 320067471 ns 358738282 ns 0.89
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 321883940.5 ns 405361001.5 ns 0.79
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 391754543.5 ns 412981399.5 ns 0.95
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11807011 ns 12013680.5 ns 0.98
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 17876933 ns 18334339 ns 0.98
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 18983163 ns 19696096 ns 0.96
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23784937 ns 24459171 ns 0.97
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17910261 ns 18368123 ns 0.98
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1147512 ns 1168429 ns 0.98
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2040254 ns 2120374.5 ns 0.96
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2040440 ns 2133928 ns 0.96
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2055145.5 ns 2118672 ns 0.97
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 199404 ns 216398 ns 0.92
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 295513 ns 309096.5 ns 0.96
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 265086.5 ns 277497.5 ns 0.96
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 357079 ns 374849 ns 0.95
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 404948 ns 418761 ns 0.97
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 271909 ns 279090 ns 0.97
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 401382 ns 409864.5 ns 0.98
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 81222 ns 93404 ns 0.87
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 81873 ns 94686.5 ns 0.86
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 85931 ns 89286 ns 0.96
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104286 ns 104365 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 190382726 ns 193458437 ns 0.98
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 325858242.5 ns 373472760.5 ns 0.87
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 404011142 ns 404755227 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 482281052 ns 454774364.5 ns 1.06
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 372092403.5 ns 372704909 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 321563812 ns 371496475.5 ns 0.87
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 44033234 ns 60350010 ns 0.73
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 44022553 ns 52074054 ns 0.85
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 50097231 ns 51366141 ns 0.98
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28279192 ns 28579992.5 ns 0.99
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 19063364 ns 20025770.5 ns 0.95
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19385305 ns 19976154 ns 0.97
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 22954436 ns 24021852 ns 0.96
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 23852218 ns 24619312 ns 0.97
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19493491 ns 19994222 ns 0.97
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6490296.5 ns 6681275.5 ns 0.97
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6478122.5 ns 6671001 ns 0.97
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6485124.5 ns 6636669 ns 0.98

This comment was automatically generated by workflow using github-action-benchmark.

@avik-pal avik-pal force-pushed the ap/fused_ops branch 6 times, most recently from dad2931 to db59691 Compare April 26, 2024 19:07
@avik-pal avik-pal force-pushed the ap/fused_ops branch 7 times, most recently from 99d8b89 to e0c9464 Compare April 27, 2024 23:30
@avik-pal avik-pal merged commit 51f2968 into main Apr 28, 2024
36 of 37 checks passed
@avik-pal avik-pal deleted the ap/fused_ops branch April 28, 2024 00:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant