Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for linear algebra #68

Closed
gdalle opened this issue May 16, 2024 · 9 comments
Closed

Support for linear algebra #68

gdalle opened this issue May 16, 2024 · 9 comments

Comments

@gdalle
Copy link
Collaborator

gdalle commented May 16, 2024

Prompted by JuliaDiff/DifferentiationInterface.jl#263, linked but not equivalent to #55

Here's an MWE:

julia> using ADTypes, SparseConnectivityTracer

julia> ADTypes.hessian_sparsity(logdet, rand(2, 2), TracerSparsityDetector())
ERROR: MethodError: HessianTracer{Int64, BitSet, Dict{Int64, BitSet}}(::HessianTracer{Int64, BitSet, Dict{Int64, BitSet}}) is ambiguous.

Candidates:
  (var"#ctor-self#"::Type{HessianTracer{I, S, D}} where {I<:Integer, S, D<:AbstractDict{I, S}})(inputs)
    @ SparseConnectivityTracer ~/.julia/packages/SparseConnectivityTracer/rpD4Z/src/tracers.jl:106
  HessianTracer{I, S, D}(::Number) where {I<:Integer, S, D}
    @ SparseConnectivityTracer ~/.julia/packages/SparseConnectivityTracer/rpD4Z/src/tracers.jl:122
  (::Type{T})(x::T) where T<:Number
    @ Core boot.jl:792

Possible fix, define
  HessianTracer{I, S, D}(::HessianTracer{I, S, D}) where {I<:Integer, S, D<:AbstractDict{I, S}}

Stacktrace:
 [1] oneunit(::Type{HessianTracer{Int64, BitSet, Dict{Int64, BitSet}}})
   @ Base ./number.jl:372
 [2] lutype(T::Type)
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/lu.jl:209
 [3] lu(A::Matrix{HessianTracer{Int64, BitSet, Dict{Int64, BitSet}}}, pivot::RowMaximum; check::Bool)
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/lu.jl:300
 [4] logabsdet(A::Matrix{HessianTracer{Int64, BitSet, Dict{Int64, BitSet}}})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1676
 [5] logdet(A::Matrix{HessianTracer{Int64, BitSet, Dict{Int64, BitSet}}})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1701
 [6] trace_function(::Type{HessianTracer{Int64, BitSet, Dict{Int64, BitSet}}}, f::typeof(logdet), x::Matrix{Float64})
   @ SparseConnectivityTracer ~/.julia/packages/SparseConnectivityTracer/rpD4Z/src/pattern.jl:22
 [7] hessian_pattern(f::Function, x::Matrix{Float64}, ::Type{BitSet})
   @ SparseConnectivityTracer ~/.julia/packages/SparseConnectivityTracer/rpD4Z/src/pattern.jl:198
 [8] hessian_sparsity(f::Function, x::Matrix{Float64}, ::TracerSparsityDetector{BitSet})
   @ SparseConnectivityTracer ~/.julia/packages/SparseConnectivityTracer/rpD4Z/src/adtypes.jl:44
 [9] top-level scope
   @ REPL[18]:1

In this case I think the bug is not linalg-related, but for functions that are not in pure Julia we might want to add vector- or matrix-based overloads

@adrhill
Copy link
Owner

adrhill commented May 16, 2024

Looks like a missing overload of oneunit on tracer types, which should return an empty tracer.
I'll add it to the tests and fix this in the ongoing refactor in #65.

adrhill added a commit that referenced this issue May 16, 2024
@adrhill
Copy link
Owner

adrhill commented May 17, 2024

Closed by #65.

@adrhill adrhill closed this as completed May 17, 2024
@ElOceanografo
Copy link

This MWE is still not working for me with SCT v0.4.0, though the error has changed:

(jl_YQ9xYx) pkg> st
Status `/tmp/jl_YQ9xYx/Project.toml`
  [47edcb42] ADTypes v1.2.1
  [9f842d2f] SparseConnectivityTracer v0.4.0

julia> using ADTypes, SparseConnectivityTracer, LinearAlgebra

julia> ADTypes.hessian_sparsity(logdet, rand(2, 2), TracerSparsityDetector())
ERROR: Function > requires primal value(s).
A dual-number tracer for local sparsity detection can be used via `local_hessian_pattern`.
Stacktrace:
  [1] >(tx::SparseConnectivityTracer.HessianTracer{…}, ty::SparseConnectivityTracer.HessianTracer{…})
    @ SparseConnectivityTracer ~/.julia/packages/SparseConnectivityTracer/JvHcU/src/overload_dual.jl:30
  [2] generic_lufact!(A::Matrix{SparseConnectivityTracer.HessianTracer{BitSet, Set{…}}}, pivot::RowMaximum; check::Bool)
    @ LinearAlgebra ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/lu.jl:152
  [3] generic_lufact!
    @ ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/lu.jl:134 [inlined]
  [4] lu!
    @ ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/lu.jl:132 [inlined]
  [5] #lu#164
    @ ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/lu.jl:300 [inlined]
  [6] lu (repeats 2 times)
    @ ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/lu.jl:299 [inlined]
  [7] logabsdet(A::Matrix{SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1676
  [8] logdet(A::Matrix{SparseConnectivityTracer.HessianTracer{BitSet, Set{Tuple{Int64, Int64}}}})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1701
  [9] trace_function(::Type{SparseConnectivityTracer.HessianTracer{BitSet, Set{…}}}, f::typeof(logdet), x::Matrix{Float64})
    @ SparseConnectivityTracer ~/.julia/packages/SparseConnectivityTracer/JvHcU/src/pattern.jl:32
 [10] hessian_pattern(f::Function, x::Matrix{Float64}, ::Type{BitSet}, ::Type{Set{Tuple{Int64, Int64}}})
    @ SparseConnectivityTracer ~/.julia/packages/SparseConnectivityTracer/JvHcU/src/pattern.jl:326
 [11] hessian_sparsity(f::Function, x::Matrix{Float64}, ::TracerSparsityDetector{BitSet, Set{Tuple{Int64, Int64}}})
    @ SparseConnectivityTracer ~/.julia/packages/SparseConnectivityTracer/JvHcU/src/adtypes.jl:45
 [12] top-level scope
    @ REPL[4]:1
Some type information was truncated. Use `show(err)` to see complete types.

@gdalle gdalle reopened this May 18, 2024
@gdalle
Copy link
Collaborator Author

gdalle commented May 18, 2024

That's because logdet is a function which apparently has branches, and requires comparing the actual values of the variables. The basic tracing mechanism, the one that is interfaced with ADTypes, only propagates sets of influencing indices and forgets the primal values (which explains the new error).
However, the point of the refactor was to add "local" tracing, which takes values into account and thus works with logdet. At the moment you have to use the native interface of SCT for that:

julia> local_hessian_pattern(logdet, rand(2, 2))
4×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 16 stored entries:
 1  1  1  1
 1  1  1  1
 1  1  1  1
 1  1  1  1

@adrhill
Copy link
Owner

adrhill commented May 18, 2024

Yes, exactly. > requires information the primal computation, which can be obtained by using dual-number tracers via local_hessian_pattern.

  • Functions *_pattern return a conservative estimate of sparsity over the entire input domain.
  • Functions local_*_pattern return a less-conservative estimate of sparsity for a specific point in input space.
julia> using LinearAlgebra

julia> using SparseConnectivityTracer

julia> local_hessian_pattern(det, rand(2, 2))
4×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 9 stored entries:
 ⋅  1  ⋅  1
 1  1  1  1
 ⋅  1  ⋅  ⋅
 1  1  ⋅  ⋅

julia> local_hessian_pattern(logdet, rand(2, 2))
4×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 16 stored entries:
 1  1  1  1
 1  1  1  1
 1  1  1  1
 1  1  1  1

@gdalle
Copy link
Collaborator Author

gdalle commented May 18, 2024

I reopened because we should add the local/global option to our ADTypes interface

@adrhill
Copy link
Owner

adrhill commented May 18, 2024

Agreed. This probably requires a PR to ADTypes to differentiate between local and global sparsity detection.

@adrhill
Copy link
Owner

adrhill commented May 18, 2024

We could add local tracing as a kwarg to TracerSparsityDetector(), but this might cause downstream packages like DI to not recompute sparsity patterns when they should.

For this reason, I would argue that local and global sparsity detection should be distinguished between on the level of ADTypes.

@gdalle
Copy link
Collaborator Author

gdalle commented May 18, 2024

Let's switch to #72

@gdalle gdalle closed this as completed May 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants