diff --git a/src/device/cuda/atomics.jl b/src/device/cuda/atomics.jl index 7f7661d6..43a4115a 100644 --- a/src/device/cuda/atomics.jl +++ b/src/device/cuda/atomics.jl @@ -299,6 +299,10 @@ const inplace_ops = Dict( :(⊻=) => :(⊻) ) +struct AtomicError <: Exception + msg::AbstractString +end + """ @atomic a[I] = op(a[I], val) @atomic a[I] ...= val @@ -319,10 +323,10 @@ macro atomic(ex) if ex.head == :(=) ref = ex.args[1] rhs = ex.args[2] - rhs.head == :call || error("right-hand side of an @atomic assignment should be a call") + Meta.isexpr(rhs, :call) || throw(AtomicError("right-hand side of an @atomic assignment should be a call")) op = rhs.args[1] if rhs.args[2] != ref - error("non-inplace @atomic assignment should reference the same array elements") + throw(AtomicError("right-hand side of a non-inplace @atomic assignment should reference the left-hand side")) end val = rhs.args[3] elseif haskey(inplace_ops, ex.head) @@ -330,13 +334,11 @@ macro atomic(ex) ref = ex.args[1] val = ex.args[2] else - error("unknown @atomic expression") + throw(AtomicError("unknown @atomic expression")) end # decode array expression - if ref.head != :ref - error("@atomic should be applied to an array reference expression") - end + Meta.isexpr(ref, :ref) || throw(AtomicError("@atomic should be applied to an array reference expression")) array = ref.args[1] indices = Expr(:tuple, ref.args[2:end]...) diff --git a/test/device/cuda.jl b/test/device/cuda.jl index aa54166e..3f358e80 100644 --- a/test/device/cuda.jl +++ b/test/device/cuda.jl @@ -1100,6 +1100,29 @@ end end end +@testset "macro" begin + using CUDAnative: AtomicError + + @test_throws_macro AtomicError("right-hand side of an @atomic assignment should be a call") @macroexpand begin + @atomic a[1] = 1 + end + @test_throws_macro AtomicError("right-hand side of an @atomic assignment should be a call") @macroexpand begin + @atomic a[1] = b ? 1 : 2 + end + + @test_throws_macro AtomicError("right-hand side of a non-inplace @atomic assignment should reference the left-hand side") @macroexpand begin + @atomic a[1] = a[2] + 1 + end + + @test_throws_macro AtomicError("unknown @atomic expression") @macroexpand begin + @atomic wat(a[1]) + end + + @test_throws_macro AtomicError("@atomic should be applied to an array reference expression") @macroexpand begin + @atomic a = a + 1 + end +end + end ############################################################################################ diff --git a/test/util.jl b/test/util.jl index 39e2bcc8..eab61100 100644 --- a/test/util.jl +++ b/test/util.jl @@ -17,6 +17,20 @@ macro test_throws_message(f, typ, ex...) end end +# @test_throw, peeking into the load error for testing macro errors +macro test_throws_macro(ty, ex) + return quote + Test.@test_throws $(esc(ty)) try + $(esc(ex)) + catch err + @test err isa LoadError + @test err.file === $(string(__source__.file)) + @test err.line === $(__source__.line + 1) + rethrow(err.error) + end + end +end + # NOTE: based on test/pkg.jl::capture_stdout, but doesn't discard exceptions macro grab_output(ex) quote