From bdaf745c07221a88450472e3ac41c059aad74320 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 23 Jun 2023 14:46:43 -0400 Subject: [PATCH] Add projection mechanism --- src/nditeration.jl | 28 ++++++++++++++++++++++------ test/test.jl | 22 ++++++++++++++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/nditeration.jl b/src/nditeration.jl index d7598ae2..cca9ba0c 100644 --- a/src/nditeration.jl +++ b/src/nditeration.jl @@ -46,16 +46,17 @@ for block in ndrange end ``` """ -struct NDRange{N, StaticBlocks, StaticWorkitems, DynamicBlock, DynamicWorkitems} +struct NDRange{N, StaticBlocks, StaticWorkitems, DynamicBlock, DynamicWorkitems, Projection} blocks::DynamicBlock workitems::DynamicWorkitems + projection::Projection - function NDRange{N, B, W}() where {N, B, W} - new{N, B, W, Nothing, Nothing}(nothing, nothing) + function NDRange{N, B, W}(projection=identity) where {N, B, W} + new{N, B, W, Nothing, Nothing, typeof(projection)}(nothing, nothing, projection) end - function NDRange{N, B, W}(blocks, workitems) where {N, B, W} - new{N, B, W, typeof(blocks), typeof(workitems)}(blocks, workitems) + function NDRange{N, B, W}(blocks, workitems, projection=identity) where {N, B, W} + new{N, B, W, typeof(blocks), typeof(workitems), typeof(projection)}(blocks, workitems, projection) end end @@ -77,7 +78,7 @@ Base.length(range::NDRange) = length(blocks(range)) gidx = groupidx.I[I] (gidx-1)*stride + idx.I[I] end - CartesianIndex(nI) + ndrange.projection(CartesianIndex(nI)) end Base.@propagate_inbounds function expand(ndrange::NDRange, groupidx::Integer, idx::Integer) @@ -126,4 +127,19 @@ needs to perform dynamic bounds-checking. end end +abstract type IndexProjection end +struct Identity <: IndexProjection end +(::Identity)(idx::CartesianIndex) = idx +const identity = Identity() + +struct Offsets{N} <: IndexProjection + offsets::NTuple{N} +end +function (o::Offsets{N})(idx::CartesianIndex{N}) where N + nI = ntuple(Val{N}) do i + idx.I[i] + o.offsets[i] + end + CartesianIndex(nI) +end + end #module diff --git a/test/test.jl b/test/test.jl index 88086342..aaeb6710 100644 --- a/test/test.jl +++ b/test/test.jl @@ -215,6 +215,28 @@ end synchronize(Backend()) end +@kernel function index_global_offset!(a) + i, j = @index(Global, NTuple) + n, m = size(a) + @inbounds a[i, j] = i + n * j +end + +@conditional_testset "Offset iteration space $Backend" skip_tests begin + a = KernelAbstractions.zeros(Backend(), 7, 9) + index_global_offset!(Backend(), (2, 2), size(a) .- 4, (2, 2))(a) + synchronize(Backend()) + + b = KernelAbstractions.zeros(CPU(), 7, 9) + b .= a + + c = [i + 7 * j for i in 1:7, j in 1:9] + + @test b[3:5, 3:7] == c[3:5, 3:7] + @test b[1:2, :] == zeros(2, 9) + @test b[6:7, :] == zeros(2, 9) + @test b[:, 1:2] == zeros(7, 2) + @test b[:, 8:9] == zeros(7, 2) +end @conditional_testset "return statement" skip_tests begin try