Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type stability of LayerNorm and Dropout #2005

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

ToucheSir
Copy link
Member

@ToucheSir ToucheSir commented Jun 23, 2022

These two layers made use of explicit or implicit control flow (e.g. default keyword argument values) which Zygote does not like. This PR is essentially a set of small hacks to work around that.

Any ideas on how to avoid return_type in _dropout would be much appreciated, but for now it seems to work.

TODO benchmarks.

PR Checklist

  • Entry in NEWS.md

@ToucheSir
Copy link
Member Author

TTFG timings using the following snippet:

Test code
using Metalhead, Flux, Zygote
using Metalhead: ChannelLayerNorm

model = ConvNeXt(:tiny; inchannels=1, nclasses=1).layers
# ChannelLayerNorm isn't type stable yet (for the same reason as LayerNorm wasn't),
# So remove it for this demo
model = fmap(Returns(identity), model; exclude=Base.Fix2(isa, ChannelLayerNorm))

# display(model); println()

loss(m, x) = sum(m(x))

inputs = randn(Float32, 32, 32, 1, 1)
# @time loss(model, inputs)
# @time loss(model, inputs)

loss_grad(m, x) = gradient((m, x) -> loss(m, x), m, x)

@time loss_grad(model, inputs)
# @time loss_grad(model, inputs)
julia> @time loss_grad(model, inputs)
 34.835647 seconds (87.12 M allocations: 4.701 GiB, 3.14% gc time, 99.38% compilation time) # 0.13.3
 30.679322 seconds (78.88 M allocations: 4.300 GiB, 3.46% gc time, 98.96% compilation time) # this PR

Replacing the Chain{Vector} with a Chain{Tuple} creates a larger gap:

julia> @time loss_grad(model, inputs)
 79.846248 seconds (98.87 M allocations: 5.243 GiB, 1.68% gc time, 99.67% compilation time) # 0.13.3
 63.024710 seconds (79.23 M allocations: 4.245 GiB, 1.92% gc time, 99.45% compilation time) # this PR
 52.838056 seconds (70.81 M allocations: 3.745 GiB, 1.98% gc time, 99.60% compilation time) # this PR + Zygote#1248

@ToucheSir
Copy link
Member Author

ToucheSir commented Aug 1, 2022

For kicks, here is Diffractor with JuliaDiff/ChainRules.jl#644:

julia> @time loss_grad(model, inputs)
 30.442982 seconds (92.61 M allocations: 4.148 GiB, 3.18% gc time, 89.07% compilation time) # tuple chain
 23.051121 seconds (88.06 M allocations: 3.920 GiB, 3.81% gc time, 85.11% compilation time) # vector chain, requires https://github.com/JuliaDiff/Diffractor.jl/pull/82

Re-enabling ChannelLayerNorm adds but ~1s to the total. Note that even the tuple Chain here is faster than any tested Zygote configuration.

Edit: added times for vector chains using a patched Diffractor.

@theabhirath
Copy link
Member

Does Diffractor already work with most Flux models (or at least those with built-in layers)? I was under the impression that it wasn't there yet 😅

@ToucheSir
Copy link
Member Author

Not OOTB, which is why that ChainRules PR is required.

@chengchingwen
Copy link
Member

@ToucheSir Could you try running the layer norm gradient with gpu? I have try that manual broadcast fusion before but CUDA.time said it actually allocated more gpu memory

@ToucheSir
Copy link
Member Author

You're right, it allocates one more time for over 2x the memory overhead. I also found this out the hard way recently while trying to fuse the RNN cell kernels for #2023, but forgot about the change here.

@codecov-commenter
Copy link

Codecov Report

Merging #2005 (29ef2ff) into master (d66d2c4) will increase coverage by 0.27%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master    #2005      +/-   ##
==========================================
+ Coverage   87.10%   87.37%   +0.27%     
==========================================
  Files          20       20              
  Lines        1528     1553      +25     
==========================================
+ Hits         1331     1357      +26     
+ Misses        197      196       -1     
Impacted Files Coverage Δ
src/Flux.jl 0.00% <ø> (ø)
src/layers/normalise.jl 90.28% <100.00%> (+1.46%) ⬆️
src/layers/stateless.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d66d2c4...29ef2ff. Read the comment docs.

@darsnack
Copy link
Member

Any updates on this (like benchmarks after unfusing)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants