Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand Down Expand Up @@ -70,6 +71,7 @@ LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
LinearSolveKrylovKitExt = "KrylovKit"
LinearSolveMetalExt = "Metal"
LinearSolveMooncakeExt = "Mooncake"
LinearSolvePardisoExt = ["Pardiso", "SparseArrays"]
LinearSolveRecursiveFactorizationExt = "RecursiveFactorization"
LinearSolveSparseArraysExt = "SparseArrays"
Expand Down Expand Up @@ -112,6 +114,7 @@ MKL_jll = "2019, 2020, 2021, 2022, 2023, 2024, 2025"
MPI = "0.20"
Markdown = "1.10"
Metal = "1.4"
Mooncake = "0.4"
MultiFloats = "2.3"
OpenBLAS_jll = "0.3"
Pardiso = "1"
Expand Down Expand Up @@ -162,6 +165,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand Down
32 changes: 32 additions & 0 deletions ext/LinearSolveMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
module LinearSolveMooncakeExt

using Mooncake
using Mooncake: @from_chainrules, MinimalCtx, ReverseMode, NoRData, increment!!
using LinearSolve: LinearSolve, SciMLLinearSolveAlgorithm, init, solve!, LinearProblem,
LinearCache, AbstractKrylovSubspaceMethod, DefaultLinearSolver,
defaultalg_adjoint_eval, solve
using LinearSolve.LinearAlgebra
using SciMLBase

@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve), LinearProblem, Nothing} true ReverseMode
@from_chainrules MinimalCtx Tuple{
typeof(SciMLBase.solve), LinearProblem, SciMLLinearSolveAlgorithm} true ReverseMode
@from_chainrules MinimalCtx Tuple{
Type{<:LinearProblem}, AbstractMatrix, AbstractVector, SciMLBase.NullParameters} true ReverseMode

function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearProblem)
f.data.A .+= t.A
f.data.b .+= t.b

return NoRData()
end

function Mooncake.to_cr_tangent(x::Mooncake.PossiblyUninitTangent{T}) where {T}
if Mooncake.is_init(x)
return Mooncake.to_cr_tangent(x.tangent)
else
error("Trying to convert uninitialized tangent to ChainRules tangent.")
end
end

end
5 changes: 2 additions & 3 deletions src/KLU/klu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const libsuitesparseconfig = :libsuitesparseconfig
using Base: Ptr, Cvoid, Cint, Cdouble, Cchar, Csize_t
include("wrappers.jl")

import Base: (\), size, getproperty, setproperty!, propertynames, show,
import Base: size, getproperty, setproperty!, show,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this about?

Copy link
Member Author

@AstitvaAggarwal AstitvaAggarwal Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was getting ExplicitImports, Aqua errors in all / Tests - NoPre since the import were not being technically directly used. In the same file, they are used as Base.propertynames etc. so I removed the imports.

Copy link
Member Author

@AstitvaAggarwal AstitvaAggarwal Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy, eachindex, view, sortperm, unsafe_load, zeros, convert, eltype,
length, parent, stride, finalizer, Complex, complex, imag, real, map!,
summary, println, oneunit, sizeof, isdefined, setfield!, getfield,
Expand All @@ -35,8 +35,7 @@ function increment!(A::AbstractArray{T}) where {T <: Integer}
end
increment(A::AbstractArray{<:Integer}) = increment!(copy(A))

using LinearAlgebra: LinearAlgebra, ldiv!, Adjoint, Transpose, Factorization
import LinearAlgebra: issuccess
using LinearAlgebra: LinearAlgebra, Adjoint, Transpose

const AdjointFact = isdefined(LinearAlgebra, :AdjointFactorization) ?
LinearAlgebra.AdjointFactorization : Adjoint
Expand Down
5 changes: 2 additions & 3 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ end
@generated function defaultalg_adjoint_eval(cache::LinearCache, dy)
ex = :()
for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
newex = if alg == Symbol(DefaultAlgorithmChoice.RFLUFactorization)
newex = if alg in Symbol.((DefaultAlgorithmChoice.RFLUFactorization, DefaultAlgorithmChoice.GenericLUFactorization))
quote
getproperty(cache.cacheval, $(Meta.quot(alg)))[1]' \ dy
end
Expand All @@ -661,8 +661,7 @@ end
DefaultAlgorithmChoice.SVDFactorization,
DefaultAlgorithmChoice.CholeskyFactorization,
DefaultAlgorithmChoice.NormalCholeskyFactorization,
DefaultAlgorithmChoice.QRFactorizationPivoted,
DefaultAlgorithmChoice.GenericLUFactorization))
DefaultAlgorithmChoice.QRFactorizationPivoted))
quote
getproperty(cache.cacheval, $(Meta.quot(alg)))' \ dy
end
Expand Down
3 changes: 2 additions & 1 deletion test/nopre/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
155 changes: 155 additions & 0 deletions test/nopre/mooncake.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
using ForwardDiff
using LinearSolve, LinearAlgebra, Test
using FiniteDiff, RecursiveFactorization
using LazyArrays: BroadcastArray
using Mooncake

# first test
n = 4
A = rand(n, n);
b1 = rand(n);

function f(A, b1; alg = LUFactorization())
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

s1 = sol1.u
norm(s1)
end

f_primal = f(A, b1) # Uses BLAS

cache = prepare_gradient_cache(f, (copy(A), copy(b1))...)
value, gradient = Mooncake.value_and_gradient!!(cache, f, (copy(A), copy(b1))...)

dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))

# Mooncake
@test value ≈ f_primal
@test gradient[2] ≈ dA2
@test gradient[3] ≈ db12

# Second test
A = rand(n, n);
b1 = rand(n);

_ff = (x,
y) -> f(x,
y;
alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization))
f_primal = _ff(copy(A), copy(b1))

cache = prepare_gradient_cache(_ff, (copy(A), copy(b1))...)
value, gradient = Mooncake.value_and_gradient!!(cache, _ff, (copy(A), copy(b1))...)

dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))

# Mooncake
@test value ≈ f_primal
@test gradient[2] ≈ dA2
@test gradient[3] ≈ db12

# third test
# Test complex numbers
A = rand(n, n) + 1im * rand(n, n);
b1 = rand(n) + 1im * rand(n);

function f3(A, b1, b2; alg = KrylovJL_GMRES())
prob = LinearProblem(A, b1)
sol1 = solve(prob, alg)
prob = LinearProblem(A, b2)
sol2 = solve(prob, alg)
norm(sol1.u .+ sol2.u)
end

# Mooncake needs atomic Complex Number tangents instead of NamedTuples.
# cache = Mooncake.prepare_gradient_cache(f3, (copy(A), copy(b1), copy(b1))...)
# results = Mooncake.value_and_gradient!!(cache, f3, (copy(A), copy(b1), copy(b1))...)

# dA2 = FiniteDiff.finite_difference_gradient(
# x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
# db12 = FiniteDiff.finite_difference_gradient(
# x -> f3(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
# db22 = FiniteDiff.finite_difference_gradient(
# x -> f3(eltype(x).(A), eltype(x).(b1), x), copy(b1))

# @test f3(A, b1, b1) ≈ results[1]
# @test dA2 ≈ results[2][2]
# @test db12 ≈ results[2][3]
# @test db22 ≈ results[2][4]

# fourth test
A = rand(n, n);
b1 = rand(n);

function f4(A, b1, b2; alg = LUFactorization())
prob = LinearProblem(A, b1)
sol1 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_LSMR()))
prob = LinearProblem(A, b2)
sol2 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_GMRES()))
norm(sol1.u .+ sol2.u)
end

cache = Mooncake.prepare_gradient_cache(f4, (copy(A), copy(b1), copy(b1))...)
results = Mooncake.value_and_gradient!!(cache, f4, (copy(A), copy(b1), copy(b1))...)

dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1))

@test f4(A, b1, b1) ≈ results[1]
@test dA2 ≈ results[2][2]
@test db12 ≈ results[2][3]
@test db22 ≈ results[2][4]

# fifth test
A = rand(n, n);
b1 = rand(n);
for alg in (
LUFactorization(),
RFLUFactorization(),
KrylovJL_GMRES()
)
@show alg
function fb(b)
prob = LinearProblem(A, b)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fb(b1)

fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
@show fd_jac

cache = Mooncake.prepare_gradient_cache(fb, copy(b1))
results = Mooncake.value_and_gradient!!(cache, fb, copy(b1))
@show results

@test results[1] ≈ fb(b1)
@test results[2][2] ≈ fd_jac rtol = 1e-5

function fA(A)
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fA(A)

fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
@show fd_jac

cache = Mooncake.prepare_gradient_cache(fA, copy(A))
results = Mooncake.value_and_gradient!!(cache, fA, copy(A))
@show results
mooncake_gradient = results[2][2] |> vec

@test results[1] ≈ fA(A)
@test mooncake_gradient ≈ fd_jac rtol = 1e-5
end
8 changes: 6 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ if GROUP == "NoPre" && isempty(VERSION.prerelease)
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Pkg.instantiate()
@time @safetestset "Quality Assurance" include("qa.jl")
@time @safetestset "Enzyme Derivative Rules" include("nopre/enzyme.jl")
@time @safetestset "Mooncake Derivative Rules" include("nopre/mooncake.jl")
@time @safetestset "JET Tests" include("nopre/jet.jl")
@time @safetestset "Static Arrays" include("nopre/static_arrays.jl")
@time @safetestset "Caching Allocation Tests" include("nopre/caching_allocation_tests.jl")
@time @safetestset "Enzyme Derivative Rules" include("nopre/enzyme.jl")
end

if GROUP == "DefaultsLoading"
Expand All @@ -39,7 +40,10 @@ end

if GROUP == "LinearSolveAutotune"
Pkg.activate(joinpath(dirname(@__DIR__), "lib", GROUP))
Pkg.test(GROUP, julia_args=["--check-bounds=auto", "--compiled-modules=yes", "--depwarn=yes"], force_latest_compatible_version=false, allow_reresolve=true)
Pkg.test(GROUP,
julia_args = ["--check-bounds=auto", "--compiled-modules=yes", "--depwarn=yes"],
force_latest_compatible_version = false,
allow_reresolve = true)
end

if GROUP == "Preferences"
Expand Down
Loading