Skip to content

Commit

Permalink
Merge 85f8fd6 into 1186046
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed Apr 6, 2022
2 parents 1186046 + 85f8fd6 commit 9791f56
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ using Reexport
abstract type SciMLLinearSolveAlgorithm <: SciMLBase.AbstractLinearAlgorithm end
abstract type AbstractFactorization <: SciMLLinearSolveAlgorithm end
abstract type AbstractKrylovSubspaceMethod <: SciMLLinearSolveAlgorithm end
abstract type AbstractSolveFunction <: SciMLLinearSolveAlgorithm end

# Traits

needs_concrete_A(alg::AbstractFactorization) = true
needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
needs_concrete_A(alg::AbstractSolveFunction) = false

# Code

Expand All @@ -39,6 +41,7 @@ include("factorization.jl")
include("simplelu.jl")
include("iterative_wrappers.jl")
include("preconditioners.jl")
include("solve_function.jl")
include("default.jl")
include("init.jl")

Expand All @@ -48,6 +51,9 @@ isopenblas() = IS_OPENBLAS[]
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
UMFPACKFactorization, KLUFactorization

export LinearSolveFunction

export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES,
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES,
Expand Down
2 changes: 1 addition & 1 deletion src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function set_cacheval(cache::LinearCache, alg_cache)
return cache
end

init_cacheval(alg::SciMLLinearSolveAlgorithm, A, b, u) = nothing
init_cacheval(alg::SciMLLinearSolveAlgorithm, args...) = nothing

SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,nothing,args...;kwargs...)

Expand Down
19 changes: 19 additions & 0 deletions src/solve_function.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#
function DEFAULT_LINEAR_SOLVE(A,b,u,p,newA,Pl,Pr,solverdata;kwargs...)
solve(LinearProblem(A, b; u0=u); p=p, kwargs...).u
end

Base.@kwdef struct LinearSolveFunction{F} <: AbstractSolveFunction
solve_func::F = DEFAULT_LINEAR_SOLVE
end

function SciMLBase.solve(cache::LinearCache, alg::LinearSolveFunction,
args...; kwargs...)
@unpack A,b,u,p,isfresh,Pl,Pr,cacheval = cache
@unpack solve_func = alg

u = solve_func(A,b,u,p,isfresh,Pl,Pr,cacheval;kwargs...)
cache = set_u(cache, u)

return SciMLBase.build_linear_solution(alg,cache.u,nothing,cache)
end
31 changes: 31 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,4 +287,35 @@ end
@test sol13.u sol33.u
end

@testset "Solve Function" begin

A1 = rand(n) |> Diagonal; b1 = rand(n); x1 = zero(b1)
A2 = rand(n) |> Diagonal; b2 = rand(n); x2 = zero(b1)

function sol_func(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...)
if verbose == true
println("out-of-place solve")
end
u = A \ b
end

function sol_func!(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...)
if verbose == true
println("in-place solve")
end
ldiv!(u,A,b)
end

prob1 = LinearProblem(A1, b1; u0=x1)
prob2 = LinearProblem(A1, b1; u0=x1)

for alg in (
LinearSolveFunction(),
LinearSolveFunction(sol_func),
LinearSolveFunction(sol_func!),
)
test_interface(alg, prob1, prob2)
end
end

end # testset

0 comments on commit 9791f56

Please sign in to comment.