Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Reactant"
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>"]
version = "0.2.17"
version = "0.2.18"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -67,7 +67,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.37"
Reactant_jll = "0.0.39"
Scratch = "1.2"
SpecialFunctions = "2"
Statistics = "1.10"
Expand Down
20 changes: 16 additions & 4 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,8 @@ static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName,
static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op,
mlir::ModuleOp source,
mlir::ModuleOp target,
unsigned &lastUsedID) {
unsigned &lastUsedID,
bool &shouldRemove) {
using namespace llvm;
using namespace mlir;

Expand All @@ -639,6 +640,13 @@ static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op,
return success();
}

if (auto func = dyn_cast<FunctionOpInterface>(op.getOperation())) {
if (func.isExternal()) {
shouldRemove = true;
return success();
}
}

StringAttr newSymName = renameSymbol(opName, lastUsedID, source, target);

if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, source)))
Expand All @@ -658,7 +666,7 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,

unsigned lastUsedID = 0;

for (auto &op : *newMod.getBody()) {
for (auto &op : make_early_inc_range(*newMod.getBody())) {
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;
Expand All @@ -669,10 +677,14 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
entryFn = &op;
}

if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID))) {
bool shouldRemove = false;
if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID, shouldRemove))) {
assert(0 && "failed to update all uses");
}
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
if (shouldRemove)
op.erase();
else
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
}
prevMod.getBody()->getOperations().splice(
prevMod.getBody()->getOperations().end(),
Expand Down
6 changes: 4 additions & 2 deletions deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "85612ea74731f02aa4e30800038e065912d37ae2"
ENZYMEXLA_COMMIT = "4d7c91e5d71fc98b901f7aa40b6deacb449fa873"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down Expand Up @@ -138,7 +138,9 @@ http_archive(
patches = ["@enzyme_ad//:patches/jax.patch"],
)

load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5"
XLA_SHA256 = ""

http_archive(
name = "xla",
Expand Down
159 changes: 119 additions & 40 deletions ext/ReactantCUDAExt.jl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

sz = sizeof(x)
ref = Ref(x)
GC.@preserve ref begin
ptr = Base.reinterpret(Ptr{UInt8}, Base.unsafe_convert(Ptr{Cvoid}, ref))
vec = Vector{UInt8}(undef, sz)
for i in 1:sz
@inbounds vec[i] = Base.unsafe_load(ptr, i)
end
vec
end
end
function Reactant.make_tracer(seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

MLIR.IR.attr!(gpufunc, "CConv", MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvC)))
gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type"))

Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,13 @@ function compile(job)
# TODO: on 1.9, this actually creates a context. cache those.
entry = GPUCompiler.JuliaContext() do ctx
mod, meta = GPUCompiler.compile(
# :llvm, job; optimize=false, cleanup=false, validate=false, libraries=true
:llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
:llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
:llvm,
job;
optimize=false,
cleanup=false,
validate=false,
libraries=false,

# :llvm, job; optimize=false, cleanup=false, validate=true, libraries=false
# :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
)

GPUCompiler.link_library!(mod, GPUCompiler.load_runtime(job))
entryname = LLVM.name(meta.entry)

GPUCompiler.optimize_module!(job, mod)
Expand Down Expand Up @@ -319,6 +323,8 @@ function compile(job)
end
end

# GPUCompiler.check_ir(job, mod)

LLVM.strip_debuginfo!(mod)
modstr = string(mod)

Expand Down Expand Up @@ -363,6 +369,38 @@ function to_bytes(x)
end
end

function Reactant.make_tracer(seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...)
x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))::TracedRArray
Reactant.make_tracer(seen, x, path, mode; kwargs...)
return prev
end

function get_field_offset(T::Type, path)
offset = 0
current_type = T

for field in path
# Get the field index
field_idx = if field isa Integer
field
else
@assert field isa Symbol
findfirst(==(field), fieldnames(current_type))
end
if field_idx === nothing
error("Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
error("Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path")
error(
"Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path",
)

end

# Add the offset of this field
offset += fieldoffset(current_type, field_idx)

# Update current_type to the field's type for next iteration
current_type = fieldtype(current_type, field_idx)
end

return offset
end

Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
args...;
convert=Val(false),
Expand All @@ -384,20 +422,19 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

wrapper_tys = MLIR.IR.Type[]
ctx = MLIR.IR.context()
cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1), 1))
for (i, a) in Tuple{Int, Any}[(0, func.f), enumerate(args)...]
if sizeof(a) == 0
cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1))

# linearize kernel arguments
seen = Reactant.OrderedIdDict()
prev = Any[func.f, args...]
kernelargsym = gensym("kernelarg")
Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.TracedTrack)
wrapper_tys = MLIR.IR.Type[]
for arg in values(seen)
if !(arg isa TracedRArray || arg isa TracedRNumber)
continue
end
if a isa CuTracedArray
a =
Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray
end
if a isa TracedRArray || a isa TracedRNumber
push!(wrapper_tys, cullvm_ty)
continue
end
# Per below we assume we can inline all other types directly in
push!(wrapper_tys, cullvm_ty)
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

sym_name = String(gensym("call_$fname"))
Expand Down Expand Up @@ -426,20 +463,60 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type"))

trueidx = 1
for (i, a) in Tuple{Int, Any}[(0, func.f), enumerate(args)...]
allocs = Union{Tuple{MLIR.IR.Value, MLIR.IR.Type}, Nothing}[]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
allocs = Union{Tuple{MLIR.IR.Value, MLIR.IR.Type}, Nothing}[]
allocs = Union{Tuple{MLIR.IR.Value,MLIR.IR.Type},Nothing}[]


llvmptr = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0))
i8 = MLIR.IR.Type(UInt8)
allargs = [func.f, args...]
for a in allargs
if sizeof(a) == 0
push!(allocs, nothing)
continue
end
if a isa CuTracedArray
a =
Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray

# TODO check for only integer and explicitly non cutraced types
MLIR.IR.block!(wrapbody) do
argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1))
argty = MLIR.IR.Type(
MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx - 1)
)

trueidx += 1
c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1)
alloc = MLIR.IR.result(MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr), 1)
Comment on lines +481 to +482
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1)
alloc = MLIR.IR.result(MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr), 1)
c1 = MLIR.IR.result(
MLIR.Dialects.llvm.mlir_constant(;
res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)
),
1,
)
alloc = MLIR.IR.result(
MLIR.Dialects.llvm.alloca(
c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr
),
1,
)

push!(allocs, (alloc, argty))

sz = sizeof(a)
array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz))
cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1)
cdata = MLIR.IR.result(
MLIR.Dialects.llvm.mlir_constant(;
res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))
),
1,
)

MLIR.Dialects.llvm.store(cdata, alloc)
end
if a isa TracedRArray || a isa TracedRNumber
push!(rarrays, a)
arg = a.mlir_data

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

end

argidx = 1
for arg in values(seen)
if !(arg isa TracedRArray || arg isa TracedRNumber)
continue
end
for p in Reactant.TracedUtils.get_paths(arg)
if p[1] !== kernelargsym
continue
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

arg = arg.mlir_data
arg = Reactant.TracedUtils.transpose_val(arg)
push!(restys, MLIR.IR.type(arg))
push!(mlir_args, arg)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# Get the allocation corresponding to which arg we're doing
alloc = allocs[p[2]][1]

# we need to now compute the offset in bytes of the path
julia_arg = allargs[p[2]]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

offset = get_field_offset(typeof(julia_arg), p[3:end])
MLIR.IR.block!(wrapbody) do
ptr = MLIR.IR.result(MLIR.Dialects.llvm.getelementptr(alloc, MLIR.IR.Value[], res=llvmptr, elem_type=i8, rawConstantIndices=MLIR.IR.Attribute([Int32(offset)])), 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
ptr = MLIR.IR.result(MLIR.Dialects.llvm.getelementptr(alloc, MLIR.IR.Value[], res=llvmptr, elem_type=i8, rawConstantIndices=MLIR.IR.Attribute([Int32(offset)])), 1)
ptr = MLIR.IR.result(
MLIR.Dialects.llvm.getelementptr(
alloc,
MLIR.IR.Value[];
res=llvmptr,
elem_type=i8,
rawConstantIndices=MLIR.IR.Attribute([Int32(offset)]),
),
1,
)

MLIR.Dialects.llvm.store(MLIR.IR.argument(wrapbody, argidx), ptr)
end

push!(
aliases,
MLIR.IR.Attribute(
Expand All @@ -453,30 +530,20 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
),
),
)
push!(wrapargs, MLIR.IR.argument(wrapbody, argidx))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

argidx += 1
trueidx += 1
continue
end

# TODO check for only integer and explicitly non cutraced types
@show "Warning: using fallback for kernel argument type conversion for argument of type $(Core.Typeof(a)), if this contains a CuTracedArray this will segfault"
MLIR.IR.block!(wrapbody) do
argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1))
trueidx += 1
c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1)
alloc = MLIR.IR.result(MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0))), 1)

sz = sizeof(a)
array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz))
cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1)
MLIR.Dialects.llvm.store(cdata, alloc)
argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1)
push!(wrapargs, argres)
end
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

MLIR.IR.block!(wrapbody) do
for arg in allocs
if arg === nothing
continue
end
alloc, argty = arg
argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1)
push!(wrapargs, argres)
end
MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[]))
MLIR.Dialects.llvm.call(
wrapargs,
MLIR.IR.Value[];
callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)),
op_bundle_sizes=MLIR.IR.Attribute(Int32[]),
)

MLIR.Dialects.llvm.return_(nothing)
end
Expand All @@ -500,8 +567,14 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
fn=MLIR.IR.FlatSymbolRefAttribute(sym_name),
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases)
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases),

)
for (i, res) in enumerate(rarrays)
res.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, i))

argidx = 1
for arg in values(seen)
if !(arg isa TracedRArray || arg isa TracedRNumber)
continue
end
arg.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, argidx))
argidx+=1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
argidx+=1
argidx += 1

end
end

Expand Down Expand Up @@ -546,6 +619,12 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction(
return Core.Typeof(res)(f, res.entry)
end

function Reactant.traced_type(
::Type{A}, seen::ST, ::Val{mode}, track_numbers
) where {A<:CuTracedArray,ST,mode}
return A
end

function Reactant.traced_type(
::Type{A}, seen::ST, ::Val{mode}, track_numbers
) where {T,N,A<:CUDA.CuArray{T,N},ST,mode}
Expand Down
2 changes: 1 addition & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ function optimization_passes(; no_nan::Bool=false)
)
func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",")
return join(
["inline{default-pipeline=canonicalize max-iterations=4}", func_passes], ','
["inline{default-pipeline=canonicalize max-iterations=4}", "libdevice-funcs-raise", func_passes], ','
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
["inline{default-pipeline=canonicalize max-iterations=4}", "libdevice-funcs-raise", func_passes], ','
[
"inline{default-pipeline=canonicalize max-iterations=4}",
"libdevice-funcs-raise",
func_passes,
],
',',

)
end

Expand Down
9 changes: 8 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ function should_rewrite_ft(@nospecialize(ft))
if ft.name.name == Symbol("#launch_configuration")
return false
end
if ft.name.name == Symbol("cudaconvert")
return false
end
end
end
end
Expand Down Expand Up @@ -161,7 +164,11 @@ function should_rewrite_ft(@nospecialize(ft))
ft <: typeof(Base.getproperty) ||
ft <: typeof(Base.vect) ||
ft <: typeof(Base.eltype) ||
ft <: typeof(Base.argtail)
ft <: typeof(Base.argtail) ||
ft <: typeof(Base.identity) ||
ft <: typeof(Base.print) ||
ft <: typeof(Base.println) ||
ft <: typeof(Adapt.adapt_structure)
return false
end

Expand Down
36 changes: 33 additions & 3 deletions test/integration/cuda.jl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

@testset "Constant Op Kernel" begin
oA = collect(1:1:64)
A = Reactant.to_rarray(oA)
if CUDA.functional()
@jit smul!(A)
@test all(Array(A) .≈ oA .* 15)
else
@code_hlo optimize = :before_kernel smul!(A)

Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ function smul!(x)
end

@static if !Sys.isapple()

# Broken pending jll update
@static if false
@testset "Constant Op Kernel" begin
oA = collect(1:1:64)
A = Reactant.to_rarray(oA)
Expand All @@ -87,4 +84,37 @@ end
end
end

Comment on lines 85 to 86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end


function tuplef!(tup)
tup[1][] += 2
return nothing
end

function tuplef2!(tup)
tup[2][] *= tup[1]
return nothing
end

tuplef(a) = @cuda threads=1 tuplef!((a,))
tuplef2(a) = @cuda threads=1 tuplef2!((5, a))
Comment on lines +98 to +99
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
tuplef(a) = @cuda threads=1 tuplef!((a,))
tuplef2(a) = @cuda threads=1 tuplef2!((5, a))
tuplef(a) = @cuda threads = 1 tuplef!((a,))
tuplef2(a) = @cuda threads = 1 tuplef2!((5, a))


@static if !Sys.isapple()
@testset "Structured Kernel Arguments" begin
A = ConcreteRArray(fill(1))
if CUDA.functional()
@jit tuplef(A)
@test all(Array(A) .≈ 3)
else
@code_hlo optimize = :before_kernel tuplef(A)
Comment on lines +102 to +108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@testset "Structured Kernel Arguments" begin
A = ConcreteRArray(fill(1))
if CUDA.functional()
@jit tuplef(A)
@test all(Array(A) .≈ 3)
else
@code_hlo optimize = :before_kernel tuplef(A)
@testset "Structured Kernel Arguments" begin
A = ConcreteRArray(fill(1))
if CUDA.functional()
@jit tuplef(A)
@test all(Array(A) .≈ 3)
else
@code_hlo optimize = :before_kernel tuplef(A)
end
A = ConcreteRArray(fill(1))
if CUDA.functional()
@jit tuplef2(A)
@test all(Array(A) .≈ 5)
else
@code_hlo optimize = :before_kernel tuplef2(A)
end

end

A = ConcreteRArray(fill(1))
if CUDA.functional()
@jit tuplef2(A)
@test all(Array(A) .≈ 5)
else
@code_hlo optimize = :before_kernel tuplef2(A)
end

end
Comment on lines +110 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
A = ConcreteRArray(fill(1))
if CUDA.functional()
@jit tuplef2(A)
@test all(Array(A) .≈ 5)
else
@code_hlo optimize = :before_kernel tuplef2(A)
end
end

end
Loading