Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
seen = Reactant.OrderedIdDict()
prev = Any[func.f, args...]
kernelargsym = gensym("kernelarg")
Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.TracedTrack)
Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.NoStopTracedTrack)
wrapper_tys = MLIR.IR.Type[]
for arg in values(seen)
if !(arg isa TracedRArray || arg isa TracedRNumber)
Expand Down Expand Up @@ -536,16 +536,18 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
if !(arg isa TracedRArray || arg isa TracedRNumber)
continue
end
for p in Reactant.TracedUtils.get_paths(arg)

paths = Reactant.TracedUtils.get_paths(arg)

arg = arg.mlir_data
arg = Reactant.TracedUtils.transpose_val(arg)
push!(restys, MLIR.IR.type(arg))
push!(mlir_args, arg)

for p in paths
if p[1] !== kernelargsym
continue
end

arg = arg.mlir_data
arg = Reactant.TracedUtils.transpose_val(arg)
push!(restys, MLIR.IR.type(arg))
push!(mlir_args, arg)

# Get the allocation corresponding to which arg we're doing
alloc = allocs[p[2]][1]

Expand Down Expand Up @@ -580,9 +582,8 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I moved argidx++ out of the loop, I'm not sure whether this is still correct because I don't really know what's happening here.

),
)

argidx += 1
end
argidx += 1
end

MLIR.IR.block!(wrapbody) do
Expand Down
28 changes: 25 additions & 3 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
TracedToConcrete = 3
ArrayToConcrete = 4
TracedSetPath = 5
NoStopTracedTrack = 6
end

for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RNumber)
Expand Down Expand Up @@ -249,7 +250,7 @@ function traced_type(
@inline base_typec(TV::TT) where {TT<:DataType} =
(T <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...}
return base_typec(T)
elseif mode == TracedTrack || mode == TracedSetPath
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
return T
else
throw("Abstract RArray $T cannot be made concrete in mode $mode")
Expand All @@ -261,7 +262,7 @@ function traced_type(::Type{T}, seen, ::Val{mode}, track_numbers) where {T<:Trac
throw("TracedRNG cannot be traced")
elseif mode == TracedToConcrete
return ConcreteRNG
elseif mode == TracedTrack || mode == TracedSetPath
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
return T
else
throw("Unsupported mode: $mode")
Expand Down Expand Up @@ -329,7 +330,7 @@ function make_tracer(
track_numbers=(),
kwargs...,
) where {RT}
if haskey(seen, prev)
if mode != NoStopTracedTrack && haskey(seen, prev)
return seen[prev]
end
TT = traced_type(RT, (), Val(mode), track_numbers)
Expand Down Expand Up @@ -460,6 +461,13 @@ function make_tracer(
end
return prev
end
if mode == NoStopTracedTrack
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if !haskey(seen, prev)
seen[prev] = prev # don't return!
end
return prev
end
if mode == TracedSetPath
if haskey(seen, prev)
return seen[prev]
Expand Down Expand Up @@ -506,6 +514,13 @@ function make_tracer(
end
return prev
end
if mode == NoStopTracedTrack
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if !haskey(seen, prev)
seen[prev] = prev # don't return!
end
return prev
end
if mode == TracedSetPath
if haskey(seen, prev)
return seen[prev]
Expand Down Expand Up @@ -546,6 +561,13 @@ function make_tracer(
end
return prev
end
if mode == NoStopTracedTrack
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if !haskey(seen, prev)
seen[prev] = prev # don't return!
end
return prev
end
if mode == TracedSetPath
haskey(seen, prev) && return seen[prev]
res = MissingTracedValue((path,))
Expand Down
36 changes: 36 additions & 0 deletions test/integration/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,40 @@ tuplef2(a) = @cuda threads = 1 tuplef2!((5, a))
@code_hlo optimize = :before_kernel tuplef2(A)
end
end
A = ConcreteRArray(fill(1))
if CUDA.functional()
@jit tuplef2(A)
@test all(Array(A) .≈ 5)
else
@code_hlo optimize = :before_kernel tuplef2(A)
end
end

# TODO this same code fails if we use a 0-d array...?
# maybe weird cuda things
function aliased!(tup)
x, y = tup
x[2][1] *= y[2][1]
return nothing
end

function aliased(s)
tup = (s, s)
@cuda threads = 1 aliased!(tup)
return nothing
end

@static if !Sys.isapple()
@testset "Aliasing arguments" begin
a = ConcreteRArray([3])

s = (10, a)

if CUDA.functional()
@jit aliased((s, s))
@test all(Array(a) == 9)
else
@code_hlo optimize = :before_kernel aliased(s)
end
end
end
Loading