Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ end

function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions{Bool})
if assump.issq
@static if VERSION>=v"1.11"
@static if VERSION >= v"1.11"
DirectLdiv!()
else
DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization)
Expand Down Expand Up @@ -239,10 +239,11 @@ Get the tuned algorithm preference for the given element type and matrix size.
Returns `nothing` if no preference exists. Uses preloaded constants for efficiency.
Fast path when no preferences are set.
"""
@inline function get_tuned_algorithm(::Type{eltype_A}, ::Type{eltype_b}, matrix_size::Integer) where {eltype_A, eltype_b}
@inline function get_tuned_algorithm(
::Type{eltype_A}, ::Type{eltype_b}, matrix_size::Integer) where {eltype_A, eltype_b}
# Determine the element type to use for preference lookup
target_eltype = eltype_A !== Nothing ? eltype_A : eltype_b

# Determine size category based on matrix size (matching LinearSolveAutotune categories)
size_category = if matrix_size <= 20
:tiny
Expand All @@ -255,10 +256,10 @@ Fast path when no preferences are set.
else
:big
end

# Fast path: if no preferences are set, return nothing immediately
AUTOTUNE_PREFS_SET || return nothing

# Look up the tuned algorithm from preloaded constants with type specialization
return _get_tuned_algorithm_impl(target_eltype, size_category)
end
Expand Down Expand Up @@ -286,11 +287,10 @@ end

@inline _get_tuned_algorithm_impl(::Type, ::Symbol) = nothing # Fallback for other types



# Convenience method for when A is nothing - delegate to main implementation
@inline get_tuned_algorithm(::Type{Nothing}, ::Type{eltype_b}, matrix_size::Integer) where {eltype_b} =
get_tuned_algorithm(eltype_b, eltype_b, matrix_size)
@inline get_tuned_algorithm(::Type{Nothing},
::Type{eltype_b},
matrix_size::Integer) where {eltype_b} = get_tuned_algorithm(eltype_b, eltype_b, matrix_size)

# Allows A === nothing as a stand-in for dense matrix
function defaultalg(A, b, assump::OperatorAssumptions{Bool})
Expand All @@ -304,7 +304,7 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool})
ArrayInterface.can_setindex(b) &&
(__conditioning(assump) === OperatorCondition.IllConditioned ||
__conditioning(assump) === OperatorCondition.WellConditioned)

# Small matrix override - always use GenericLUFactorization for tiny problems
if length(b) <= 10
DefaultAlgorithmChoice.GenericLUFactorization
Expand All @@ -313,7 +313,7 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool})
matrix_size = length(b)
eltype_A = A === nothing ? Nothing : eltype(A)
tuned_alg = get_tuned_algorithm(eltype_A, eltype(b), matrix_size)

if tuned_alg !== nothing
tuned_alg
elseif appleaccelerate_isavailable() && b isa Array &&
Expand Down Expand Up @@ -513,7 +513,7 @@ end
newex = quote
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
@SciMLMessage("LU factorization failed, falling back to QR factorization. `A` is potentially rank-deficient.",
@SciMLMessage("LU factorization failed, falling back to QR factorization. `A` is potentially rank-deficient.",
cache.verbose, :default_lu_fallback)
sol = SciMLBase.solve!(
cache, QRFactorization(ColumnNorm()), args...; kwargs...)
Expand Down Expand Up @@ -641,7 +641,8 @@ end
@generated function defaultalg_adjoint_eval(cache::LinearCache, dy)
ex = :()
for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
newex = if alg in Symbol.((DefaultAlgorithmChoice.RFLUFactorization, DefaultAlgorithmChoice.GenericLUFactorization))
newex = if alg in Symbol.((DefaultAlgorithmChoice.RFLUFactorization,
DefaultAlgorithmChoice.GenericLUFactorization))
quote
getproperty(cache.cacheval, $(Meta.quot(alg)))[1]' \ dy
end
Expand Down
76 changes: 72 additions & 4 deletions src/solve_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,29 @@ function SciMLBase.solve!(cache::LinearCache, alg::LinearSolveFunction,
end

"""
DirectLdiv! <: AbstractSolveFunction
DirectLdiv!{cache}() <: AbstractSolveFunction

A simple linear solver that directly applies the left-division operator (`\\`)
A simple linear solver that directly applies the left-division operator (`\\`)
to solve the linear system. This algorithm calls `ldiv!(u, A, b)` which computes
`u = A \\ b` in-place.

## Type Parameter

- `cache::Bool`: Whether to cache a copy of the matrix for use with ldiv!.
When `true`, a copy of the matrix is stored during `init` and used during `solve!`,
preventing mutation of `cache.A`. Default is `true` for matrix types where `ldiv!`
mutates the input (e.g., `Tridiagonal`, `SymTridiagonal`).

## Usage

```julia
# Default: automatically caches for matrix types that need it
alg = DirectLdiv!()
sol = solve(prob, alg)

# Explicit caching control
alg = DirectLdiv!(Val(true)) # Always cache
alg = DirectLdiv!(Val(false)) # Never cache (may mutate A)
```

## Notes
Expand All @@ -75,12 +87,68 @@ sol = solve(prob, alg)
- Performance depends on the specific matrix type and its `ldiv!` implementation
- No preconditioners or advanced numerical techniques are applied
- Best used for small to medium problems or when `A` has special structure
- For `Tridiagonal` and `SymTridiagonal`, `ldiv!` performs in-place LU factorization
which mutates the matrix. Use `cache=true` (default) to preserve `cache.A`.
"""
struct DirectLdiv! <: AbstractSolveFunction end
struct DirectLdiv!{cache} <: AbstractSolveFunction
function DirectLdiv!(::Val{cache} = Val(true)) where {cache}
new{cache}()
end
end

function SciMLBase.solve!(cache::LinearCache, alg::DirectLdiv!, args...; kwargs...)
# Default solve! for non-caching or matrix types that don't need caching
function SciMLBase.solve!(cache::LinearCache, alg::DirectLdiv!{false}, args...; kwargs...)
(; A, b, u) = cache
ldiv!(u, A, b)
return SciMLBase.build_linear_solution(alg, u, nothing, cache)
end

# For caching DirectLdiv! with general matrices, just use regular ldiv!
# (caching is only needed for specific matrix types like Tridiagonal)
function SciMLBase.solve!(cache::LinearCache, alg::DirectLdiv!{true}, args...; kwargs...)
(; A, b, u) = cache
ldiv!(u, A, b)
return SciMLBase.build_linear_solution(alg, u, nothing, cache)
end

# Specialized handling for Tridiagonal matrices to avoid mutating cache.A
# ldiv! for Tridiagonal performs in-place LU factorization which would corrupt the cache.
# We cache a copy of the Tridiagonal matrix and use that for the factorization.
# See https://github.com/SciML/LinearSolve.jl/issues/825

function init_cacheval(alg::DirectLdiv!{true}, A::Tridiagonal, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Union{LinearVerbosity, Bool},
assumptions::OperatorAssumptions)
# Allocate a copy of the Tridiagonal matrix to use as workspace for ldiv!
return copy(A)
end

function init_cacheval(alg::DirectLdiv!{true}, A::SymTridiagonal, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
assumptions::OperatorAssumptions)
# SymTridiagonal also gets mutated by ldiv!, cache a copy
return copy(A)
end

function SciMLBase.solve!(cache::LinearCache{<:Tridiagonal}, alg::DirectLdiv!{true},
args...; kwargs...)
(; A, b, u, cacheval) = cache
# Copy current A values into the cached workspace (non-allocating)
copyto!(cacheval.dl, A.dl)
copyto!(cacheval.d, A.d)
copyto!(cacheval.du, A.du)
# Perform ldiv! on the copy, preserving the original A
ldiv!(u, cacheval, b)
return SciMLBase.build_linear_solution(alg, u, nothing, cache)
end

function SciMLBase.solve!(cache::LinearCache{<:SymTridiagonal}, alg::DirectLdiv!{true},
args...; kwargs...)
(; A, b, u, cacheval) = cache
# Copy current A values into the cached workspace (non-allocating)
copyto!(cacheval.dv, A.dv)
copyto!(cacheval.ev, A.ev)
# Perform ldiv! on the copy, preserving the original A
ldiv!(u, cacheval, b)
return SciMLBase.build_linear_solution(alg, u, nothing, cache)
end
81 changes: 69 additions & 12 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ A4 = A2 .|> ComplexF32
b4 = b2 .|> ComplexF32
x4 = x2 .|> ComplexF32

A5_ = A - 0.01Tridiagonal(ones(n,n)) + sparse([1], [8], 0.5, n,n)
A5_ = A - 0.01Tridiagonal(ones(n, n)) + sparse([1], [8], 0.5, n, n)
A5 = sparse(transpose(A5_) * A5_)
x5 = zeros(n)
u5 = ones(n)
Expand All @@ -46,7 +46,7 @@ prob3 = LinearProblem(A3, b3; u0 = x3)
prob4 = LinearProblem(A4, b4; u0 = x4)
prob5 = LinearProblem(A5, b5)

cache_kwargs = (;abstol = 1e-8, reltol = 1e-8, maxiter = 30)
cache_kwargs = (; abstol = 1e-8, reltol = 1e-8, maxiter = 30)

function test_interface(alg, prob1, prob2)
A1, b1 = prob1.A, prob1.b
Expand Down Expand Up @@ -79,10 +79,10 @@ end

function test_tolerance_update(alg, prob, u)
cache = init(prob, alg)
LinearSolve.update_tolerances!(cache; reltol = 1e-2, abstol=1e-8)
LinearSolve.update_tolerances!(cache; reltol = 1e-2, abstol = 1e-8)
u1 = copy(solve!(cache).u)

LinearSolve.update_tolerances!(cache; reltol = 1e-8, abstol=1e-8)
LinearSolve.update_tolerances!(cache; reltol = 1e-8, abstol = 1e-8)
u2 = solve!(cache).u

@test norm(u2 - u) < norm(u1 - u)
Expand Down Expand Up @@ -303,30 +303,86 @@ end
ρ = 0.95
A_tri = SymTridiagonal(ones(k) .+ ρ^2, -ρ * ones(k-1))
b = rand(k)

# Test with explicit LDLtFactorization
prob_tri = LinearProblem(A_tri, b)
sol = solve(prob_tri, LDLtFactorization())
@test A_tri * sol.u b

# Test that default algorithm uses LDLtFactorization for SymTridiagonal
default_alg = LinearSolve.defaultalg(A_tri, b, OperatorAssumptions(true))
@test default_alg isa LinearSolve.DefaultLinearSolver
@test default_alg.alg == LinearSolve.DefaultAlgorithmChoice.LDLtFactorization

# Test that the factorization is cached and reused
cache = init(prob_tri, LDLtFactorization())
sol1 = solve!(cache)
@test A_tri * sol1.u b
@test !cache.isfresh # Cache should not be fresh after first solve

# Solve again with same matrix to ensure cache is reused
cache.b = rand(k) # Change RHS
sol2 = solve!(cache)
@test A_tri * sol2.u cache.b
@test !cache.isfresh # Cache should still not be fresh
end

@testset "Tridiagonal cache not mutated (issue #825)" begin
# Test that solving with Tridiagonal does not mutate cache.A
# See https://github.com/SciML/LinearSolve.jl/issues/825
k = 6
lower = ones(k - 1)
diag = -2 * ones(k)
upper = ones(k - 1)
A_tri = Tridiagonal(lower, diag, upper)
b = rand(k)

# Store original matrix values for comparison
A_orig = Tridiagonal(copy(lower), copy(diag), copy(upper))

# Test that default algorithm uses DirectLdiv! for Tridiagonal on Julia 1.11+
default_alg = LinearSolve.defaultalg(A_tri, b, OperatorAssumptions(true))
@static if VERSION >= v"1.11"
@test default_alg isa DirectLdiv!
else
@test default_alg isa LinearSolve.DefaultLinearSolver
@test default_alg.alg == LinearSolve.DefaultAlgorithmChoice.LUFactorization
end

# Test with default algorithm
prob_tri = LinearProblem(A_tri, b)
cache = init(prob_tri)

# Verify solution is correct
sol1 = solve!(cache)
@test A_orig * sol1.u b

# Verify cache.A is not mutated
@test cache.A A_orig

# Verify multiple solves give correct answers
b2 = rand(k)
cache.b = b2
sol2 = solve!(cache)
@test A_orig * sol2.u b2

# Cache.A should still be unchanged
@test cache.A A_orig

# Verify solve! allocates minimally after first solve (warm-up)
# The small allocation (48 bytes) is from the return type construction,
# same as other factorization methods like LUFactorization
@static if VERSION >= v"1.11"
# Warm up
for _ in 1:3
solve!(cache)
end
# Test minimal allocations (same as LUFactorization)
allocs = @allocated solve!(cache)
@test allocs <= 64 # Allow small overhead from return type
end
end

test_algs = [
LUFactorization(),
QRFactorization(),
Expand Down Expand Up @@ -680,8 +736,10 @@ end
prob3 = LinearProblem(op1, b1; u0 = x1)
prob4 = LinearProblem(op2, b2; u0 = x2)

@test LinearSolve.defaultalg(op1, x1).alg === LinearSolve.DefaultAlgorithmChoice.DirectLdiv!
@test LinearSolve.defaultalg(op2, x2).alg === LinearSolve.DefaultAlgorithmChoice.DirectLdiv!
@test LinearSolve.defaultalg(op1, x1).alg ===
LinearSolve.DefaultAlgorithmChoice.DirectLdiv!
@test LinearSolve.defaultalg(op2, x2).alg ===
LinearSolve.DefaultAlgorithmChoice.DirectLdiv!
@test LinearSolve.defaultalg(op3, x1).alg ===
LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES
@test LinearSolve.defaultalg(op4, x2).alg ===
Expand Down Expand Up @@ -800,7 +858,6 @@ end
reinit!(cache; A = B1, b = b1)
u = solve!(cache)
@test norm(u - u0, Inf) < 1.0e-8

end

@testset "ParallelSolves" begin
Expand All @@ -818,7 +875,7 @@ end
for i in 1:2
@test sol[i] U[i]
end

Threads.@threads for i in 1:2
sol[i] = solve(LinearProblem(A_sparse, B[i]), KLUFactorization())
end
Expand Down
Loading