Skip to content
Merged
41 changes: 38 additions & 3 deletions src/array_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,40 @@ Base.:(==)(A::ArrayPartition, B::ArrayPartition) = A.x == B.x
## Iterable Collection Constructs

Base.map(f, A::ArrayPartition) = ArrayPartition(map(x -> map(f, x), A.x))
function Base.mapreduce(f, op, A::ArrayPartition{T}; kwargs...) where {T}
mapreduce(f, op, (i for i in A); kwargs...)
# Use @generated function for type stability on Julia 1.10
# The generated approach avoids type inference issues with kwargs in older Julia versions
@generated function _mapreduce_impl(f, op, A::ArrayPartition{T, S}) where {T, S}
N = length(S.parameters)
if N == 1
return :(mapreduce(f, op, A.x[1]))
else
expr = :(mapreduce(f, op, A.x[$N]))
for i in (N - 1):-1:1
expr = :(op(mapreduce(f, op, A.x[$i]), $expr))
end
return expr
end
end
@generated function _mapreduce_impl_init(f, op, A::ArrayPartition{T, S}, init) where {T, S}
N = length(S.parameters)
if N == 1
return :(mapreduce(f, op, A.x[1]))
else
expr = :(mapreduce(f, op, A.x[$N]))
for i in (N - 1):-1:1
expr = :(op(mapreduce(f, op, A.x[$i]), $expr))
end
# Apply init only at the outermost reduction
return :(op(init, $expr))
end
end
@inline function Base.mapreduce(f, op, A::ArrayPartition;
init = Base._InitialValue(), kwargs...)
if init isa Base._InitialValue
_mapreduce_impl(f, op, A)
else
_mapreduce_impl_init(f, op, A, init)
end
end
Base.filter(f, A::ArrayPartition) = ArrayPartition(map(x -> filter(f, x), A.x))
Base.any(f, A::ArrayPartition) = any((any(f, x) for x in A.x))
Expand Down Expand Up @@ -442,7 +474,10 @@ end

## Linear Algebra

ArrayInterface.zeromatrix(A::ArrayPartition) = ArrayInterface.zeromatrix(Vector(A))
function ArrayInterface.zeromatrix(A::ArrayPartition)
x = reduce(vcat,vec.(A.x))
x .* x' .* false
end

function __get_subtypes_in_module(
mod, supertype; include_supertype = true, all = false, except = [])
Expand Down
7 changes: 7 additions & 0 deletions src/named_array_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ end
return dest
end

#Overwrite ArrayInterface zeromatrix to work with NamedArrayPartitions & implicit solvers within OrdinaryDiffEq
function ArrayInterface.zeromatrix(A::NamedArrayPartition)
B = ArrayPartition(A)
x = reduce(vcat,vec.(B.x))
x .* x' .* false
end

# `x = find_NamedArrayPartition(x)` returns the first `NamedArrayPartition` among broadcast arguments.
find_NamedArrayPartition(bc::Base.Broadcast.Broadcasted) = find_NamedArrayPartition(bc.args)
function find_NamedArrayPartition(args::Tuple)
Expand Down
9 changes: 7 additions & 2 deletions test/gpu/arraypartition_gpu.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using RecursiveArrayTools, CUDA, Test, Adapt
using RecursiveArrayTools, ArrayInterface, CUDA, Adapt, Test
CUDA.allowscalar(false)

# Test indexing with colon
Expand Down Expand Up @@ -40,4 +40,9 @@ for i in 1:length(part_a.x)
sub_b = part_b.x[i]
@test sub_a == sub_b # Test for value equality in sub-arrays
@test typeof(sub_a) === typeof(sub_b) # Test type equality
end
end

x = ArrayPartition((CUDA.zeros(2),CUDA.zeros(2)))
@test ArrayInterface.zeromatrix(x) isa CuMatrix
@test size(ArrayInterface.zeromatrix(x)) == (4,4)
@test maximum(abs, x) == 0f0
5 changes: 4 additions & 1 deletion test/named_array_partition_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using RecursiveArrayTools, Test
using RecursiveArrayTools, ArrayInterface, Test

@testset "NamedArrayPartition tests" begin
x = NamedArrayPartition(a = ones(10), b = rand(20))
Expand All @@ -9,10 +9,13 @@ using RecursiveArrayTools, Test
@test x.a ≈ ones(10)
@test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence
@test all(x .== x[1:end])
@test ArrayInterface.zeromatrix(x) isa Matrix
@test size(ArrayInterface.zeromatrix(x)) == (30,30)
y = copy(x)
@test zero(x, (10, 20)) == zero(x) # test that ignoring dims works
@test typeof(zero(x)) <: NamedArrayPartition
@test (y .*= 2).a[1] ≈ 2 # test in-place bcast


@test length(Array(x)) == 30
@test typeof(Array(x)) <: Array
Expand Down
Loading