Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 46 additions & 32 deletions ext/LinearSolveForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,30 +308,54 @@ function SciMLBase.solve!(
)
end

# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
function setA!(dc::DualLinearCache, A)
# Put the Dual-stripped versions in the LinearCache
prop = nodual_value!(getproperty(dc.linear_cache, :A), A) # Update in-place
setproperty!(dc.linear_cache, :A, prop) # Does additional invalidation logic etc.

# Update partials
setfield!(dc, :dual_A, A)
partial_vals!(getfield(dc, :partials_A), A) # Update in-place

# Invalidate cache (if setting A or b)
setfield!(dc, :rhs_cache_valid, false)
end
function setb!(dc::DualLinearCache, b)
# Put the Dual-stripped versions in the LinearCache
prop = nodual_value!(getproperty(dc.linear_cache, :b), b) # Update in-place
setproperty!(dc.linear_cache, :b, prop) # Does additional invalidation logic etc.

# Update partials
setfield!(dc, :dual_b, b)
partial_vals!(getfield(dc, :partials_b), b) # Update in-place

# Invalidate cache (if setting A or b)
setfield!(dc, :rhs_cache_valid, false)
end
function setu!(dc::DualLinearCache, u)
# Put the Dual-stripped versions in the LinearCache
prop = nodual_value!(getproperty(dc.linear_cache, :u), u) # Update in-place
setproperty!(dc.linear_cache, :u, prop) # Does additional invalidation logic etc.

# Update partials
setfield!(dc, :dual_u, u)
partial_vals!(getfield(dc, :partials_u), u) # Update in-place
end

function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
# If the property is A or b, also update it in the LinearCache
if sym === :A || sym === :b || sym === :u
setproperty!(dc.linear_cache, sym, nodual_value(val))
if sym === :A
setA!(dc, val)
elseif sym === :b
setb!(dc, val)
elseif sym === :u
setu!(dc, val)
elseif hasfield(DualLinearCache, sym)
setfield!(dc, sym, val)
elseif hasfield(LinearSolve.LinearCache, sym)
setproperty!(dc.linear_cache, sym, val)
end

# Update the partials and invalidate cache if setting A or b
if sym === :A
setfield!(dc, :dual_A, val)
setfield!(dc, :partials_A, partial_vals(val))
setfield!(dc, :rhs_cache_valid, false) # Invalidate cache
elseif sym === :b
setfield!(dc, :dual_b, val)
setfield!(dc, :partials_b, partial_vals(val))
setfield!(dc, :rhs_cache_valid, false) # Invalidate cache
elseif sym === :u
setfield!(dc, :dual_u, val)
setfield!(dc, :partials_u, partial_vals(val))
end
nothing
end

# "Forwards" getproperty to LinearCache if necessary
Expand Down Expand Up @@ -360,30 +384,20 @@ partial_vals(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.pa
partial_vals(x::Dual{T, V, P}) where {T, V <: Dual, P} = ForwardDiff.partials(x)
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)
partial_vals(x) = nothing
partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place

# Add recursive handling for nested dual values
nodual_value(x) = x
nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x)
nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact

function nodual_value(x::AbstractArray{<:Dual})
# Create a similar array with the appropriate element type
T = typeof(nodual_value(first(x)))
result = similar(x, T)

# Fill the result array with values
for i in eachindex(x)
result[i] = nodual_value(x[i])
end

return result
end
nodual_value(x::AbstractArray{<:Dual}) = nodual_value!(similar(x, typeof(nodual_value(first(x)))), x)
nodual_value!(out, x) = map!(nodual_value, out, x) # Update in-place

function update_partials_list!(partial_matrix::AbstractVector{T}, list_cache) where {T}
p = eachindex(first(partial_matrix))
for i in p
for j in eachindex(partial_matrix)
list_cache[i][j] = partial_matrix[j][i]
@inbounds list_cache[i][j] = partial_matrix[j][i]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not required, but they gave me a non-negligible speedup, as I saw checkbounds showing up in profiling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it so the user should not suffer a performance penalty from library internals. I can remove them if you want to keep it safe.

This function is just shuffling data around. The optimal solution would be to avoid it altogether, but I am not sure if it's easily possible.

end
end
return list_cache
Expand All @@ -396,7 +410,7 @@ function update_partials_list!(partial_matrix, list_cache)
for k in 1:p
for i in 1:m
for j in 1:n
list_cache[k][i, j] = partial_matrix[i, j][k]
@inbounds list_cache[k][i, j] = partial_matrix[i, j][k]
end
end
end
Expand Down
Loading