Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ struct SMCMixedModeSparseJacobianPrep{
BSr<:DI.BatchSizeSettings,
P<:AbstractMatrix,
C<:AbstractColoringResult{:nonsymmetric,:bidirectional},
M<:AbstractMatrix{<:Number},
Mf<:AbstractMatrix{<:Number},
Mr<:AbstractMatrix{<:Number},
Sfp<:NTuple,
Srp<:NTuple,
Sf<:Vector{<:NTuple},
Sr<:Vector{<:NTuple},
Rf<:Vector{<:NTuple},
Expand All @@ -19,8 +22,10 @@ struct SMCMixedModeSparseJacobianPrep{
batch_size_settings_reverse::BSr
sparsity::P
coloring_result::C
compressed_matrix_forward::M
compressed_matrix_reverse::M
compressed_matrix_forward::Mf
compressed_matrix_reverse::Mr
batched_seed_forward_prep::Sfp
batched_seed_reverse_prep::Srp
batched_seeds_forward::Sf
batched_seeds_reverse::Sr
batched_results_forward::Rf
Expand Down Expand Up @@ -111,12 +116,24 @@ function _prepare_mixed_sparse_jacobian_aux_aux(
groups_forward = column_groups(coloring_result)
groups_reverse = row_groups(coloring_result)

seed_forward_prep = DI.multibasis(x, eachindex(x))
seed_reverse_prep = DI.multibasis(y, eachindex(y))
seeds_forward = [DI.multibasis(x, eachindex(x)[group]) for group in groups_forward]
seeds_reverse = [DI.multibasis(y, eachindex(y)[group]) for group in groups_reverse]

compressed_matrix_forward = stack(_ -> vec(similar(y)), groups_forward; dims=2)
compressed_matrix_reverse = stack(_ -> vec(similar(x)), groups_reverse; dims=1)
compressed_matrix_forward = if isempty(groups_forward)
similar(vec(y), length(y), 0)
else
stack(_ -> vec(similar(y)), groups_forward; dims=2)
end
compressed_matrix_reverse = if isempty(groups_reverse)
similar(vec(x), 0, length(x))
else
stack(_ -> vec(similar(x)), groups_reverse; dims=1)
end

batched_seed_forward_prep = ntuple(b -> copy(seed_forward_prep), Val(Bf))
batched_seed_reverse_prep = ntuple(b -> copy(seed_reverse_prep), Val(Br))
batched_seeds_forward = [
ntuple(b -> seeds_forward[1 + ((a - 1) * Bf + (b - 1)) % Nf], Val(Bf)) for a in 1:Af
]
Expand All @@ -136,15 +153,15 @@ function _prepare_mixed_sparse_jacobian_aux_aux(
f_or_f!y...,
DI.forward_backend(dense_backend),
x,
batched_seeds_forward[1],
batched_seed_forward_prep,
contexts...;
)
pullback_prep = DI.prepare_pullback_nokwarg(
strict,
f_or_f!y...,
DI.reverse_backend(dense_backend),
x,
batched_seeds_reverse[1],
batched_seed_reverse_prep,
contexts...;
)

Expand All @@ -156,6 +173,8 @@ function _prepare_mixed_sparse_jacobian_aux_aux(
coloring_result,
compressed_matrix_forward,
compressed_matrix_reverse,
batched_seed_forward_prep,
batched_seed_reverse_prep,
batched_seeds_forward,
batched_seeds_reverse,
batched_results_forward,
Expand Down Expand Up @@ -183,6 +202,8 @@ function _sparse_jacobian_aux!(
coloring_result,
compressed_matrix_forward,
compressed_matrix_reverse,
batched_seed_forward_prep,
batched_seed_reverse_prep,
batched_seeds_forward,
batched_seeds_reverse,
batched_results_forward,
Expand All @@ -200,15 +221,15 @@ function _sparse_jacobian_aux!(
pushforward_prep,
DI.forward_backend(dense_backend),
x,
batched_seeds_forward[1],
batched_seed_forward_prep,
contexts...,
)
pullback_prep_same = DI.prepare_pullback_same_point(
f_or_f!y...,
pullback_prep,
DI.reverse_backend(dense_backend),
x,
batched_seeds_reverse[1],
batched_seed_reverse_prep,
contexts...,
)

Expand Down
5 changes: 5 additions & 0 deletions DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ end
@test only(row_groups(jac_rev_prep)) == 1:10
@test only(column_groups(hess_prep)) == 1:10
end

@testset "Empty colors for mixed mode" begin # issue 857
backend = MyAutoSparse(MixedMode(adaptive_backends[1], adaptive_backends[2]))
@test jacobian(copyto!, zeros(10), backend, ones(10)) isa AbstractMatrix
end
end

@testset "Misc" begin
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/test/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function MyAutoSparse(backend::AbstractADType)
return AutoSparse(
backend;
sparsity_detector=TracerSparsityDetector(),
coloring_algorithm=GreedyColoringAlgorithm(),
coloring_algorithm=GreedyColoringAlgorithm(; postprocessing=true),
)
end

Expand Down
Loading