Skip to content

Commit

Permalink
Merge pull request #262 from avik-pal/ap/banded
Browse files Browse the repository at this point in the history
Special Case for Banded Matrices
  • Loading branch information
ChrisRackauckas authored Oct 26, 2023
2 parents f1bb4a7 + 345ec2a commit 191a237
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 8 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"

[extensions]
NonlinearSolveBandedMatricesExt = "BandedMatrices"
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"

[compat]
BandedMatrices = "1"
ADTypes = "0.2"
ArrayInterface = "6.0.24, 7"
ConcreteStructs = "0.2"
Expand All @@ -58,6 +61,7 @@ Zygote = "0.6"
julia = "1.9"

[extras]
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
Expand All @@ -77,4 +81,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath"]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices"]
8 changes: 8 additions & 0 deletions ext/NonlinearSolveBandedMatricesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module NonlinearSolveBandedMatricesExt

using BandedMatrices, LinearAlgebra, NonlinearSolve, SparseArrays

# This is used if we vcat a Banded Jacobian with a Diagonal Matrix in Levenberg
@inline NonlinearSolve._vcat(B::BandedMatrix, D::Diagonal) = vcat(sparse(B), D)

end
2 changes: 1 addition & 1 deletion src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
rhs_tmp = nothing
else
# Preserve Types
mat_tmp = vcat(J, DᵀD)
mat_tmp = _vcat(J, DᵀD)
fill!(mat_tmp, zero(eltype(u)))
rhs_tmp = vcat(_vec(fu1), _vec(u))
fill!(rhs_tmp, zero(eltype(u)))
Expand Down
13 changes: 8 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,17 @@ function _try_factorize_and_check_singular!(linsolve, X)
end
_try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false

_reshape(x, args...) = reshape(x, args...)
_reshape(x::Number, args...) = x
@inline _reshape(x, args...) = reshape(x, args...)
@inline _reshape(x::Number, args...) = x

@generated function _axpy!(α, x, y)
hasmethod(axpy!, Tuple{α, x, y}) && return :(axpy!(α, x, y))
return :(@. y += α * x)
end

_needs_square_A(_, ::Number) = true
_needs_square_A(_, ::StaticArray) = true
_needs_square_A(alg, _) = LinearSolve.needs_square_A(alg.linsolve)
@inline _needs_square_A(_, ::Number) = true
@inline _needs_square_A(_, ::StaticArray) = true
@inline _needs_square_A(alg, _) = LinearSolve.needs_square_A(alg.linsolve)

# Define special concatenation for certain Array combinations
@inline _vcat(x, y) = vcat(x, y)
2 changes: 2 additions & 0 deletions test/GPU/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"

[compat]
CUDA = "5"
LinearSolve = "2"
NonlinearSolve = "2"
2 changes: 1 addition & 1 deletion test/gpu.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using CUDA, NonlinearSolve
using CUDA, NonlinearSolve, LinearSolve

CUDA.allowscalar(false)

Expand Down
7 changes: 7 additions & 0 deletions test/misc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Miscellaneous Tests
using BandedMatrices, LinearAlgebra, NonlinearSolve, SparseArrays, Test

b = BandedMatrix(Ones(5, 5), (1, 1))
d = Diagonal(ones(5, 5))

@test NonlinearSolve._vcat(b, d) == vcat(sparse(b), d)

0 comments on commit 191a237

Please sign in to comment.