Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding complex broadcasting for gradients on the GPU #1324

Merged
merged 25 commits into from
Jan 10, 2023

Conversation

ptiede
Copy link
Contributor

@ptiede ptiede commented Oct 25, 2022

This is a first attempt to add support for taking gradients of complex numbers when broadcasting and on the GPU. This targets issues #961 #1121 #1215.

A nice side effect of this pull request is that complex broadcasting doesn't have to take the slow route anymore when on the CPU, and fixes the performance issues in #1276

On the current Zygote.jl release, I get:

f1(x) = sum(abs2, cispi.(x))
f2(x) = sum(abs2, cis.(x))

@btime Zygote.gradient($f1, $ones((1024,1024)))
# 926.936 ms (18874448 allocations: 640.00 MiB)

@btime Zygote.gradient($f2, $(ones((1024,1024))))
# 39.655 ms (38 allocations: 160.00 MiB)

With this pull-request I get

@btime Zygote.gradient($f1, $(ones((1024,1024))))
#  19.668 ms (31 allocations: 72.00 MiB)

@btime Zygote.gradient($f2, $(ones((1024,1024))))
#  15.781 ms (31 allocations: 72.00 MiB)

Approach

To fix these issues, I changed how broadcast_forward and dual_function work. This was inspired by @mcabbott comment but with some changes to ensure there are no dynamic dispatches or type instabilities. Specifically, I had to change the dual function since

if any(a isa Complex for a in args...)
   ...
else 
  ...
end

was leading to some type instability warnings on the GPU and some other strange issues.

On top of the change to dual another change is how broadcast_forward works. I had to make four separate functions depending on the output and the arguments broadcast. I am not sure if there is a better way to do this, but it currently works and passes all tests on my machine. One concern I had was what to do for complex->complex functions. For this, I just followed what was listed in https://juliadiff.org/ChainRulesCore.jl/stable/maths/complex.html, but maybe we don't want to follow that?

Testing

In terms of testing, I have added some small tests to cuda.jl to ensure that nothing is not returned and that the gradient on the GPU and CPU are the same. Since I also changed broadcast_forward on the CPU (always taking the fast path) I believe there is already sufficient testing done there.

PR Checklist

  • Tests are added
  • Documentation, if applicable

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly leaving this here to say the test failures are expected, but a couple suggestions while I'm at it:

src/lib/broadcast.jl Outdated Show resolved Hide resolved
src/lib/broadcast.jl Outdated Show resolved Hide resolved
src/lib/broadcast.jl Outdated Show resolved Hide resolved
Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, thanks for tacking it!

I meant to take a closer read but haven't yet, sorry.

I believe there is already sufficient testing done there.

Sadly I would not assume this. There may be very few tests of complex broadcasting, not sure (maybe I missed a section). It might be worth trying to come up with some evil test cases, including e.g. fused broadcasts where only parts are complex.

out = dual_function(f).(args...)
eltype(out) <: Dual || return (out, _ -> nothing)
T = eltype(out)
T <: Union{Dual, Complex} || return (out, _ -> nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be Union{Dual, Dual{<:Complex}}? You'd have to try pretty hard but I think the Complex path expects Dual inside.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought is was the other way around? At least that is what I am constructing in the dual_function. ForwardDiff.jl also defines Dual <: Real so I think defining it the other way would break things. However, I probably want to be a little more specific here and do

Suggested change
T <: Union{Dual, Complex} || return (out, _ -> nothing)
T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry, that's what I was thinking but didn't type...

test/cuda.jl Outdated
@testset "CUDA complex broadcasting" begin
# Issue 961 and 1121 and 1215
x = rand(Float32, 50)
y = complex(rand(Float32, 50))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why define x here at all?

Also, this y has zero imaginary part. rand(ComplexF64, 50) would be a stronger test.

julia> complex(rand(Float32, 50))
50-element Vector{ComplexF32}:
  0.89825445f0 + 0.0f0im
  0.40070343f0 + 0.0f0im
  0.29411656f0 + 0.0f0im
  0.44503874f0 + 0.0f0im

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops! That x was for a test I was doing on my machine. I think overall that the testing could be a bit better though so I've added another test that uses both real and complex arguments. I probably need to add some additional tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. I think x.^2 .*y .+ y uses only functions which have special rules, and ought to work without this PR. I think even broadcasting trivial functions like add(x,y) = x+y will change the path it takes. But messy examples (e.g. with trig, conj/real/imag, in all sorts of ways) are much more likely to expose mistakes like a conj missing somewhere.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to invent some functions, did not try them on GPU:

r3 = Float32.(inv.(2:4))
c3 = ComplexF32.(inv.(5:7) .+ im ./ (8:10))

@test gradient(r -> sum(abs2, log.(1 .+ im .* r)./2), r3)[1] ≈ [0.2077734, 0.15268978, 0.11885023]
@test gradient(c -> sum(abs2, imag.(sqrt.(c .+ im))), c3)[1] ≈ [-0.4124833f0 + 0.49228126f0im, -0.4258298f0 + 0.49446818f0im, -0.43560573f0 + 0.49583605f0im]
@test gradient((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3)[2] ≈ [2.9423256f0 + 63.7845f0im, -2.7483354f0 + 55.08628f0im, -9.976982f0 + 48.902283f0im]

But locally, with this branch, I expected them to use the new code... but adding printing doesn't seem to work?

(jl_S8DfLf) pkg> st Zygote
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_S8DfLf/Project.toml`
  [e88e6eb3] Zygote v0.6.49 `https://github.com/ptiede/Zygote.jl#pt-complexbroadcast`

julia> @eval Zygote function dual(x::Complex, i, N)  # from PR, with printing
            @show x
            re_dual = Dual(real(x), ntuple(==(i), 2N))
            im_dual = Dual(imag(x), ntuple(==(N+i), 2N))
            return Complex(re_dual, im_dual)
        end;

julia> Zygote.refresh()

julia> @test gradient(r -> sum(abs2, log.(1 .+ im .* r)./2), r3)[1] ≈ [0.2077734, 0.15268978, 0.11885023]
Test Passed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I looked into this and this occurred because I hadn't added a Complex method for _dual_safearg. When I added this some issues started to appear. One of them was because the partials for the real and complex parts had different lengths.

However, that is not the big issue. The big issue is that certain functions seem to be causing some type instabilities during the evaluation of the dual numbers. For instance,

x = rand(Complex{Float32}, 100)
f(x) = sum(abs2, log.(y))
@code_warntype Zygote.dual_function(f).(x)

MethodInstance for (::var"##dotfunction#314#7")(::Vector{ComplexF32})
  from (::var"##dotfunction#314#7")(x1) in Main
Arguments
  #self#::Core.Const(var"##dotfunction#314#7"())
  x1::Vector{ComplexF32}
Body::Union{Vector{ForwardDiff.Dual{Float32, Float32, 2}}, Vector{ForwardDiff.Dual{Float32, V, 2} where V}, Vector{ForwardDiff.Dual{Float32, Float64, 2}}}
1%1 = Zygote.dual_function::Core.Const(Zygote.dual_function)
│   %2 = (%1)(Main.f)::Core.Const(Zygote.var"#944#946"{typeof(f)}(f))
│   %3 = Base.broadcasted(%2, x1)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, Zygote.var"#944#946"{typeof(f)}, Tuple{Vector{ComplexF32}}}%4 = Base.materialize(%3)::Union{Vector{ForwardDiff.Dual{Float32, Float32, 2}}, Vector{ForwardDiff.Dual{Float32, V, 2} where V}, Vector{ForwardDiff.Dual{Float32, Float64, 2}}}
└──      return %4```

Has a problem where the broadcast can't seem to figure out that eltype of the partial field in Dual should be a Float32. What is really annoying is that this problem does not occur for Float64 where I get

x64 = Complex{Float64}.(x)
@code_warntype Zygote.dual_function(f)(x64)

MethodInstance for (::var"##dotfunction#313#6")(::Vector{ComplexF64})
  from (::var"##dotfunction#313#6")(x1) in Main
Arguments
  #self#::Core.Const(var"##dotfunction#313#6"())
  x1::Vector{ComplexF64}
Body::Vector{ForwardDiff.Dual{Float64, Float64, 2}}
1%1 = Zygote.dual_function::Core.Const(Zygote.dual_function)
│   %2 = (%1)(Main.f)::Core.Const(Zygote.var"#944#946"{typeof(f)}(f))
│   %3 = Base.broadcasted(%2, x1)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, Zygote.var"#944#946"{typeof(f)}, Tuple{Vector{ComplexF64}}}%4 = Base.materialize(%3)::Vector{ForwardDiff.Dual{Float64, Float64, 2}}
└──      return %4


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok looking into this more. It appears the log with Complex{Dual{Float32}} arguments is type unstable.
My guess is that this occurs because there isn't using the specific forward rule for a complex number for log, or likely any common functions.

Copy link
Member

@mcabbott mcabbott Nov 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is weird, @code_warntype log(Dual(1f0, 1f0) + im) is bad. Inside Base.ssqs, it looks like ldexp(Dual(1f0, 2f0), 3) makes a Float64 dual, by a method from ForwardDiff.

Anyway not this PR's problem! Maybe make an issue on ForwardDiff (or DiffRules) and test inference etc. with other functions here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok sounds good! I'll skip log for now and make tests for other functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright I was able to add the last test,

@test gradient((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3)[2]  [2.9423256f0 + 63.7845f0im, -2.7483354f0 + 55.08628f0im, -9.976982f0 + 48.902283f0im]

and everything passes! The other two tests suggested both run into the ldexp problem with Float32. I have opened up an issue JuliaDiff/ForwardDiff.jl#604 detailing the problem. The good news is that when I fix the problem locally all the tests pass!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are a couple of updates on my end. First, I just realized I was running the previous test on the CPU. When I run it on the GPU, I get a scalar indexing error. The stack trace is

julia>     @test gradcheck_gpu((r,c) -> sum(abs2, @. sin(conj(c)/r' - im) - imag(c + tanh(r/c'))), r3, c3)
Error During Test at /home/ptiede/.julia/dev/Zygote/test/cuda.jl:186
  Test threw exception
  Expression: gradcheck_gpu(((r, c)->begin
            sum(abs2, #= /home/ptiede/.julia/dev/Zygote/test/cuda.jl:186 =# @__dot__(sin(conj(c) / r' - im) - imag(c + tanh(r / c'))))
        end), r3, c3)
  Scalar indexing is disallowed.
  Invocation of getindex resulted in scalar indexing of a GPU array.
  This is typically caused by calling an iterating implementation of a method.
  Such implementations *do not* execute on the GPU, but very slowly on the CPU,
  and therefore are only permitted from the REPL for prototyping purposes.
  If you did intend to index this array, annotate the caller with @allowscalar.
  Stacktrace:
    [1] error(s::String)
      @ Base ./error.jl:35
    [2] assertscalar(op::String)
      @ GPUArraysCore ~/.julia/packages/GPUArraysCore/lojQM/src/GPUArraysCore.jl:87
    [3] getindex(::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
      @ GPUArrays ~/.julia/packages/GPUArrays/fqD8z/src/host/indexing.jl:9
    [4] getindex
      @ ~/.julia/juliaup/julia-1.8.2+0.x64/share/julia/stdlib/v1.8/LinearAlgebra/src/adjtrans.jl:180 [inlined]
    [5] _unsafe_getindex_rs
      @ ./reshapedarray.jl:250 [inlined]
    [6] _unsafe_getindex
      @ ./reshapedarray.jl:247 [inlined]
    [7] getindex
      @ ./reshapedarray.jl:235 [inlined]
    [8] iterate
      @ ./abstractarray.jl:1167 [inlined]
    [9] iterate
      @ ./abstractarray.jl:1165 [inlined]
   [10] iterate
      @ ./generator.jl:44 [inlined]
   [11] _collect(c::Base.ReshapedArray{ComplexF32, 1, LinearAlgebra.Adjoint{ComplexF32, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, itr::Base.Generator{Base.ReshapedArray{ComplexF32, 1, LinearAlgebra.Adjoint{ComplexF32, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
      @ Base ./array.jl:807
   [12] collect_similar
      @ ./array.jl:716 [inlined]
   [13] map
      @ ./abstractarray.jl:2933 [inlined]
   [14] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::LinearAlgebra.Adjoint{ComplexF32, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}})
      @ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/projection.jl:236
   [15] ProjectTo
      @ ~/.julia/packages/ChainRulesCore/C73ay/src/projection.jl:414 [inlined]
   [16] _project
      @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:184 [inlined]
   [17] unbroadcast(x::LinearAlgebra.Adjoint{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, x̄::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer})
      @ Zygote ~/.julia/dev/Zygote/src/lib/broadcast.jl:58
   [18] (::Zygote.var"#857#858"{CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}, LinearAlgebra.Adjoint{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}})(Δ::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer})
      @ Zygote ~/.julia/dev/Zygote/src/lib/broadcast.jl:97
   [19] (::Zygote.var"#3669#back#859"{Zygote.var"#857#858"{CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}, LinearAlgebra.Adjoint{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}}})(Δ::CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer})
      @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
   [20] Pullback
      @ ./none:0 [inlined]
   [21] (::typeof((#13)))(Δ::Float32)
      @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
   [22] (::Zygote.var"#60#61"{typeof((#13))})(Δ::Float32)
      @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:45
   [23] gradient(::Function, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{Any})
      @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
   [24] gradcheck_gpu(::Function, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{Any})
      @ Main ~/.julia/dev/Zygote/test/cuda.jl:9
   [25] top-level scope

From the look of the stack trace, this isn't due to this pull request. In fact, if I change the function definition to

sin(conj(c)/$(transpose(r)) - im) - imag(c + tanh(r/c')))

then everything is fine, so my guess is that this is some funkiness related to the pullback of an adjoint of a real vector. I'll take a look into this, but I am not sure if that's part of this pull request.

Second, I have added some additional tests to ensure we hit every one of the _broadcast_forward branches.

@ptiede
Copy link
Contributor Author

ptiede commented Nov 10, 2022

@mcabbott mostly good news. The ldexp type instability was fixed in JuliaDiff/DiffRules.jl#89.
However, now I am getting a really annoying issue with the following:

r3 = Float32.(inv.(2:4))
f(r) = sum(abs2, log.(1 .+ im .* r)./2)
Zygote.gradient(f, r3)

which gives a Reason: unsupported dynamic function invocation (call to exponent) error. Now this occurs only on Julia 1.6. For Julia 1.7 and 1.8 everything works fine.

Digging into this issue a bit more, on 1.6 I can create the following MWE:

julia> rd3 = first.(Zygote.dualize(r3)) # CuArray{ForwardDiff.Dual{Nothing, Float32, 1}, 1, CUDA.Mem.DeviceBuffer}
julia> log.(1im .* rd3) 
ERROR: InvalidIRError: compiling kernel #broadcast_kernel#17(CUDA.CuKernelContext, CuDeviceVector{Complex{ForwardDiff.Dual{Nothing, Float32, 1}}, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(log), Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(*), Tuple{Complex{Int64}, Base.Broadcast.Extruded{CuDeviceVector{ForwardDiff.Dual{Nothing, Float32, 1}, 1}, Tuple{Bool}, Tuple{Int64}}}}}}, Int64) resulted in invalid LLVM IR
Reason: unsupported dynamic function invocation (call to exponent)
Stacktrace:
 [1] ssqs
   @ ./complex.jl:474
 [2] log
   @ ./complex.jl:594
 [3] _broadcast_getindex_evalf
   @ ./broadcast.jl:648
 [4] _broadcast_getindex
   @ ./broadcast.jl:621
 [5] getindex
   @ ./broadcast.jl:575
 [6] broadcast_kernel
   @ ~/.julia/packages/GPUArrays/fqD8z/src/host/broadcast.jl:57
...

which suggests the problem is in Base.ssqs. This looks like an issue outside the scope of this pull-request so I am not too sure what we want to do here.

@mcabbott
Copy link
Member

Does more inlining help at all, e.g. @inline function dual_function?

@ptiede
Copy link
Contributor Author

ptiede commented Nov 10, 2022

Sadly no :( The MWE also shouldn't have an inlining issue right?

@ptiede
Copy link
Contributor Author

ptiede commented Nov 11, 2022

Does more inlining help at all, e.g. @inline function dual_function?

Ok figured it out! Analyzing Base.ssqs with Cthulhu it looks like what was happening was sometimes the code would venture into the following call:

%282  = call #exponent(::ForwardDiff.Dual{Nothing, Float32, 1})::Union{}

which errors because exponent(::Dual) is never defined. This is what was causing the dynamic call. If I define the following method

Base.exponent(x::ForwardDiff.Dual{<:Real}) = Base.exponent(ForwardDiff.value(x))

everything works and we pass the tests on Julia 1.6. I believe this function definition makes sense since exponent: Real -> Int so we only really care about the value of the function. I don't really understand why this didn't cause an issue on 1.7/1.8, but maybe this got optimized away?

@ptiede
Copy link
Contributor Author

ptiede commented Nov 11, 2022

Alright the dual exponent issue has been fixed. When a new version of ForwardDiff is released when a new version of ForwardDiff.jl get released the 1.6 tests should pass.

@devmotion
Copy link
Collaborator

DynamicPPL test failures are caused by JuliaDiff/ForwardDiff.jl#606.

@ptiede
Copy link
Contributor Author

ptiede commented Nov 16, 2022

@mcabbott I think this is finally ready to review again. All the tests are passing, and I have added some additional tests to ensure that every branch is getting hit.

@ptiede
Copy link
Contributor Author

ptiede commented Dec 1, 2022

Is this ready to merge?

@ptiede
Copy link
Contributor Author

ptiede commented Jan 10, 2023

Just a bump to see if this is ready to be merged or it there are some outstanding items that I still need to fix.

@CarloLucibello CarloLucibello merged commit 616bf6c into FluxML:master Jan 10, 2023
@CarloLucibello
Copy link
Member

Thanks! I'll tag a new release shortly

@CarloLucibello
Copy link
Member

@ptiede which of the issues mentioned in the OP should be closed?

@ptiede
Copy link
Contributor Author

ptiede commented Jan 10, 2023

This should fix 961, 1121, 1215, 1276, i.e. all of them since they were all the same problem in disguise.

@CarloLucibello
Copy link
Member

CarloLucibello commented Jan 10, 2023

Do you think we need some extra tests or the ones in this PR cover them all?

@ptiede
Copy link
Contributor Author

ptiede commented Jan 10, 2023

I think the tests should cover all of those cases

Zygote.jl/test/cuda.jl

Lines 191 to 192 in 616bf6c

@test gradcheck_gpu(x->sum(real, cis.(x)), xgpu)
@test gradcheck_gpu(x->sum(real, cispi.(x)), xgpu)

should cover 1276 intrinsically because the type instability that was causing the slowdown is fixed.

@test gradcheck_gpu((x,y)->sum(abs2, x.^2 .+ y), xgpu, ygpu)

should cover the abs2 bug. But coming up with tests was tricky so it is possible that I missed something.

@CarloLucibello
Copy link
Member

I just tested and closed all of them

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants