From 4425ce97c9c447f65e0c0e05cb795cf96c3fb866 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simen=20Hus=C3=B8y?= Date: Wed, 15 Oct 2025 16:56:07 +0200 Subject: [PATCH] Adapt rule for ArrayPartition --- src/array_partition.jl | 4 ++++ test/gpu/arraypartition_gpu.jl | 22 +++++++++++++++++++++- test/partitions_test.jl | 21 ++++++++++++++++++++- 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/array_partition.jl b/src/array_partition.jl index 89261f4b..5eb1289a 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -579,3 +579,7 @@ end end return sum_expr end + +function Adapt.adapt_structure(to, ap::ArrayPartition) + ArrayPartition(map(x -> Adapt.adapt(to, x), ap.x)...) +end \ No newline at end of file diff --git a/test/gpu/arraypartition_gpu.jl b/test/gpu/arraypartition_gpu.jl index 08fdb69a..76c7a6ea 100644 --- a/test/gpu/arraypartition_gpu.jl +++ b/test/gpu/arraypartition_gpu.jl @@ -1,4 +1,4 @@ -using RecursiveArrayTools, CUDA, Test +using RecursiveArrayTools, CUDA, Test, Adapt CUDA.allowscalar(false) # Test indexing with colon @@ -21,3 +21,23 @@ fill!(pA, false) a = ArrayPartition(([1.0f0] |> cu, [2.0f0] |> cu, [3.0f0] |> cu)) b = ArrayPartition(([0.0f0] |> cu, [0.0f0] |> cu, [0.0f0] |> cu)) @. a + b + +# Test adapt from ArrayPartition with CuArrays to ArrayPartition with CPU arrays + +a = CuArray(Float64.([1., 2., 3., 4.])) +b = CuArray(Float64.([1., 2., 3., 4.])) +part_a_gpu = ArrayPartition(a, b) +part_a = adapt(Array{Float32}, part_a_gpu) + +c = Float32.([1., 2., 3., 4.]) +d = Float32.([1., 2., 3., 4.]) +part_b = ArrayPartition(c, d) + +@test part_a == part_b # Test equality + +for i in 1:length(part_a.x) + sub_a = part_a.x[i] + sub_b = part_b.x[i] + @test sub_a == sub_b # Test for value equality in sub-arrays + @test typeof(sub_a) === typeof(sub_b) # Test type equality +end \ No newline at end of file diff --git a/test/partitions_test.jl b/test/partitions_test.jl index 3c0c2232..27abfaf6 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -1,4 +1,4 @@ -using RecursiveArrayTools, Test, Statistics, ArrayInterface +using RecursiveArrayTools, Test, Statistics, ArrayInterface, Adapt @test length(ArrayPartition()) == 0 @test isempty(ArrayPartition()) @@ -306,3 +306,22 @@ end copyto!(u, ArrayPartition(1.0, -1.2)) @test u == [1.0, -1.2] end + +# Test adapt on ArrayPartition from Float64 to Float32 arrays +a = Float64.([1., 2., 3., 4.]) +b = Float64.([1., 2., 3., 4.]) +part_a_64 = ArrayPartition(a, b) +part_a = adapt(Array{Float32}, part_a_64) + +c = Float32.([1., 2., 3., 4.]) +d = Float32.([1., 2., 3., 4.]) +part_b = ArrayPartition(c, d) + +@test part_a == part_b # Test equality of partitions + +for i in 1:length(part_a.x) + sub_a = part_a.x[i] + sub_b = part_b.x[i] + @test sub_a == sub_b # Test for value equality + @test typeof(sub_a) === typeof(sub_b) # Test type equality +end \ No newline at end of file