diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index e058b241e..d9feefeb6 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -1,7 +1,7 @@ module LinearSolveForwardDiffExt using LinearSolve -using LinearSolve: SciMLLinearSolveAlgorithm, __init +using LinearSolve: SciMLLinearSolveAlgorithm, __init, DefaultLinearSolver, DefaultAlgorithmChoice, defaultalg using LinearAlgebra using ForwardDiff using ForwardDiff: Dual, Partials @@ -196,6 +196,24 @@ function SciMLBase.init(prob::DualAbstractLinearProblem, alg::GenericLUFactoriza return __init(prob, alg, args...; kwargs...) end +function SciMLBase.init(prob::DualAbstractLinearProblem, alg::DefaultLinearSolver, args...; kwargs...) + if alg.alg === DefaultAlgorithmChoice.GenericLUFactorization + return __init(prob, alg, args...; kwargs...) + else + return __dual_init(prob, alg, args...; kwargs...) + end +end + +function SciMLBase.init(prob::DualAbstractLinearProblem, alg::Nothing, + args...; + assumptions = OperatorAssumptions(issquare(prob.A)), + kwargs...) + new_A = nodual_value(prob.A) + new_b = nodual_value(prob.b) + SciMLBase.init( + prob, defaultalg(new_A, new_b, assumptions), args...; assumptions, kwargs...) +end + function __dual_init( prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAlgorithm, args...; @@ -225,11 +243,8 @@ function __dual_init( dual_type = get_dual_type(prob.b) end - alg isa LinearSolve.DefaultLinearSolver ? - real_alg = LinearSolve.defaultalg(primal_prob.A, primal_prob.b) : real_alg = alg - non_partial_cache = init( - primal_prob, real_alg, assumptions, args...; + primal_prob, alg, assumptions, args...; alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, sensealg = sensealg, u0 = new_u0, kwargs...) diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl index b7710f9de..329287cd3 100644 --- a/test/forwarddiff_overloads.jl +++ b/test/forwarddiff_overloads.jl @@ -13,7 +13,7 @@ end A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) prob = LinearProblem(A, b) -overload_x_p = solve(prob) +overload_x_p = solve(prob, LUFactorization()) backslash_x_p = A \ b krylov_overload_x_p = solve(prob, KrylovJL_GMRES()) @test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) @@ -42,7 +42,7 @@ prob = LinearProblem(A, b) A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)]) prob = LinearProblem(A, b) -cache = init(prob) +cache = init(prob, LUFactorization()) new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) cache.A = new_A @@ -60,7 +60,7 @@ backslash_x_p = new_A \ new_b A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)]) prob = LinearProblem(A, b) -cache = init(prob) +cache = init(prob, LUFactorization()) new_A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) cache.A = new_A @@ -75,7 +75,7 @@ backslash_x_p = new_A \ b A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) prob = LinearProblem(A, b) -cache = init(prob) +cache = init(prob, LUFactorization()) _, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) cache.b = new_b @@ -99,7 +99,7 @@ original_x_p = A \ b @test ≈(overload_x_p, original_x_p, rtol = 1e-9) prob = LinearProblem(A, b) -cache = init(prob) +cache = init(prob, LUFactorization()) new_A, new_b = h([ForwardDiff.Dual(ForwardDiff.Dual(10.0, 1.0, 0.0), 1.0, 0.0), @@ -155,7 +155,7 @@ end A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) prob = LinearProblem(A, b) -cache = init(prob) +cache = init(prob, LUFactorization()) new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) cache.A = new_A @@ -193,3 +193,5 @@ A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) prob = LinearProblem(A, b) @test init(prob, GenericLUFactorization()) isa LinearSolve.LinearCache + +@test init(prob) isa LinearSolve.LinearCache \ No newline at end of file diff --git a/test/nopre/jet.jl b/test/nopre/jet.jl index 47f11aef1..16f54537d 100644 --- a/test/nopre/jet.jl +++ b/test/nopre/jet.jl @@ -136,8 +136,21 @@ end @testset "JET Tests for creating Dual solutions" begin # Make sure there's no runtime dispatch when making solutions of Dual problems - dual_cache = init(dual_prob) + dual_cache = init(dual_prob, LUFactorization()) ext = Base.get_extension(LinearSolve, :LinearSolveForwardDiffExt) JET.@test_opt ext.linearsolve_dual_solution( [1.0, 1.0, 1.0], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dual_cache) +end + +@testset "JET Tests for default algs with DualLinear Problems" begin + # Test for Default alg choosing for DualLinear Problems + # These should both produce a LinearCache + alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization) + if VERSION < v"1.11" + JET.@test_opt init(dual_prob, alg) broken=true + JET.@test_opt init(dual_prob) broken=true + else + JET.@test_opt init(dual_prob, alg) + JET.@test_opt init(dual_prob) + end end \ No newline at end of file