-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Slow Broadcasting (compared to Zygote) #1434
Comments
using BenchmarkTools, Enzyme, Zygote
gelu_act(x) = sum(abs2, sin.(x))
x = randn(Float32, 32, 32, 1, 32);
Enzyme.gradient(Reverse, gelu_act, x);
@btime Enzyme.gradient(Reverse, $gelu_act, $x);
@btime Zygote.gradient($gelu_act, $x) |
So the code generated here by the broadacst, before any Enzyme AD is actually quite awful. I wrote some primitive optimization passes to do a bit of cleanup (which may fix the runtiem activity), but still the indexing pattern is really bad. @vchuravy any ideas what's happening here (besides it presumably now being > 3 dims so no specialization by Julia)
|
This is post my fix btw^ Post fix timings:
The bigger issue rn imo is the fact that loop bounds aren't statically inferrable due to whatever that awful index math is. So as a result inner loops are caching the true iteration count, inside other loops, doing a bunch of unnecessary caching/etc. That's still not fixed. Pre my fix timings are slower so minor fix does ~something for perf, but again index math is likely root cause. At least others won't have to deal with runtime activity though.
|
Module pre optimization: |
This might be somewhat unfair because gelu has
rrule
defined.The text was updated successfully, but these errors were encountered: