-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove ZygoteDistancesExt and associated tests #1460
Conversation
To make this not completely breaking I think one has to keep these rules for all Distances versions that do not contain the ChainRules definitions. At least on Julia < 1.9 users won't be able to use Distances with an upcoming Zygote release anymore, even if they update Distances, since Distances only defines these ChainRules defs on Julia >= 1.9. |
Zygote still has to support Julia 1.6+ and extensions only work on 1.9+, so unless Distances.jl adds a fallback with a direct dep/Requires.jl I don't see how we could remove the Zygote rules? Edit: @devmotion beat me by a second! Would it be so painful to have one of the aforementioned fallbacks on the Distances side? |
There's a long history of discussions with Distances maintainers about adding ChainRulesCore that already started before extensions were a thing but they did not approve the idea, so I don't think it's likely to be added as a direct dependency. Not sure about Requires but I assume there might be similar concerns about increasing dependencies and loading times. |
My recollection is that those discussions happened when CRC was a much heavier dep than it is now, and the topic hasn't really been revisited since. The problem with us gating the Zygote rules behind a version check is what @simsurace mentioned up top: they've bitrotted to the point where they're no longer super functional. Unless someone wants to put in the effort of fixing them, there's not too much we can do on the Zygote side to ensure users access to functional rules for Distances.jl on Julia <1.9. |
Ok, maybe I haven't diagnosed the problem correctly. Zygote#master, ChainRules@1.55.0, Distances@0.10.10: julia> using Distances, Zygote
julia> x = rand(10);
julia> f(x) = iszero(x) ? zero(x) : x;
julia> Zygote.gradient(_x -> sum(f, pairwise(Euclidean(), reshape(_x, :, 1); dims=1)), x)
ERROR: MethodError: no method matching *(::Nothing, ::Float64)
Closest candidates are:
*(::Any, ::Any, ::Any, ::Any...)
@ Base operators.jl:578
*(::T, ::T) where T<:Union{Float16, Float32, Float64}
@ Base float.jl:410
*(::StridedArray{P}, ::Real) where P<:Dates.Period
@ Dates ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/Dates/src/deprecated.jl:44
...
Stacktrace:
[1] (::Zygote.var"#1412#1416"{Int64})(y1::Nothing, o1::ForwardDiff.Dual{Nothing, Float64, 2})
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/lib/broadcast.jl:298
[2] _broadcast_getindex_evalf
@ ./broadcast.jl:683 [inlined]
[3] _broadcast_getindex
@ ./broadcast.jl:656 [inlined]
[4] getindex
@ ./broadcast.jl:610 [inlined]
[5] macro expansion
@ ./broadcast.jl:974 [inlined]
[6] macro expansion
@ ./simdloop.jl:77 [inlined]
[7] copyto!
@ ./broadcast.jl:973 [inlined]
[8] copyto!
@ ./broadcast.jl:926 [inlined]
[9] copy
@ ./broadcast.jl:898 [inlined]
[10] materialize
@ ./broadcast.jl:873 [inlined]
[11] broadcast(::Zygote.var"#1412#1416"{Int64}, ::Matrix{Union{Nothing, Float64}}, ::Matrix{ForwardDiff.Dual{Nothing, Float64, 2}})
@ Base.Broadcast ./broadcast.jl:811
[12] #1411
@ ~/.julia/packages/Zygote/XJ8pP/src/lib/broadcast.jl:298 [inlined]
[13] ntuple
@ ./ntuple.jl:49 [inlined]
[14] bc_fwd_back
@ ~/.julia/packages/Zygote/XJ8pP/src/lib/broadcast.jl:297 [inlined]
[15] #4155#back
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
[16] #291
@ ~/.julia/packages/Zygote/XJ8pP/src/lib/lib.jl:206 [inlined]
[17] #2173#back
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
[18] Pullback
@ ./broadcast.jl:1317 [inlined]
[19] Pullback
@ ~/.julia/packages/Zygote/XJ8pP/ext/ZygoteDistancesExt.jl:104 [inlined]
[20] (::ZygoteDistancesExt.var"#pairwise_Euclidean_pullback#52"{Zygote.Pullback{Tuple{ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, SqEuclidean, Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eps_pullback#396"{Tuple{DataType}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(eltype), Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eltype_pullback#385"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(ZygoteDistancesExt._sqrt_if_positive), Matrix{Float64}, Float64}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.var"#4155#back#1376"{Zygote.var"#bc_fwd_back#1414"{Matrix{ForwardDiff.Dual{Nothing, Float64, 2}}, Tuple{Matrix{Float64}, Float64}, Val{2}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Float64}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2881#back#688"{Zygote.var"#map_back#682"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1182"{Tuple{Nothing, Nothing, Nothing}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Matrix{Float64}}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, ZygoteDistancesExt.var"#63#back#30"{ZygoteDistancesExt.var"#32#33"{LinearAlgebra.Transpose{Float64, Matrix{Float64}}, typeof(transpose)}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float64}}, Tuple{}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:dims, Zygote.Context{false}, ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, Int64}}}}})(Δ::Matrix{Union{Nothing, Float64}})
@ ZygoteDistancesExt ~/.julia/packages/Zygote/XJ8pP/ext/ZygoteDistancesExt.jl:107
[21] Pullback
@ ./REPL[86]:1 [inlined]
[22] (::Zygote.Pullback{Tuple{var"#89#90", Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#sum_pullback_f2#1665"{typeof(f), Colon, Matrix{Tuple{Float64, Zygote.var"#ad_pullback#58"{Tuple{typeof(f), Float64}, Zygote.Pullback{Tuple{typeof(f), Float64}, Any}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.Pullback{Tuple{Type{Euclidean}}, Tuple{}}, ZygoteDistancesExt.var"#pairwise_Euclidean_pullback#52"{Zygote.Pullback{Tuple{ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, SqEuclidean, Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eps_pullback#396"{Tuple{DataType}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(eltype), Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eltype_pullback#385"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(ZygoteDistancesExt._sqrt_if_positive), Matrix{Float64}, Float64}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.var"#4155#back#1376"{Zygote.var"#bc_fwd_back#1414"{Matrix{ForwardDiff.Dual{Nothing, Float64, 2}}, Tuple{Matrix{Float64}, Float64}, Val{2}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Float64}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2881#back#688"{Zygote.var"#map_back#682"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1182"{Tuple{Nothing, Nothing, Nothing}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Matrix{Float64}}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, ZygoteDistancesExt.var"#63#back#30"{ZygoteDistancesExt.var"#32#33"{LinearAlgebra.Transpose{Float64, Matrix{Float64}}, typeof(transpose)}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float64}}, Tuple{}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:dims, Zygote.Context{false}, ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, Int64}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2799#back#625"{Zygote.var"#619#623"{Vector{Float64}, Tuple{Colon, Int64}}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
[23] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#89#90", Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#sum_pullback_f2#1665"{typeof(f), Colon, Matrix{Tuple{Float64, Zygote.var"#ad_pullback#58"{Tuple{typeof(f), Float64}, Zygote.Pullback{Tuple{typeof(f), Float64}, Any}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.Pullback{Tuple{Type{Euclidean}}, Tuple{}}, ZygoteDistancesExt.var"#pairwise_Euclidean_pullback#52"{Zygote.Pullback{Tuple{ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, SqEuclidean, Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eps_pullback#396"{Tuple{DataType}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(eltype), Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#eltype_pullback#385"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(ZygoteDistancesExt._sqrt_if_positive), Matrix{Float64}, Float64}, Tuple{Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.var"#4155#back#1376"{Zygote.var"#bc_fwd_back#1414"{Matrix{ForwardDiff.Dual{Nothing, Float64, 2}}, Tuple{Matrix{Float64}, Float64}, Val{2}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Float64}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2881#back#688"{Zygote.var"#map_back#682"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1182"{Tuple{Nothing, Nothing, Nothing}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Matrix{Float64}}, Tuple{}}, Zygote.var"#2017#back#204"{typeof(identity)}}}, ZygoteDistancesExt.var"#63#back#30"{ZygoteDistancesExt.var"#32#33"{LinearAlgebra.Transpose{Float64, Matrix{Float64}}, typeof(transpose)}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float64}}, Tuple{}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:dims, Zygote.Context{false}, ZygoteDistancesExt.var"#_pairwise_euclidean#51"{Int64}, Int64}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2799#back#625"{Zygote.var"#619#623"{Vector{Float64}, Tuple{Colon, Int64}}}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface.jl:45
[24] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface.jl:97
[25] top-level scope
@ REPL[86]:1 Zygote#master, ChainRules@1.52.1, Distances@0.10.10: julia> using Distances, Zygote
julia> x = rand(10);
julia> f(x) = iszero(x) ? zero(x) : x;
julia> Zygote.gradient(_x -> sum(f, pairwise(Euclidean(), reshape(_x, :, 1); dims=1)), x)
([-9.999999999999956, 6.0, -17.999999999999996, 1.9999999999999984, -14.000000000000028, 10.000000000000766, -2.000000000000006, 18.000000000000004, 13.999999999999257, -6.000000000000013],) Zygote#simsurace:remove-distances, ChainRules@1.55.0, Distances@0.10.10: maybe another bug in Zygote, see JuliaStats/Distances.jl#256 julia> using Distances, Zygote
julia> x = rand(10);
julia> f(x) = iszero(x) ? zero(x) : x;
julia> Zygote.gradient(_x -> sum(f, pairwise(Euclidean(), reshape(_x, :, 1); dims=1)), x)
ERROR: MethodError: no method matching _normalize(::ChainRulesCore.ZeroTangent, ::Float64)
Closest candidates are:
_normalize(::Real, ::Real)
@ DistancesChainRulesCoreExt ~/.julia/packages/Distances/PvoXa/ext/DistancesChainRulesCoreExt.jl:83
Stacktrace:
[1] _broadcast_getindex_evalf
@ ./broadcast.jl:683 [inlined]
[2] _broadcast_getindex
@ ./broadcast.jl:656 [inlined]
[3] getindex
@ ./broadcast.jl:610 [inlined]
[4] copy
@ ./broadcast.jl:912 [inlined]
[5] materialize
@ ./broadcast.jl:873 [inlined]
[6] (::DistancesChainRulesCoreExt.var"#pairwise_Euclidean_X_pullback#10"{Int64, Matrix{Float64}, Matrix{Float64}})(ΔΩ::Matrix{Any})
@ DistancesChainRulesCoreExt ~/.julia/packages/Distances/PvoXa/ext/DistancesChainRulesCoreExt.jl:114
[7] ZBack
@ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:211 [inlined]
[8] (::Zygote.var"#kw_zpullback#53"{DistancesChainRulesCoreExt.var"#pairwise_Euclidean_X_pullback#10"{Int64, Matrix{Float64}, Matrix{Float64}}})(dy::Matrix{Union{Nothing, Float64}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/chainrules.jl:237
[9] Pullback
@ ./REPL[7]:1 [inlined]
[10] (::Zygote.Pullback{Tuple{var"#3#4", Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#sum_pullback_f2#1665"{typeof(f), Colon, Matrix{Tuple{Float64, Zygote.var"#ad_pullback#58"{Tuple{typeof(f), Float64}, Zygote.Pullback{Tuple{typeof(f), Float64}, Any}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.var"#2799#back#625"{Zygote.var"#619#623"{Vector{Float64}, Tuple{Colon, Int64}}}, Zygote.Pullback{Tuple{Type{Euclidean}}, Tuple{}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#kw_zpullback#53"{DistancesChainRulesCoreExt.var"#pairwise_Euclidean_X_pullback#10"{Int64, Matrix{Float64}, Matrix{Float64}}}, Zygote.var"#2017#back#204"{typeof(identity)}}})(Δ::Float64)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[11] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#3#4", Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#sum_pullback_f2#1665"{typeof(f), Colon, Matrix{Tuple{Float64, Zygote.var"#ad_pullback#58"{Tuple{typeof(f), Float64}, Zygote.Pullback{Tuple{typeof(f), Float64}, Any}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.var"#2799#back#625"{Zygote.var"#619#623"{Vector{Float64}, Tuple{Colon, Int64}}}, Zygote.Pullback{Tuple{Type{Euclidean}}, Tuple{}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:dims,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:dims,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#kw_zpullback#53"{DistancesChainRulesCoreExt.var"#pairwise_Euclidean_X_pullback#10"{Int64, Matrix{Float64}, Matrix{Float64}}}, Zygote.var"#2017#back#204"{typeof(identity)}}}})(Δ::Float64)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:45
[12] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
[13] top-level scope
@ REPL[7]:1 |
Closing this for now as it seems to be considered too breaking. I opened #1464 to track the issue that has been unmasked. |
This extension should probably be removed as