diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 3a2a8c059..cacba0112 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -308,30 +308,54 @@ function SciMLBase.solve!( ) end -# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache +function setA!(dc::DualLinearCache, A) + # Put the Dual-stripped versions in the LinearCache + prop = nodual_value!(getproperty(dc.linear_cache, :A), A) # Update in-place + setproperty!(dc.linear_cache, :A, prop) # Does additional invalidation logic etc. + + # Update partials + setfield!(dc, :dual_A, A) + partial_vals!(getfield(dc, :partials_A), A) # Update in-place + + # Invalidate cache (if setting A or b) + setfield!(dc, :rhs_cache_valid, false) +end +function setb!(dc::DualLinearCache, b) + # Put the Dual-stripped versions in the LinearCache + prop = nodual_value!(getproperty(dc.linear_cache, :b), b) # Update in-place + setproperty!(dc.linear_cache, :b, prop) # Does additional invalidation logic etc. + + # Update partials + setfield!(dc, :dual_b, b) + partial_vals!(getfield(dc, :partials_b), b) # Update in-place + + # Invalidate cache (if setting A or b) + setfield!(dc, :rhs_cache_valid, false) +end +function setu!(dc::DualLinearCache, u) + # Put the Dual-stripped versions in the LinearCache + prop = nodual_value!(getproperty(dc.linear_cache, :u), u) # Update in-place + setproperty!(dc.linear_cache, :u, prop) # Does additional invalidation logic etc. + + # Update partials + setfield!(dc, :dual_u, u) + partial_vals!(getfield(dc, :partials_u), u) # Update in-place +end + function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) # If the property is A or b, also update it in the LinearCache - if sym === :A || sym === :b || sym === :u - setproperty!(dc.linear_cache, sym, nodual_value(val)) + if sym === :A + setA!(dc, val) + elseif sym === :b + setb!(dc, val) + elseif sym === :u + setu!(dc, val) elseif hasfield(DualLinearCache, sym) setfield!(dc, sym, val) elseif hasfield(LinearSolve.LinearCache, sym) setproperty!(dc.linear_cache, sym, val) end - - # Update the partials and invalidate cache if setting A or b - if sym === :A - setfield!(dc, :dual_A, val) - setfield!(dc, :partials_A, partial_vals(val)) - setfield!(dc, :rhs_cache_valid, false) # Invalidate cache - elseif sym === :b - setfield!(dc, :dual_b, val) - setfield!(dc, :partials_b, partial_vals(val)) - setfield!(dc, :rhs_cache_valid, false) # Invalidate cache - elseif sym === :u - setfield!(dc, :dual_u, val) - setfield!(dc, :partials_u, partial_vals(val)) - end + nothing end # "Forwards" getproperty to LinearCache if necessary @@ -360,30 +384,20 @@ partial_vals(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.pa partial_vals(x::Dual{T, V, P}) where {T, V <: Dual, P} = ForwardDiff.partials(x) partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x) partial_vals(x) = nothing +partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place # Add recursive handling for nested dual values nodual_value(x) = x nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x) nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact - -function nodual_value(x::AbstractArray{<:Dual}) - # Create a similar array with the appropriate element type - T = typeof(nodual_value(first(x))) - result = similar(x, T) - - # Fill the result array with values - for i in eachindex(x) - result[i] = nodual_value(x[i]) - end - - return result -end +nodual_value(x::AbstractArray{<:Dual}) = nodual_value!(similar(x, typeof(nodual_value(first(x)))), x) +nodual_value!(out, x) = map!(nodual_value, out, x) # Update in-place function update_partials_list!(partial_matrix::AbstractVector{T}, list_cache) where {T} p = eachindex(first(partial_matrix)) for i in p for j in eachindex(partial_matrix) - list_cache[i][j] = partial_matrix[j][i] + @inbounds list_cache[i][j] = partial_matrix[j][i] end end return list_cache @@ -396,7 +410,7 @@ function update_partials_list!(partial_matrix, list_cache) for k in 1:p for i in 1:m for j in 1:n - list_cache[k][i, j] = partial_matrix[i, j][k] + @inbounds list_cache[k][i, j] = partial_matrix[i, j][k] end end end