diff --git a/docs/make.jl b/docs/make.jl index 59af495eb..d46044f02 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -29,6 +29,7 @@ makedocs( ], "Advanced" => Any[ "advanced/developing.md" + "advanced/custom.md" ] ] ) diff --git a/docs/src/advanced/custom.md b/docs/src/advanced/custom.md new file mode 100644 index 000000000..973a13f34 --- /dev/null +++ b/docs/src/advanced/custom.md @@ -0,0 +1,55 @@ +# Passing in a Custom Linear Solver +Julia users are building a wide variety of applications in the SciML ecosystem, +often requiring problem-specific handling of their linear solves. As existing solvers in `LinearSolve.jl` may not +be optimally suited for novel applications, it is essential for the linear solve +interface to be easily extendable by users. To that end, the linear solve algorithm +`LinearSolveFunction()` accepts a user-defined function for handling the solve. A +user can pass in their custom linear solve function, say `my_linsolve`, to +`LinearSolveFunction()`. A contrived example of solving a linear system with a custom solver is below. +```julia +using LinearSolve, LinearAlgebra + +function my_linsolve(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...) + if verbose == true + println("solving Ax=b") + end + u = A \ b + return u +end + +prob = LinearProblem(Diagonal(rand(4)), rand(4)) +alg = LinearSolveFunction(my_linsolve), +sol = solve(prob, alg) +``` +The inputs to the function are as follows: +- `A`, the linear operator +- `b`, the right-hand-side +- `u`, the solution initialized as `zero(b)`, +- `p`, a set of parameters +- `newA`, a `Bool` which is `true` if `A` has been modified since last solve +- `Pl`, left-preconditioner +- `Pr`, right-preconditioner +- `solverdata`, solver cache set to `nothing` if solver hasn't been initialized +- `kwargs`, standard SciML keyword arguments such as `verbose`, `maxiters`, +`abstol`, `reltol` +The function `my_linsolve` must accept the above specified arguments, and return +the solution, `u`. As memory for `u` is already allocated, the user may choose +to modify `u` in place as follows: +```julia +function my_linsolve!(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...) + if verbose == true + println("solving Ax=b") + end + u .= A \ b # in place + return u +end + +alg = LinearSolveFunction(my_linsolve!) +sol = solve(prob, alg) +``` +Finally, note that `LinearSolveFunction()` dispatches to the default linear solve +algorithm handling if no arguments are passed in. +```julia +alg = LinearSolveFunction() +sol = solve(prob, alg) # same as solve(prob, nothing) +``` diff --git a/docs/src/solvers/solvers.md b/docs/src/solvers/solvers.md index e3ac46606..21590553c 100644 --- a/docs/src/solvers/solvers.md +++ b/docs/src/solvers/solvers.md @@ -19,9 +19,17 @@ Pardiso.jl's methods are also known to be very efficient sparse linear solvers. As sparse matrices get larger, iterative solvers tend to get more efficient than factorization methods if a lower tolerance of the solution is required. -Krylov.jl works with CPUs and GPUs and tends to be more efficient than other + +IterativeSolvers.jl uses a low-rank Q update in its GMRES so it tends to be +faster than Krylov.jl for CPU-based arrays, but it's only compatible with +CPU-based arrays while Krylov.jl is more general and will support accelerators +like CUDA. Krylov.jl works with CPUs and GPUs and tends to be more efficient than other Krylov-based methods. +Finally, a user can pass a custom function ofr the linear solve using +`LinearSolveFunction()` if existing solvers are not optimal for their application. +The interface is detailed [here](#passing-in-a-custom-linear-solver) + ## Full List of Methods ### RecursiveFactorization.jl diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index b78f55ef9..50eaca349 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -26,11 +26,13 @@ using Reexport abstract type SciMLLinearSolveAlgorithm <: SciMLBase.AbstractLinearAlgorithm end abstract type AbstractFactorization <: SciMLLinearSolveAlgorithm end abstract type AbstractKrylovSubspaceMethod <: SciMLLinearSolveAlgorithm end +abstract type AbstractSolveFunction <: SciMLLinearSolveAlgorithm end # Traits needs_concrete_A(alg::AbstractFactorization) = true needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false +needs_concrete_A(alg::AbstractSolveFunction) = false # Code @@ -39,6 +41,7 @@ include("factorization.jl") include("simplelu.jl") include("iterative_wrappers.jl") include("preconditioners.jl") +include("solve_function.jl") include("default.jl") include("init.jl") @@ -48,6 +51,9 @@ isopenblas() = IS_OPENBLAS[] export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization, GenericLUFactorization, SimpleLUFactorization, RFLUFactorization, UMFPACKFactorization, KLUFactorization + +export LinearSolveFunction + export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES, IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES, IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES, diff --git a/src/common.jl b/src/common.jl index 2f329f621..2e8b6e31d 100644 --- a/src/common.jl +++ b/src/common.jl @@ -65,7 +65,7 @@ function set_cacheval(cache::LinearCache, alg_cache) return cache end -init_cacheval(alg::SciMLLinearSolveAlgorithm, A, b, u) = nothing +init_cacheval(alg::SciMLLinearSolveAlgorithm, args...) = nothing SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,nothing,args...;kwargs...) diff --git a/src/solve_function.jl b/src/solve_function.jl new file mode 100644 index 000000000..e3c01947a --- /dev/null +++ b/src/solve_function.jl @@ -0,0 +1,15 @@ +# +struct LinearSolveFunction{F} <: AbstractSolveFunction + solve_func::F +end + +function SciMLBase.solve(cache::LinearCache, alg::LinearSolveFunction, + args...; kwargs...) + @unpack A,b,u,p,isfresh,Pl,Pr,cacheval = cache + @unpack solve_func = alg + + u = solve_func(A,b,u,p,isfresh,Pl,Pr,cacheval;kwargs...) + cache = set_u(cache, u) + + return SciMLBase.build_linear_solution(alg,cache.u,nothing,cache) +end diff --git a/test/basictests.jl b/test/basictests.jl index 99934d919..faa63c133 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -287,4 +287,35 @@ end @test sol13.u ≈ sol33.u end +@testset "Solve Function" begin + + A1 = rand(n) |> Diagonal; b1 = rand(n); x1 = zero(b1) + A2 = rand(n) |> Diagonal; b2 = rand(n); x2 = zero(b1) + + function sol_func(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...) + if verbose == true + println("out-of-place solve") + end + u = A \ b + end + + function sol_func!(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...) + if verbose == true + println("in-place solve") + end + ldiv!(u,A,b) + end + + prob1 = LinearProblem(A1, b1; u0=x1) + prob2 = LinearProblem(A1, b1; u0=x1) + + for alg in ( + LinearSolveFunction(sol_func), + LinearSolveFunction(sol_func!), + ) + + test_interface(alg, prob1, prob2) + end +end + end # testset