Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/array_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -579,3 +579,7 @@ end
end
return sum_expr
end

function Adapt.adapt_structure(to, ap::ArrayPartition)
ArrayPartition(map(x -> Adapt.adapt(to, x), ap.x)...)
end
22 changes: 21 additions & 1 deletion test/gpu/arraypartition_gpu.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using RecursiveArrayTools, CUDA, Test
using RecursiveArrayTools, CUDA, Test, Adapt
CUDA.allowscalar(false)

# Test indexing with colon
Expand All @@ -21,3 +21,23 @@ fill!(pA, false)
a = ArrayPartition(([1.0f0] |> cu, [2.0f0] |> cu, [3.0f0] |> cu))
b = ArrayPartition(([0.0f0] |> cu, [0.0f0] |> cu, [0.0f0] |> cu))
@. a + b

# Test adapt from ArrayPartition with CuArrays to ArrayPartition with CPU arrays

a = CuArray(Float64.([1., 2., 3., 4.]))
b = CuArray(Float64.([1., 2., 3., 4.]))
part_a_gpu = ArrayPartition(a, b)
part_a = adapt(Array{Float32}, part_a_gpu)

c = Float32.([1., 2., 3., 4.])
d = Float32.([1., 2., 3., 4.])
part_b = ArrayPartition(c, d)

@test part_a == part_b # Test equality

for i in 1:length(part_a.x)
sub_a = part_a.x[i]
sub_b = part_b.x[i]
@test sub_a == sub_b # Test for value equality in sub-arrays
@test typeof(sub_a) === typeof(sub_b) # Test type equality
end
21 changes: 20 additions & 1 deletion test/partitions_test.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using RecursiveArrayTools, Test, Statistics, ArrayInterface
using RecursiveArrayTools, Test, Statistics, ArrayInterface, Adapt

@test length(ArrayPartition()) == 0
@test isempty(ArrayPartition())
Expand Down Expand Up @@ -306,3 +306,22 @@ end
copyto!(u, ArrayPartition(1.0, -1.2))
@test u == [1.0, -1.2]
end

# Test adapt on ArrayPartition from Float64 to Float32 arrays
a = Float64.([1., 2., 3., 4.])
b = Float64.([1., 2., 3., 4.])
part_a_64 = ArrayPartition(a, b)
part_a = adapt(Array{Float32}, part_a_64)

c = Float32.([1., 2., 3., 4.])
d = Float32.([1., 2., 3., 4.])
part_b = ArrayPartition(c, d)

@test part_a == part_b # Test equality of partitions

for i in 1:length(part_a.x)
sub_a = part_a.x[i]
sub_b = part_b.x[i]
@test sub_a == sub_b # Test for value equality
@test typeof(sub_a) === typeof(sub_b) # Test type equality
end
Loading