diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index ca4d6efdff..d849efc61e 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -443,7 +443,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( seen = Reactant.OrderedIdDict() prev = Any[func.f, args...] kernelargsym = gensym("kernelarg") - Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.TracedTrack) + Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.NoStopTracedTrack) wrapper_tys = MLIR.IR.Type[] for arg in values(seen) if !(arg isa TracedRArray || arg isa TracedRNumber) @@ -536,16 +536,18 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( if !(arg isa TracedRArray || arg isa TracedRNumber) continue end - for p in Reactant.TracedUtils.get_paths(arg) + + paths = Reactant.TracedUtils.get_paths(arg) + + arg = arg.mlir_data + arg = Reactant.TracedUtils.transpose_val(arg) + push!(restys, MLIR.IR.type(arg)) + push!(mlir_args, arg) + + for p in paths if p[1] !== kernelargsym continue end - - arg = arg.mlir_data - arg = Reactant.TracedUtils.transpose_val(arg) - push!(restys, MLIR.IR.type(arg)) - push!(mlir_args, arg) - # Get the allocation corresponding to which arg we're doing alloc = allocs[p[2]][1] @@ -580,9 +582,8 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( ), ), ) - - argidx += 1 end + argidx += 1 end MLIR.IR.block!(wrapbody) do diff --git a/src/Tracing.jl b/src/Tracing.jl index e00fdcb006..2f4885147d 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -4,6 +4,7 @@ TracedToConcrete = 3 ArrayToConcrete = 4 TracedSetPath = 5 + NoStopTracedTrack = 6 end for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RNumber) @@ -249,7 +250,7 @@ function traced_type( @inline base_typec(TV::TT) where {TT<:DataType} = (T <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...} return base_typec(T) - elseif mode == TracedTrack || mode == TracedSetPath + elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath return T else throw("Abstract RArray $T cannot be made concrete in mode $mode") @@ -261,7 +262,7 @@ function traced_type(::Type{T}, seen, ::Val{mode}, track_numbers) where {T<:Trac throw("TracedRNG cannot be traced") elseif mode == TracedToConcrete return ConcreteRNG - elseif mode == TracedTrack || mode == TracedSetPath + elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath return T else throw("Unsupported mode: $mode") @@ -329,7 +330,7 @@ function make_tracer( track_numbers=(), kwargs..., ) where {RT} - if haskey(seen, prev) + if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] end TT = traced_type(RT, (), Val(mode), track_numbers) @@ -460,6 +461,13 @@ function make_tracer( end return prev end + if mode == NoStopTracedTrack + TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) + if !haskey(seen, prev) + seen[prev] = prev # don't return! + end + return prev + end if mode == TracedSetPath if haskey(seen, prev) return seen[prev] @@ -506,6 +514,13 @@ function make_tracer( end return prev end + if mode == NoStopTracedTrack + TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) + if !haskey(seen, prev) + seen[prev] = prev # don't return! + end + return prev + end if mode == TracedSetPath if haskey(seen, prev) return seen[prev] @@ -546,6 +561,13 @@ function make_tracer( end return prev end + if mode == NoStopTracedTrack + TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) + if !haskey(seen, prev) + seen[prev] = prev # don't return! + end + return prev + end if mode == TracedSetPath haskey(seen, prev) && return seen[prev] res = MissingTracedValue((path,)) diff --git a/test/integration/cuda.jl b/test/integration/cuda.jl index ca445e3e26..817dfa7408 100644 --- a/test/integration/cuda.jl +++ b/test/integration/cuda.jl @@ -115,4 +115,40 @@ tuplef2(a) = @cuda threads = 1 tuplef2!((5, a)) @code_hlo optimize = :before_kernel tuplef2(A) end end + A = ConcreteRArray(fill(1)) + if CUDA.functional() + @jit tuplef2(A) + @test all(Array(A) .≈ 5) + else + @code_hlo optimize = :before_kernel tuplef2(A) + end +end + +# TODO this same code fails if we use a 0-d array...? +# maybe weird cuda things +function aliased!(tup) + x, y = tup + x[2][1] *= y[2][1] + return nothing +end + +function aliased(s) + tup = (s, s) + @cuda threads = 1 aliased!(tup) + return nothing +end + +@static if !Sys.isapple() + @testset "Aliasing arguments" begin + a = ConcreteRArray([3]) + + s = (10, a) + + if CUDA.functional() + @jit aliased((s, s)) + @test all(Array(a) == 9) + else + @code_hlo optimize = :before_kernel aliased(s) + end + end end