-
Notifications
You must be signed in to change notification settings - Fork 33
linearize kernel args #497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5fd4c39
839050b
5290b0b
7c40198
020fcd3
832a20c
a93e455
fa2a1df
8127ed6
42496d8
7edc8a0
f3e1f03
9eacb5e
9fcf444
97e3f79
6e9a5f2
0b2df16
1880c7e
c5a5c14
eef65ce
1f3a6a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶 Reactant.jl/ext/ReactantCUDAExt.jl Lines 462 to 463 in 7edc8a0
|
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -281,9 +281,13 @@ function compile(job) | |||||||||||||||||||||||||||||
| # TODO: on 1.9, this actually creates a context. cache those. | ||||||||||||||||||||||||||||||
| entry = GPUCompiler.JuliaContext() do ctx | ||||||||||||||||||||||||||||||
| mod, meta = GPUCompiler.compile( | ||||||||||||||||||||||||||||||
| # :llvm, job; optimize=false, cleanup=false, validate=false, libraries=true | ||||||||||||||||||||||||||||||
| :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| # :llvm, job; optimize=false, cleanup=false, validate=true, libraries=false | ||||||||||||||||||||||||||||||
| # :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| GPUCompiler.link_library!(mod, GPUCompiler.load_runtime(job)) | ||||||||||||||||||||||||||||||
| entryname = LLVM.name(meta.entry) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| GPUCompiler.optimize_module!(job, mod) | ||||||||||||||||||||||||||||||
|
|
@@ -319,6 +323,8 @@ function compile(job) | |||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # GPUCompiler.check_ir(job, mod) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| LLVM.strip_debuginfo!(mod) | ||||||||||||||||||||||||||||||
| modstr = string(mod) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
@@ -363,6 +369,38 @@ function to_bytes(x) | |||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| function Reactant.make_tracer(seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...) | ||||||||||||||||||||||||||||||
| x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))::TracedRArray | ||||||||||||||||||||||||||||||
| Reactant.make_tracer(seen, x, path, mode; kwargs...) | ||||||||||||||||||||||||||||||
| return prev | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| function get_field_offset(T::Type, path) | ||||||||||||||||||||||||||||||
| offset = 0 | ||||||||||||||||||||||||||||||
| current_type = T | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| for field in path | ||||||||||||||||||||||||||||||
| # Get the field index | ||||||||||||||||||||||||||||||
| field_idx = if field isa Integer | ||||||||||||||||||||||||||||||
| field | ||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||
| @assert field isa Symbol | ||||||||||||||||||||||||||||||
| findfirst(==(field), fieldnames(current_type)) | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
| if field_idx === nothing | ||||||||||||||||||||||||||||||
| error("Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path") | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Add the offset of this field | ||||||||||||||||||||||||||||||
| offset += fieldoffset(current_type, field_idx) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Update current_type to the field's type for next iteration | ||||||||||||||||||||||||||||||
| current_type = fieldtype(current_type, field_idx) | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| return offset | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( | ||||||||||||||||||||||||||||||
| args...; | ||||||||||||||||||||||||||||||
| convert=Val(false), | ||||||||||||||||||||||||||||||
|
|
@@ -384,20 +422,19 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( | |||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| wrapper_tys = MLIR.IR.Type[] | ||||||||||||||||||||||||||||||
| ctx = MLIR.IR.context() | ||||||||||||||||||||||||||||||
| cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1), 1)) | ||||||||||||||||||||||||||||||
| for (i, a) in Tuple{Int, Any}[(0, func.f), enumerate(args)...] | ||||||||||||||||||||||||||||||
| if sizeof(a) == 0 | ||||||||||||||||||||||||||||||
| cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1)) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # linearize kernel arguments | ||||||||||||||||||||||||||||||
| seen = Reactant.OrderedIdDict() | ||||||||||||||||||||||||||||||
| prev = Any[func.f, args...] | ||||||||||||||||||||||||||||||
| kernelargsym = gensym("kernelarg") | ||||||||||||||||||||||||||||||
| Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.TracedTrack) | ||||||||||||||||||||||||||||||
| wrapper_tys = MLIR.IR.Type[] | ||||||||||||||||||||||||||||||
| for arg in values(seen) | ||||||||||||||||||||||||||||||
| if !(arg isa TracedRArray || arg isa TracedRNumber) | ||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
| if a isa CuTracedArray | ||||||||||||||||||||||||||||||
| a = | ||||||||||||||||||||||||||||||
| Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
| if a isa TracedRArray || a isa TracedRNumber | ||||||||||||||||||||||||||||||
| push!(wrapper_tys, cullvm_ty) | ||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
| # Per below we assume we can inline all other types directly in | ||||||||||||||||||||||||||||||
| push!(wrapper_tys, cullvm_ty) | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| sym_name = String(gensym("call_$fname")) | ||||||||||||||||||||||||||||||
|
|
@@ -426,20 +463,60 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( | |||||||||||||||||||||||||||||
| gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type")) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| trueidx = 1 | ||||||||||||||||||||||||||||||
| for (i, a) in Tuple{Int, Any}[(0, func.f), enumerate(args)...] | ||||||||||||||||||||||||||||||
| allocs = Union{Tuple{MLIR.IR.Value, MLIR.IR.Type}, Nothing}[] | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| llvmptr = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0)) | ||||||||||||||||||||||||||||||
| i8 = MLIR.IR.Type(UInt8) | ||||||||||||||||||||||||||||||
| allargs = [func.f, args...] | ||||||||||||||||||||||||||||||
| for a in allargs | ||||||||||||||||||||||||||||||
| if sizeof(a) == 0 | ||||||||||||||||||||||||||||||
| push!(allocs, nothing) | ||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
| if a isa CuTracedArray | ||||||||||||||||||||||||||||||
| a = | ||||||||||||||||||||||||||||||
| Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # TODO check for only integer and explicitly non cutraced types | ||||||||||||||||||||||||||||||
| MLIR.IR.block!(wrapbody) do | ||||||||||||||||||||||||||||||
| argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1)) | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| trueidx += 1 | ||||||||||||||||||||||||||||||
| c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1) | ||||||||||||||||||||||||||||||
| alloc = MLIR.IR.result(MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr), 1) | ||||||||||||||||||||||||||||||
|
Comment on lines
+481
to
+482
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| push!(allocs, (alloc, argty)) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| sz = sizeof(a) | ||||||||||||||||||||||||||||||
| array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz)) | ||||||||||||||||||||||||||||||
| cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1) | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| MLIR.Dialects.llvm.store(cdata, alloc) | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
| if a isa TracedRArray || a isa TracedRNumber | ||||||||||||||||||||||||||||||
| push!(rarrays, a) | ||||||||||||||||||||||||||||||
| arg = a.mlir_data | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| argidx = 1 | ||||||||||||||||||||||||||||||
| for arg in values(seen) | ||||||||||||||||||||||||||||||
| if !(arg isa TracedRArray || arg isa TracedRNumber) | ||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
| for p in Reactant.TracedUtils.get_paths(arg) | ||||||||||||||||||||||||||||||
| if p[1] !== kernelargsym | ||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| arg = arg.mlir_data | ||||||||||||||||||||||||||||||
| arg = Reactant.TracedUtils.transpose_val(arg) | ||||||||||||||||||||||||||||||
| push!(restys, MLIR.IR.type(arg)) | ||||||||||||||||||||||||||||||
| push!(mlir_args, arg) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| # Get the allocation corresponding to which arg we're doing | ||||||||||||||||||||||||||||||
| alloc = allocs[p[2]][1] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # we need to now compute the offset in bytes of the path | ||||||||||||||||||||||||||||||
| julia_arg = allargs[p[2]] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| offset = get_field_offset(typeof(julia_arg), p[3:end]) | ||||||||||||||||||||||||||||||
| MLIR.IR.block!(wrapbody) do | ||||||||||||||||||||||||||||||
| ptr = MLIR.IR.result(MLIR.Dialects.llvm.getelementptr(alloc, MLIR.IR.Value[], res=llvmptr, elem_type=i8, rawConstantIndices=MLIR.IR.Attribute([Int32(offset)])), 1) | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| MLIR.Dialects.llvm.store(MLIR.IR.argument(wrapbody, argidx), ptr) | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| push!( | ||||||||||||||||||||||||||||||
| aliases, | ||||||||||||||||||||||||||||||
| MLIR.IR.Attribute( | ||||||||||||||||||||||||||||||
|
|
@@ -453,30 +530,20 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( | |||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| push!(wrapargs, MLIR.IR.argument(wrapbody, argidx)) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| argidx += 1 | ||||||||||||||||||||||||||||||
| trueidx += 1 | ||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # TODO check for only integer and explicitly non cutraced types | ||||||||||||||||||||||||||||||
| @show "Warning: using fallback for kernel argument type conversion for argument of type $(Core.Typeof(a)), if this contains a CuTracedArray this will segfault" | ||||||||||||||||||||||||||||||
| MLIR.IR.block!(wrapbody) do | ||||||||||||||||||||||||||||||
| argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1)) | ||||||||||||||||||||||||||||||
| trueidx += 1 | ||||||||||||||||||||||||||||||
| c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1) | ||||||||||||||||||||||||||||||
| alloc = MLIR.IR.result(MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0))), 1) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| sz = sizeof(a) | ||||||||||||||||||||||||||||||
| array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz)) | ||||||||||||||||||||||||||||||
| cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1) | ||||||||||||||||||||||||||||||
| MLIR.Dialects.llvm.store(cdata, alloc) | ||||||||||||||||||||||||||||||
| argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1) | ||||||||||||||||||||||||||||||
| push!(wrapargs, argres) | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| MLIR.IR.block!(wrapbody) do | ||||||||||||||||||||||||||||||
| for arg in allocs | ||||||||||||||||||||||||||||||
| if arg === nothing | ||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
| alloc, argty = arg | ||||||||||||||||||||||||||||||
| argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1) | ||||||||||||||||||||||||||||||
| push!(wrapargs, argres) | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
| MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[])) | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| MLIR.Dialects.llvm.return_(nothing) | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
@@ -500,8 +567,14 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( | |||||||||||||||||||||||||||||
| fn=MLIR.IR.FlatSymbolRefAttribute(sym_name), | ||||||||||||||||||||||||||||||
| output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases) | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| for (i, res) in enumerate(rarrays) | ||||||||||||||||||||||||||||||
| res.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, i)) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| argidx = 1 | ||||||||||||||||||||||||||||||
| for arg in values(seen) | ||||||||||||||||||||||||||||||
| if !(arg isa TracedRArray || arg isa TracedRNumber) | ||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
| arg.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, argidx)) | ||||||||||||||||||||||||||||||
| argidx+=1 | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
@@ -546,6 +619,12 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( | |||||||||||||||||||||||||||||
| return Core.Typeof(res)(f, res.entry) | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| function Reactant.traced_type( | ||||||||||||||||||||||||||||||
| ::Type{A}, seen::ST, ::Val{mode}, track_numbers | ||||||||||||||||||||||||||||||
| ) where {A<:CuTracedArray,ST,mode} | ||||||||||||||||||||||||||||||
| return A | ||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| function Reactant.traced_type( | ||||||||||||||||||||||||||||||
| ::Type{A}, seen::ST, ::Val{mode}, track_numbers | ||||||||||||||||||||||||||||||
| ) where {T,N,A<:CUDA.CuArray{T,N},ST,mode} | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -293,7 +293,7 @@ function optimization_passes(; no_nan::Bool=false) | |||||||||||||||
| ) | ||||||||||||||||
| func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",") | ||||||||||||||||
| return join( | ||||||||||||||||
| ["inline{default-pipeline=canonicalize max-iterations=4}", func_passes], ',' | ||||||||||||||||
| ["inline{default-pipeline=canonicalize max-iterations=4}", "libdevice-funcs-raise", func_passes], ',' | ||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||
| ) | ||||||||||||||||
| end | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶 Reactant.jl/test/integration/cuda.jl Lines 75 to 82 in 832a20c
|
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -72,9 +72,6 @@ function smul!(x) | |||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| @static if !Sys.isapple() | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Broken pending jll update | ||||||||||||||||||||||||||||||||||||||||||||||||
| @static if false | ||||||||||||||||||||||||||||||||||||||||||||||||
| @testset "Constant Op Kernel" begin | ||||||||||||||||||||||||||||||||||||||||||||||||
| oA = collect(1:1:64) | ||||||||||||||||||||||||||||||||||||||||||||||||
| A = Reactant.to_rarray(oA) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -87,4 +84,37 @@ end | |||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
85
to
86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| function tuplef!(tup) | ||||||||||||||||||||||||||||||||||||||||||||||||
| tup[1][] += 2 | ||||||||||||||||||||||||||||||||||||||||||||||||
| return nothing | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| function tuplef2!(tup) | ||||||||||||||||||||||||||||||||||||||||||||||||
| tup[2][] *= tup[1] | ||||||||||||||||||||||||||||||||||||||||||||||||
| return nothing | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| tuplef(a) = @cuda threads=1 tuplef!((a,)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| tuplef2(a) = @cuda threads=1 tuplef2!((5, a)) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+98
to
+99
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| @static if !Sys.isapple() | ||||||||||||||||||||||||||||||||||||||||||||||||
| @testset "Structured Kernel Arguments" begin | ||||||||||||||||||||||||||||||||||||||||||||||||
| A = ConcreteRArray(fill(1)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if CUDA.functional() | ||||||||||||||||||||||||||||||||||||||||||||||||
| @jit tuplef(A) | ||||||||||||||||||||||||||||||||||||||||||||||||
| @test all(Array(A) .≈ 3) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| @code_hlo optimize = :before_kernel tuplef(A) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+102
to
+108
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| A = ConcreteRArray(fill(1)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if CUDA.functional() | ||||||||||||||||||||||||||||||||||||||||||||||||
| @jit tuplef2(A) | ||||||||||||||||||||||||||||||||||||||||||||||||
| @test all(Array(A) .≈ 5) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| @code_hlo optimize = :before_kernel tuplef2(A) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+110
to
+119
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/ext/ReactantCUDAExt.jl
Lines 360 to 372 in 7edc8a0