Skip to content

LinearSolve with SciMLSensitivity Solution Handling Requires sol.u #483

Open
@marklau34

Description

@marklau34

I wanted to try something simple with SciMLSensitivity.jl to find the sensitivities of the solution to a LinearProblem with respect to parameters. However, I get an unexpected error, which I outlined in a post.

using Zygote
using SciMLSensitivity
using ForwardDiff
using LinearSolve
import Random
Random.seed!(1234)

N = 2

function test_func(x::AbstractVector{T}) where {T<:Real}
    A = reshape(x[1:N*N], (N,N))
    b = x[N*N+1:end]
    # This works:
    # sol = A\b
    # But this seems to not work:
    prob = LinearProblem(A, b)
    sol = solve(prob)
    return sum(sol)
end

# Random Point
x0 = rand(N*N+N)

# Try with Zygote
grad_zygote = Zygote.gradient(test_func, x0)
display(grad_zygote[1])

# Compare with ForwardDiff
grad_forwarddiff = ForwardDiff.gradient(test_func, x0)
display(grad_forwarddiff)

The following error occurs:

ERROR: type Fill has no field u
Stacktrace:
  [1] getproperty
    @ .\Base.jl:37 [inlined]
  [2] (::LinearSolve.var"#∇linear_solve#103"{})(∂sol::FillArrays.Fill{…})
    @ LinearSolve C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\adjoint.jl:58
  [3] ZBack
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
  [4] (::Zygote.var"#291#292"{Tuple{}, Zygote.ZBack{}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206
  [5] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
  [6] #solve#5
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:188 [inlined]
  [7] (::Zygote.Pullback{Tuple{…}, Any})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
  [8] #291
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
  [9] #2169#back
    @ C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [10] solve
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:186 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [12] #291
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
 [13] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
 [14] #solve#4
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:183 [inlined]
 [15] (::Zygote.Pullback{Tuple{…}, Any})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [16] #291
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
 [17] #2169#back
    @ C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [18] solve
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:182 [inlined]
 [19] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [20] test_func
    @ c:\GitCode\marklau\sandbox\sensitivity\test_zygote.jl:17 [inlined]
 [21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [22] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float64)
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:91
 [23] gradient(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:148
 [24] top-level scope
    @ c:\GitCode\marklau\sandbox\sensitivity\test_zygote.jl:25
Some type information was truncated. Use `show(err)` to see complete types.

As suggested by @avik-pal, returning sum(sol.u) fixes the problem and this may be a bug not handling thegetindex(sol, sym) rrule correctly.

Is this a bug or is there as reason sol.u should be used in this case?

Metadata

Metadata

Assignees

No one assigned

    Labels

    good first issueGood for newcomersquestionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions