Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions ext/LinearSolveForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LinearSolveForwardDiffExt

using LinearSolve
using LinearSolve: SciMLLinearSolveAlgorithm, __init
using LinearSolve: SciMLLinearSolveAlgorithm, __init, DefaultLinearSolver, DefaultAlgorithmChoice, defaultalg
using LinearAlgebra
using ForwardDiff
using ForwardDiff: Dual, Partials
Expand Down Expand Up @@ -196,6 +196,24 @@ function SciMLBase.init(prob::DualAbstractLinearProblem, alg::GenericLUFactoriza
return __init(prob, alg, args...; kwargs...)
end

function SciMLBase.init(prob::DualAbstractLinearProblem, alg::DefaultLinearSolver, args...; kwargs...)
if alg.alg === DefaultAlgorithmChoice.GenericLUFactorization
return __init(prob, alg, args...; kwargs...)
else
return __dual_init(prob, alg, args...; kwargs...)
end
Comment on lines +199 to +204
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a JET test

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests added. Looks good on 1.11 but fails on LTS

end

function SciMLBase.init(prob::DualAbstractLinearProblem, alg::Nothing,
args...;
assumptions = OperatorAssumptions(issquare(prob.A)),
kwargs...)
new_A = nodual_value(prob.A)
new_b = nodual_value(prob.b)
SciMLBase.init(
prob, defaultalg(new_A, new_b, assumptions), args...; assumptions, kwargs...)
end

function __dual_init(
prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAlgorithm,
args...;
Expand Down Expand Up @@ -225,11 +243,8 @@ function __dual_init(
dual_type = get_dual_type(prob.b)
end

alg isa LinearSolve.DefaultLinearSolver ?
real_alg = LinearSolve.defaultalg(primal_prob.A, primal_prob.b) : real_alg = alg

non_partial_cache = init(
primal_prob, real_alg, assumptions, args...;
primal_prob, alg, assumptions, args...;
alias = alias, abstol = abstol, reltol = reltol,
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
sensealg = sensealg, u0 = new_u0, kwargs...)
Expand Down
14 changes: 8 additions & 6 deletions test/forwarddiff_overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ end
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

prob = LinearProblem(A, b)
overload_x_p = solve(prob)
overload_x_p = solve(prob, LUFactorization())
backslash_x_p = A \ b
krylov_overload_x_p = solve(prob, KrylovJL_GMRES())
@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9)
Expand Down Expand Up @@ -42,7 +42,7 @@ prob = LinearProblem(A, b)
A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)])

prob = LinearProblem(A, b)
cache = init(prob)
cache = init(prob, LUFactorization())

new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
cache.A = new_A
Expand All @@ -60,7 +60,7 @@ backslash_x_p = new_A \ new_b
A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)])

prob = LinearProblem(A, b)
cache = init(prob)
cache = init(prob, LUFactorization())

new_A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
cache.A = new_A
Expand All @@ -75,7 +75,7 @@ backslash_x_p = new_A \ b
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

prob = LinearProblem(A, b)
cache = init(prob)
cache = init(prob, LUFactorization())

_, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
cache.b = new_b
Expand All @@ -99,7 +99,7 @@ original_x_p = A \ b
@test ≈(overload_x_p, original_x_p, rtol = 1e-9)

prob = LinearProblem(A, b)
cache = init(prob)
cache = init(prob, LUFactorization())

new_A,
new_b = h([ForwardDiff.Dual(ForwardDiff.Dual(10.0, 1.0, 0.0), 1.0, 0.0),
Expand Down Expand Up @@ -155,7 +155,7 @@ end
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

prob = LinearProblem(A, b)
cache = init(prob)
cache = init(prob, LUFactorization())

new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
cache.A = new_A
Expand Down Expand Up @@ -193,3 +193,5 @@ A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

prob = LinearProblem(A, b)
@test init(prob, GenericLUFactorization()) isa LinearSolve.LinearCache

@test init(prob) isa LinearSolve.LinearCache
15 changes: 14 additions & 1 deletion test/nopre/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,21 @@ end

@testset "JET Tests for creating Dual solutions" begin
# Make sure there's no runtime dispatch when making solutions of Dual problems
dual_cache = init(dual_prob)
dual_cache = init(dual_prob, LUFactorization())
ext = Base.get_extension(LinearSolve, :LinearSolveForwardDiffExt)
JET.@test_opt ext.linearsolve_dual_solution(
[1.0, 1.0, 1.0], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dual_cache)
end

@testset "JET Tests for default algs with DualLinear Problems" begin
# Test for Default alg choosing for DualLinear Problems
# These should both produce a LinearCache
alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization)
if VERSION < v"1.11"
JET.@test_opt init(dual_prob, alg) broken=true
JET.@test_opt init(dual_prob) broken=true
else
JET.@test_opt init(dual_prob, alg)
JET.@test_opt init(dual_prob)
end
end
Loading