Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Record Broadcast.broadcasted instead of Broadcast.broadcast #215

Open
torfjelde opened this issue Jan 16, 2023 · 0 comments
Open

Record Broadcast.broadcasted instead of Broadcast.broadcast #215

torfjelde opened this issue Jan 16, 2023 · 0 comments

Comments

@torfjelde
Copy link

IIUC ReverseDiff records broadcast and uses ForwardDiff to specialize further on broadcasted statements whenever possible, leading to much better performance than if one were to trace through all the operations using ReverseDiff.TrackedReal.

Unfortunately this means that once one tries to make use of Broadcast.broadcasted, i.e. lazy broadcasting, this is not recorded and we end up taking the less desirable path of tracing through the broadcast using ReverseDiff.TrackedReal:

julia> using ReverseDiff

julia> f(x) = sum(exp.(x))
f (generic function with 1 method)

julia> f_tape = ReverseDiff.GradientTape(f, (rand(10, ),))
typename(ReverseDiff.GradientTape)(f)

julia> g(x) = sum(Broadcast.instantiate(Broadcast.broadcasted(exp, x)))
g (generic function with 1 method)

julia> g_tape = ReverseDiff.GradientTape(g, (rand(10, ),))
typename(ReverseDiff.GradientTape)(g)

julia> length(g_tape.tape)
19

julia> g_tape.tape
19-element Vector{ReverseDiff.AbstractInstruction}:
 ScalarInstruction(exp):
  input:  TrackedReal<76u>(0.5057895559423533, 0.0, 6vX, 1, 5d7)
  output: TrackedReal<AW2>(1.6582943198420965, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.6582943198420965)
 ScalarInstruction(exp):
  input:  TrackedReal<KUi>(0.13213345349262395, 0.0, 6vX, 2, 5d7)
  output: TrackedReal<JmS>(1.141260614319831, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.141260614319831)
 ScalarInstruction(+):
  input:  (TrackedReal<AW2>(1.6582943198420965, 0.0, 6vX, ---),
           TrackedReal<JmS>(1.141260614319831, 0.0, 6vX, ---))
  output: TrackedReal<Cng>(2.7995549341619275, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<GdN>(0.034478177830953305, 0.0, 6vX, 3, 5d7)
  output: TrackedReal<DNm>(1.03507944045113, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.03507944045113)
 ScalarInstruction(+):
  input:  (TrackedReal<Cng>(2.7995549341619275, 0.0, 6vX, ---),
           TrackedReal<DNm>(1.03507944045113, 0.0, 6vX, ---))
  output: TrackedReal<DEh>(3.8346343746130573, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<CEJ>(0.04867133616730335, 0.0, 6vX, 4, 5d7)
  output: TrackedReal<GVI>(1.0498752380105207, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.0498752380105207)
 ScalarInstruction(+):
  input:  (TrackedReal<DEh>(3.8346343746130573, 0.0, 6vX, ---),
           TrackedReal<GVI>(1.0498752380105207, 0.0, 6vX, ---))
  output: TrackedReal<7HD>(4.884509612623578, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<Ft8>(0.8637862831888328, 0.0, 6vX, 5, 5d7)
  output: TrackedReal<5lH>(2.3721252497772487, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(2.3721252497772487)
 ScalarInstruction(+):
  input:  (TrackedReal<7HD>(4.884509612623578, 0.0, 6vX, ---),
           TrackedReal<5lH>(2.3721252497772487, 0.0, 6vX, ---))
  output: TrackedReal<HpK>(7.256634862400826, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<Gmo>(0.0039196786165185404, 0.0, 6vX, 6, 5d7)
  output: TrackedReal<1R0>(1.0039273706035023, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.0039273706035023)
 ScalarInstruction(+):
  input:  (TrackedReal<HpK>(7.256634862400826, 0.0, 6vX, ---),
           TrackedReal<1R0>(1.0039273706035023, 0.0, 6vX, ---))
  output: TrackedReal<8hX>(8.260562233004329, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<L8R>(0.9153223594101434, 0.0, 6vX, 7, 5d7)
  output: TrackedReal<4qL>(2.4975802406432295, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(2.4975802406432295)
 ScalarInstruction(+):
  input:  (TrackedReal<8hX>(8.260562233004329, 0.0, 6vX, ---),
           TrackedReal<4qL>(2.4975802406432295, 0.0, 6vX, ---))
  output: TrackedReal<17T>(10.758142473647558, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<JYS>(0.15063946146751517, 0.0, 6vX, 8, 5d7)
  output: TrackedReal<6gL>(1.1625774285521715, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.1625774285521715)
 ScalarInstruction(+):
  input:  (TrackedReal<17T>(10.758142473647558, 0.0, 6vX, ---),
           TrackedReal<6gL>(1.1625774285521715, 0.0, 6vX, ---))
  output: TrackedReal<Dvu>(11.92071990219973, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<B28>(0.3010502862135006, 0.0, 6vX, 9, 5d7)
  output: TrackedReal<5Ku>(1.3512772904478805, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.3512772904478805)
 ScalarInstruction(+):
  input:  (TrackedReal<Dvu>(11.92071990219973, 0.0, 6vX, ---),
           TrackedReal<5Ku>(1.3512772904478805, 0.0, 6vX, ---))
  output: TrackedReal<2ZC>(13.27199719264761, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])
 ScalarInstruction(exp):
  input:  TrackedReal<Cpk>(0.02748173107946794, 0.0, 6vX, 10, 5d7)
  output: TrackedReal<9uF>(1.0278628369912384, 0.0, 6vX, ---)
  cache:  Base.RefValue{Float64}(1.0278628369912384)
 ScalarInstruction(+):
  input:  (TrackedReal<2ZC>(13.27199719264761, 0.0, 6vX, ---),
           TrackedReal<9uF>(1.0278628369912384, 0.0, 6vX, ---))
  output: TrackedReal<J8N>(14.299860029638849, 0.0, 6vX, ---)
  cache:  Base.RefValue{StaticArraysCore.SVector{2, Float64}}([1.0, 1.0])

julia> @benchmark ReverseDiff.gradient($f, $(randn(1000)))
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):   9.308 μs   1.218 ms  ┊ GC (min  max): 0.00%  95.95%
 Time  (median):     10.105 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   12.275 μs ± 36.408 μs  ┊ GC (mean ± σ):  9.10% ±  3.05%

  ▇██▇▆▅▅▄▃▁                                       ▁▁▁▁▁▁     ▂
  ██████████▇▅▄▅▄▄▃▃▄▅▅▅▄▅▄▄▁▃▄▁▄▃▃▁▁▁▁▁▁▃▄▃▁▃▄▅▅▆████████▇▇▇ █
  9.31 μs      Histogram: log(frequency) by time      27.8 μs <

 Memory estimate: 55.83 KiB, allocs estimate: 15.

julia> @benchmark ReverseDiff.gradient($g, $(randn(1000)))
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  151.875 μs    3.504 ms  ┊ GC (min  max):  0.00%  94.23%
 Time  (median):     158.858 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   183.434 μs ± 253.436 μs  ┊ GC (mean ± σ):  12.42% ±  8.45%

  ▄▆▆█▇▇▅▄▃▂▁▁▁                                                 ▂
  ██████████████▇▆▇▆▆▇▆▆▅▅▄▅▆▃▄▄▃▁▄▄▃▁▃▄▁▁▁▁▁▁▄▃▃▅▄▄▅▄▄▃▅▅▃▆▆▆▆ █
  152 μs        Histogram: log(frequency) by time        259 μs <

 Memory estimate: 374.53 KiB, allocs estimate: 8009.

The overhead can of course be lowered if the tape is compiled:

julia> x = randn(1000);

julia> inputs = (x,); results = (similar(x),); cfg = ReverseDiff.GradientConfig(inputs);

julia> g_tape = ReverseDiff.GradientTape(g, inputs);

julia> compiled_g_tape = ReverseDiff.compile(g_tape)
typename(ReverseDiff.CompiledTape)(g)

julia> @benchmark ReverseDiff.gradient!($results, $compiled_g_tape, $inputs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  51.798 μs  117.899 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     54.679 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   55.396 μs ±   4.165 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

    ▃▂▅▁▁▇█▂ ▃▃       ▁              ▂    ▁▁                   ▂
  ▇▇████████▄███▆▇▇▆▆▇█▆▄▄▇▇▄▃▃▆▇▃▁▁██▆▁▁▃██▄▁▄▃▆▅▁▅▃▁▅▅▅▅▃▃▄▅ █
  51.8 μs       Histogram: log(frequency) by time      73.1 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

but it's still slower than ForwardDiff (this of course varies wrt. input size, etc. but I'm guessing this perf difference is well-established given that basically all reverse-AD frameworks in Julia make use of ForwardDiff for broadcasting).

Would it be possible to record broadcasted instead of broadcast in ReverseDiff.jl (this is the way it's done in Zygote.jl)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant