Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient calculation with explicit type cast is broken #810

Closed
racinmat opened this issue Oct 15, 2020 · 7 comments · Fixed by #1171
Closed

gradient calculation with explicit type cast is broken #810

racinmat opened this issue Oct 15, 2020 · 7 comments · Fixed by #1171

Comments

@racinmat
Copy link
Contributor

racinmat commented Oct 15, 2020

W = randn(Float32, 2,2)
b = randn(Float32, 2)
md = Dense(W, b)
xs = sparse(randn(Float64, 2, 2))
gradient(() -> sum(W*Float32.(xs) .+ b), Params([W,b]))

W = randn(Float64, 2,2)
b = randn(Float64, 2)
md = Dense(W, b)
xs = sparse(randn(Float32, 2, 2))
gradient(() -> sum(W*Float64.(xs) .+ b), Params([W,b]))

is not working, it throws following error.

ERROR: MethodError: no method matching zero(::Type{Any})
Closest candidates are:
  zero(::Type{Union{Missing, T}}) where T at missing.jl:104
  zero(::Type{LibGit2.GitHash}) at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LibGit2\src\oid.jl:220   
  zero(::Type{Pkg.Resolve.VersionWeight}) at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\Pkg\src\Resolve\versionweights.jl:15
  ...
Stacktrace:
 [1] zero(::Type{Any}) at .\missing.jl:105
 [2] _zeros_eltypes at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\SparseArrays\src\higherorderfns.jl:203 [inlined]
 [3] _noshapecheck_map(::Zygote.var"#1068#1075", ::SparseMatrixCSC{Any,Int64}) at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\SparseArrays\src\higherorderfns.jl:159
 [4] map(::Zygote.var"#1068#1075", ::SparseMatrixCSC{Any,Int64}) at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\SparseArrays\src\higherorderfns.jl:1153
 [5] adjoint at C:\Users\racinsky\.julia\packages\Zygote\2Likt\src\lib\broadcast.jl:137 [inlined]
 [6] _pullback at C:\Users\racinsky\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
 [7] adjoint at C:\Users\racinsky\.julia\packages\Zygote\2Likt\src\lib\lib.jl:188 [inlined]
 [8] _pullback at C:\Users\racinsky\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
 [9] broadcasted at .\broadcast.jl:1257 [inlined]
 [10] _pullback(::Zygote.Context, ::typeof(Base.Broadcast.broadcasted), ::Type{Float64}, ::SparseMatrixCSC{Float32,Int64}) at C:\Users\racinsky\.julia\packages\Zygote\2Likt\src\compiler\interface2.jl:0
 [11] #165 at .\none:1 [inlined]
 [12] _pullback(::Zygote.Context, ::var"#165#166") at C:\Users\racinsky\.julia\packages\Zygote\2Likt\src\compiler\interface2.jl:0
 [13] pullback(::Function, ::Params) at C:\Users\racinsky\.julia\packages\Zygote\2Likt\src\compiler\interface.jl:172
 [14] gradient(::Function, ::Params) at C:\Users\racinsky\.julia\packages\Zygote\2Likt\src\compiler\interface.jl:53
 [15] top-level scope at none:1

The gradient computation is working without the explicit cast, but the cast is done in Flux.jl in Dense Layer in order to hit BLAS in https://github.com/FluxML/Flux.jl/blob/master/src/layers/basic.jl#L138

And this causes following issue: FluxML/Flux.jl#965

I guess Zygote should be able to handle this case, right?

@racinmat racinmat changed the title gradient calculation including explicit type cast is broken gradient calculation with explicit type cast is broken Oct 15, 2020
@mcabbott
Copy link
Member

xref FluxML/Flux.jl#815 I think.

Ideally Zygote would handle this without blinking. Ideally (IMO) Flux would simply give you an error on mismatched eltypes, instead of weird silent conversions or slow paths.

And ideally you would make sparse x contain Float32 beforehand, in actual use?

@racinmat
Copy link
Contributor Author

Yes, if flux raised error suggesting there is mismatched eltypes, I would cast my sparse x without problem, but when it's doing silent conversion, it's very hard to figure out where's the problem.
So, should we forbid mixed sparse computing for flux and raise error during forward pass for Dense and mispatched sparse data?

@mcabbott
Copy link
Member

I tried to add warnings about this and other type mismatches in FluxML/Flux.jl#1031 but well nothing happened.

Did not think about sparse anything. I don't see why dense W * sparse x as you have above should be disallowed. You're going to get a dense dx in the gradient I think.

@racinmat
Copy link
Contributor Author

Well, what I meant is:
for other 3 combinations with type mismatches it's working and does not crash. For this combination dense W * sparse x it crashes with obscure error, which is why I said it should raise error, because it currently does not work either.

Thanks for the link to the PR. Is there reason why you closed it? If it's still relevant, which I think it is, maybe I can try to ping other people to look at that PR and hopefully we can get it merged.

@mcabbott
Copy link
Member

Besides utopian ideas like fixing Flux, this might be a duplicate of #575 -- any broadcast over a sparse array seems to give an error:

julia> using Zygote, SparseArrays

julia> gradient(x -> sum(x), sprand(6, 0.9))[1] # ok
6-element Fill{Float64}: entries equal to 1.0

julia> ans == ones(6)
true

julia> gradient(x -> sum(Float32.(x)), rand(6))[1] == ones(6) # ok
true

julia> gradient(x -> sum(Float32.(x)), sprand(6, 0.9))[1] # as above
ERROR: MethodError: no method matching zero(::Type{Any})

julia> gradient(x -> sum(exp.(x)), sprand(6, 0.9))[1]
ERROR: MethodError: no method matching zero(::Type{Tuple{Float64,Zygote.ZBack{ChainRules.var"#exp_pullback#1289"{Float64}}}})
  ...
Stacktrace:
 [1] _zeros_eltypes at /Applications/Julia-1.5.app/Contents/Resources/julia/share/julia/stdlib/v1.5/SparseArrays/src/higherorderfns.jl:203 [inlined]
 [2] _noshapecheck_map(::Zygote.var"#1068#1075", ::SparseVector{Tuple{Float64,Zygote.ZBack{ChainRules.var"#exp_pullback#1289"{Float64}}},Int64}) at /Applications/Julia-1.5.app/Contents/Resources/julia/share/julia/stdlib/v1.5/SparseArrays/src/higherorderfns.jl:159
 [3] map(::Zygote.var"#1068#1075", ::SparseVector{Tuple{Float64,Zygote.ZBack{ChainRules.var"#exp_pullback#1289"{Float64}}},Int64}) at /Applications/Julia-1.5.app/Contents/Resources/julia/share/julia/stdlib/v1.5/SparseArrays/src/higherorderfns.jl:142
 [4] adjoint at /Users/me/.julia/packages/Zygote/c0awc/src/lib/broadcast.jl:137 [inlined]

julia> Float32.(sprand(10,0.5)) # respects sparsity
10-element SparseVector{Float32,Int64} with 3 stored entries:
  [7 ]  =  0.453431
  [8 ]  =  0.898222
  [10]  =  0.298508

julia> exp.(sprand(10,0.5)) # all entries nonzero
10-element SparseVector{Float64,Int64} with 10 stored entries:
  [1 ]  =  1.26619
  [2 ]  =  1.0
  [3 ]  =  1.38125
  [4 ]  =  1.46056
  [5 ]  =  2.5163
  [6 ]  =  2.35844
  [7 ]  =  1.0
  [8 ]  =  1.99315
  [9 ]  =  2.48701
  [10]  =  1.24585

For broadcasting Float32 in particular, it looks like #762 has a special case for that:

https://github.com/FluxML/Zygote.jl/pull/762/files#diff-a9e025ac90a30d27e7512546971c5d92ea7c3496ba759336ae6bf1cace6db4b2R978

which fixes gradient(x -> sum(Float32.(x)), sprand(6,7, 0.9))[1] but not the sparse vector above.

@CarloLucibello
Copy link
Member

CarloLucibello commented Feb 8, 2022

I implemented a rrule for this, but before filing a PR I was wondering if this is the expected behavior?

julia> using Zygote, Test, ChainRulesCore, Random, SparseArrays

julia> Zygote.refresh()

julia> function ChainRulesCore.rrule(::typeof(Broadcast.broadcasted), T::Type{TT}, x::AbstractSparseArray) where TT
           function broadcasted_cast_sparse(Δ)
               return NoTangent(), NoTangent(), Δ         
           end
           T.(x), broadcasted_cast_sparse
       end

julia> s = sprand(Float32, 5, 5, 0.5)
5×5 SparseMatrixCSC{Float32, Int64} with 12 stored entries:
                            0.321893  0.835169
          0.459585  0.325418  0.265581  0.0364544
 0.678645                    0.229887  0.73394
                            0.755401    
 0.953619  0.926159                      

julia> l1, gs1 = withgradient(s) do s
                   sum(Float64.(s))
               end
(val = 6.521751821041107, grad = (sparse([3, 5, 2, 5, 2, 1, 2, 3, 4, 1, 2, 3], [1, 1, 2, 2, 3, 4, 4, 4, 4, 5, 5, 5], Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 5, 5),))

# zero-entries get zero derivatives
julia> gs1[1]
5×5 SparseMatrixCSC{Float32, Int64} with 12 stored entries:
             1.0  1.0
     1.0  1.0  1.0  1.0
 1.0          1.0  1.0
             1.0    
 1.0  1.0            

# same matrix but now collected in a dense array gets all 1 derivatives
julia> l2, gs2 = withgradient(collect(s)) do s
                   sum(Float64.(s))
               end
(val = 6.521751821041107, grad = (Float32[1.0 1.0  1.0 1.0; 1.0 1.0  1.0 1.0;  ; 1.0 1.0  1.0 1.0; 1.0 1.0  1.0 1.0],))

julia> gs2[1]
5×5 Matrix{Float32}:
 1.0  1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0  1.0

So specifically, it is fine to return a gradient in the restricted manifold spanned by the non-zero entries? Or we should consider a sparse matrix as embedded in the full matrix space?
I think there have been a lot of discussions in ChainRules on similar issues, maybe @mcabbott nows if there is a default answer

@mcabbott
Copy link
Member

mcabbott commented Feb 8, 2022

Yes that sound great. The policy in CR is to regard zero entries as structurally zero, and ProjectTo enforces this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants