Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinearSolve with SciMLSensitivity Solution Handling Requires sol.u #483

Open
marklau34 opened this issue Mar 20, 2024 · 1 comment
Open
Labels
good first issue Good for newcomers question Further information is requested

Comments

@marklau34
Copy link

I wanted to try something simple with SciMLSensitivity.jl to find the sensitivities of the solution to a LinearProblem with respect to parameters. However, I get an unexpected error, which I outlined in a post.

using Zygote
using SciMLSensitivity
using ForwardDiff
using LinearSolve
import Random
Random.seed!(1234)

N = 2

function test_func(x::AbstractVector{T}) where {T<:Real}
    A = reshape(x[1:N*N], (N,N))
    b = x[N*N+1:end]
    # This works:
    # sol = A\b
    # But this seems to not work:
    prob = LinearProblem(A, b)
    sol = solve(prob)
    return sum(sol)
end

# Random Point
x0 = rand(N*N+N)

# Try with Zygote
grad_zygote = Zygote.gradient(test_func, x0)
display(grad_zygote[1])

# Compare with ForwardDiff
grad_forwarddiff = ForwardDiff.gradient(test_func, x0)
display(grad_forwarddiff)

The following error occurs:

ERROR: type Fill has no field u
Stacktrace:
  [1] getproperty
    @ .\Base.jl:37 [inlined]
  [2] (::LinearSolve.var"#∇linear_solve#103"{})(∂sol::FillArrays.Fill{…})
    @ LinearSolve C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\adjoint.jl:58
  [3] ZBack
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
  [4] (::Zygote.var"#291#292"{Tuple{}, Zygote.ZBack{}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206
  [5] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
  [6] #solve#5
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:188 [inlined]
  [7] (::Zygote.Pullback{Tuple{…}, Any})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
  [8] #291
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
  [9] #2169#back
    @ C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [10] solve
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:186 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [12] #291
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
 [13] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
 [14] #solve#4
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:183 [inlined]
 [15] (::Zygote.Pullback{Tuple{…}, Any})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [16] #291
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
 [17] #2169#back
    @ C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [18] solve
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:182 [inlined]
 [19] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [20] test_func
    @ c:\GitCode\marklau\sandbox\sensitivity\test_zygote.jl:17 [inlined]
 [21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [22] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float64)
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:91
 [23] gradient(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:148
 [24] top-level scope
    @ c:\GitCode\marklau\sandbox\sensitivity\test_zygote.jl:25
Some type information was truncated. Use `show(err)` to see complete types.

As suggested by @avik-pal, returning sum(sol.u) fixes the problem and this may be a bug not handling thegetindex(sol, sym) rrule correctly.

Is this a bug or is there as reason sol.u should be used in this case?

@marklau34 marklau34 added the question Further information is requested label Mar 20, 2024
@ChrisRackauckas
Copy link
Member

We just need a similar getindex overload https://github.com/SciML/SciMLBase.jl/blob/master/ext/SciMLBaseZygoteExt.jl#L97-L109

@ChrisRackauckas ChrisRackauckas added the good first issue Good for newcomers label Mar 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants