From 0388574486316dc5a235064d0f1cac36720429ac Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Fri, 14 Nov 2025 17:06:46 +0100 Subject: [PATCH 1/4] Update (dual) linear cache values/partials in-place --- ext/LinearSolveForwardDiffExt.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 3a2a8c059..1f649f8f3 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -309,10 +309,11 @@ function SciMLBase.solve!( end # If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache -function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) +function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val::AbstractArray) # If the property is A or b, also update it in the LinearCache if sym === :A || sym === :b || sym === :u - setproperty!(dc.linear_cache, sym, nodual_value(val)) + prop = nodual_value!(getproperty(dc.linear_cache, sym), val) # Update in-place + setproperty!(dc.linear_cache, sym, prop) # Does additional invalidation logic etc. elseif hasfield(DualLinearCache, sym) setfield!(dc, sym, val) elseif hasfield(LinearSolve.LinearCache, sym) @@ -322,15 +323,15 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) # Update the partials and invalidate cache if setting A or b if sym === :A setfield!(dc, :dual_A, val) - setfield!(dc, :partials_A, partial_vals(val)) + partial_vals!(getfield(dc, :partials_A), val) # Update in-place setfield!(dc, :rhs_cache_valid, false) # Invalidate cache elseif sym === :b setfield!(dc, :dual_b, val) - setfield!(dc, :partials_b, partial_vals(val)) + partial_vals!(getfield(dc, :partials_b), val) # Update in-place setfield!(dc, :rhs_cache_valid, false) # Invalidate cache elseif sym === :u setfield!(dc, :dual_u, val) - setfield!(dc, :partials_u, partial_vals(val)) + partial_vals!(getfield(dc, :partials_u), val) # Update in-place end end @@ -360,11 +361,13 @@ partial_vals(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.pa partial_vals(x::Dual{T, V, P}) where {T, V <: Dual, P} = ForwardDiff.partials(x) partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x) partial_vals(x) = nothing +partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place # Add recursive handling for nested dual values nodual_value(x) = x nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x) nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact +nodual_value!(out, x) = map!(nodual_value, out, x) # Update in-place function nodual_value(x::AbstractArray{<:Dual}) # Create a similar array with the appropriate element type From 212b4338b1c57a4c3f5b12d3556131cf2c2ff9f9 Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Mon, 17 Nov 2025 12:41:59 +0100 Subject: [PATCH 2/4] Refactor out-of-place nodual_value with new in-place dispatch --- ext/LinearSolveForwardDiffExt.jl | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 1f649f8f3..414bd56cb 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -367,21 +367,9 @@ partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place nodual_value(x) = x nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x) nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact +nodual_value(x::AbstractArray{<:Dual}) = nodual_value!(similar(x, typeof(nodual_value(first(x)))), x) nodual_value!(out, x) = map!(nodual_value, out, x) # Update in-place -function nodual_value(x::AbstractArray{<:Dual}) - # Create a similar array with the appropriate element type - T = typeof(nodual_value(first(x))) - result = similar(x, T) - - # Fill the result array with values - for i in eachindex(x) - result[i] = nodual_value(x[i]) - end - - return result -end - function update_partials_list!(partial_matrix::AbstractVector{T}, list_cache) where {T} p = eachindex(first(partial_matrix)) for i in p From b8aec5e6a4d1fa38f37f27d3c386d9f7ff613f29 Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Fri, 14 Nov 2025 23:11:07 +0100 Subject: [PATCH 3/4] @inbounds in update_partials_list! --- ext/LinearSolveForwardDiffExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 414bd56cb..ad43df02c 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -374,7 +374,7 @@ function update_partials_list!(partial_matrix::AbstractVector{T}, list_cache) wh p = eachindex(first(partial_matrix)) for i in p for j in eachindex(partial_matrix) - list_cache[i][j] = partial_matrix[j][i] + @inbounds list_cache[i][j] = partial_matrix[j][i] end end return list_cache @@ -387,7 +387,7 @@ function update_partials_list!(partial_matrix, list_cache) for k in 1:p for i in 1:m for j in 1:n - list_cache[k][i, j] = partial_matrix[i, j][k] + @inbounds list_cache[k][i, j] = partial_matrix[i, j][k] end end end From b2ec6ff96a405d7a35aa282fdbab7ce13b37e1be Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Mon, 17 Nov 2025 15:06:31 +0100 Subject: [PATCH 4/4] Simplify branching and fix type instability in setproperty! --- ext/LinearSolveForwardDiffExt.jl | 61 ++++++++++++++++++++++---------- 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index ad43df02c..cacba0112 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -308,31 +308,54 @@ function SciMLBase.solve!( ) end -# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache -function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val::AbstractArray) +function setA!(dc::DualLinearCache, A) + # Put the Dual-stripped versions in the LinearCache + prop = nodual_value!(getproperty(dc.linear_cache, :A), A) # Update in-place + setproperty!(dc.linear_cache, :A, prop) # Does additional invalidation logic etc. + + # Update partials + setfield!(dc, :dual_A, A) + partial_vals!(getfield(dc, :partials_A), A) # Update in-place + + # Invalidate cache (if setting A or b) + setfield!(dc, :rhs_cache_valid, false) +end +function setb!(dc::DualLinearCache, b) + # Put the Dual-stripped versions in the LinearCache + prop = nodual_value!(getproperty(dc.linear_cache, :b), b) # Update in-place + setproperty!(dc.linear_cache, :b, prop) # Does additional invalidation logic etc. + + # Update partials + setfield!(dc, :dual_b, b) + partial_vals!(getfield(dc, :partials_b), b) # Update in-place + + # Invalidate cache (if setting A or b) + setfield!(dc, :rhs_cache_valid, false) +end +function setu!(dc::DualLinearCache, u) + # Put the Dual-stripped versions in the LinearCache + prop = nodual_value!(getproperty(dc.linear_cache, :u), u) # Update in-place + setproperty!(dc.linear_cache, :u, prop) # Does additional invalidation logic etc. + + # Update partials + setfield!(dc, :dual_u, u) + partial_vals!(getfield(dc, :partials_u), u) # Update in-place +end + +function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) # If the property is A or b, also update it in the LinearCache - if sym === :A || sym === :b || sym === :u - prop = nodual_value!(getproperty(dc.linear_cache, sym), val) # Update in-place - setproperty!(dc.linear_cache, sym, prop) # Does additional invalidation logic etc. + if sym === :A + setA!(dc, val) + elseif sym === :b + setb!(dc, val) + elseif sym === :u + setu!(dc, val) elseif hasfield(DualLinearCache, sym) setfield!(dc, sym, val) elseif hasfield(LinearSolve.LinearCache, sym) setproperty!(dc.linear_cache, sym, val) end - - # Update the partials and invalidate cache if setting A or b - if sym === :A - setfield!(dc, :dual_A, val) - partial_vals!(getfield(dc, :partials_A), val) # Update in-place - setfield!(dc, :rhs_cache_valid, false) # Invalidate cache - elseif sym === :b - setfield!(dc, :dual_b, val) - partial_vals!(getfield(dc, :partials_b), val) # Update in-place - setfield!(dc, :rhs_cache_valid, false) # Invalidate cache - elseif sym === :u - setfield!(dc, :dual_u, val) - partial_vals!(getfield(dc, :partials_u), val) # Update in-place - end + nothing end # "Forwards" getproperty to LinearCache if necessary