diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index db13cffc98..db8f518583 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -28,7 +28,7 @@ for file in [ "Gpu.jl", "Affine.jl", "TPU.jl", - "Triton.jl" + "Triton.jl", ] build_file(joinpath(src_dir, "mlir", "Dialects", file)) end diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 847d94052d..ca4d6efdff 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -225,7 +225,7 @@ struct LLVMFunc{F,tt} entry::String end -function Base.getproperty(f::LLVMFunc{F, tt}, sym::Symbol) where {F, tt} +function Base.getproperty(f::LLVMFunc{F,tt}, sym::Symbol) where {F,tt} if sym === :fun f else @@ -235,8 +235,14 @@ end # TODO in the future we may want to avoid doing a second cufunction compilation # for computing the thread/block count (or potentially do it ourselves). -@noinline function CUDA.launch_configuration(f::LLVMFunc{F, tt}; shmem::Union{Integer, Base.Callable}=0, max_threads::Integer=0) where {F, tt} - CUDA.launch_configuration(Base.inferencebarrier(CUDA.cufunction)(f.f, Tuple{tt.parameters[2:end]...}).fun; shmem, max_threads) +@noinline function CUDA.launch_configuration( + f::LLVMFunc{F,tt}; shmem::Union{Integer,Base.Callable}=0, max_threads::Integer=0 +) where {F,tt} + return CUDA.launch_configuration( + Base.inferencebarrier(CUDA.cufunction)(f.f, Tuple{tt.parameters[2:end]...}).fun; + shmem, + max_threads, + ) end const GPUCompiler = CUDA.GPUCompiler @@ -282,7 +288,12 @@ function compile(job) entry = GPUCompiler.JuliaContext() do ctx mod, meta = GPUCompiler.compile( # :llvm, job; optimize=false, cleanup=false, validate=false, libraries=true - :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false + :llvm, + job; + optimize=false, + cleanup=false, + validate=false, + libraries=false, # :llvm, job; optimize=false, cleanup=false, validate=true, libraries=false # :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false ) @@ -357,19 +368,21 @@ function link(job, compiled) end function to_bytes(x) - sz = sizeof(x) - ref = Ref(x) - GC.@preserve ref begin - ptr = Base.reinterpret(Ptr{UInt8}, Base.unsafe_convert(Ptr{Cvoid}, ref)) - vec = Vector{UInt8}(undef, sz) - for i in 1:sz - @inbounds vec[i] = Base.unsafe_load(ptr, i) - end - vec - end -end - -function Reactant.make_tracer(seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...) + sz = sizeof(x) + ref = Ref(x) + GC.@preserve ref begin + ptr = Base.reinterpret(Ptr{UInt8}, Base.unsafe_convert(Ptr{Cvoid}, ref)) + vec = Vector{UInt8}(undef, sz) + for i in 1:sz + @inbounds vec[i] = Base.unsafe_load(ptr, i) + end + vec + end +end + +function Reactant.make_tracer( + seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs... +) x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))::TracedRArray Reactant.make_tracer(seen, x, path, mode; kwargs...) return prev @@ -388,7 +401,9 @@ function get_field_offset(T::Type, path) findfirst(==(field), fieldnames(current_type)) end if field_idx === nothing - error("Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path") + error( + "Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path", + ) end # Add the offset of this field @@ -419,7 +434,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( rarrays = TracedRArray[] fname = func.entry - + wrapper_tys = MLIR.IR.Type[] ctx = MLIR.IR.context() cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1)) @@ -436,19 +451,23 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( end push!(wrapper_tys, cullvm_ty) end - + sym_name = String(gensym("call_$fname")) mod = MLIR.IR.mmodule() - CConv=MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvPTX_Kernel)) + CConv = MLIR.IR.Attribute( + MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvPTX_Kernel) + ) voidty = MLIR.IR.Type(MLIR.API.mlirLLVMVoidTypeGet(ctx)) - wrapftype = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGet(voidty, length(wrapper_tys), wrapper_tys, false)) + wrapftype = MLIR.IR.Type( + MLIR.API.mlirLLVMFunctionTypeGet(voidty, length(wrapper_tys), wrapper_tys, false) + ) wrapfunc = MLIR.IR.block!(MLIR.IR.body(mod)) do return MLIR.Dialects.llvm.func(; sym_name, sym_visibility=MLIR.IR.Attribute("private"), function_type=wrapftype, body=MLIR.IR.Region(), - CConv + CConv, ) end wrapbody = MLIR.IR.Block(wrapper_tys, [MLIR.IR.Location() for _ in wrapper_tys]) @@ -459,11 +478,17 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( symtab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)) gpufunc = MLIR.IR.lookup(symtab, fname) - MLIR.IR.attr!(gpufunc, "CConv", MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvC))) - gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type")) + MLIR.IR.attr!( + gpufunc, + "CConv", + MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvC)), + ) + gpu_function_type = MLIR.IR.Type( + Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type") + ) trueidx = 1 - allocs = Union{Tuple{MLIR.IR.Value, MLIR.IR.Type}, Nothing}[] + allocs = Union{Tuple{MLIR.IR.Value,MLIR.IR.Type},Nothing}[] llvmptr = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0)) i8 = MLIR.IR.Type(UInt8) @@ -476,18 +501,34 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( # TODO check for only integer and explicitly non cutraced types MLIR.IR.block!(wrapbody) do - argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1)) + argty = MLIR.IR.Type( + MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx - 1) + ) trueidx += 1 - c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1) - alloc = MLIR.IR.result(MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr), 1) + c1 = MLIR.IR.result( + MLIR.Dialects.llvm.mlir_constant(; + res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1) + ), + 1, + ) + alloc = MLIR.IR.result( + MLIR.Dialects.llvm.alloca( + c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr + ), + 1, + ) push!(allocs, (alloc, argty)) sz = sizeof(a) array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz)) - cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1) + cdata = MLIR.IR.result( + MLIR.Dialects.llvm.mlir_constant(; + res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a)) + ), + 1, + ) MLIR.Dialects.llvm.store(cdata, alloc) end - end argidx = 1 @@ -499,21 +540,30 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( if p[1] !== kernelargsym continue end - + arg = arg.mlir_data arg = Reactant.TracedUtils.transpose_val(arg) push!(restys, MLIR.IR.type(arg)) push!(mlir_args, arg) - + # Get the allocation corresponding to which arg we're doing alloc = allocs[p[2]][1] # we need to now compute the offset in bytes of the path julia_arg = allargs[p[2]] - + offset = get_field_offset(typeof(julia_arg), p[3:end]) MLIR.IR.block!(wrapbody) do - ptr = MLIR.IR.result(MLIR.Dialects.llvm.getelementptr(alloc, MLIR.IR.Value[], res=llvmptr, elem_type=i8, rawConstantIndices=MLIR.IR.Attribute([Int32(offset)])), 1) + ptr = MLIR.IR.result( + MLIR.Dialects.llvm.getelementptr( + alloc, + MLIR.IR.Value[]; + res=llvmptr, + elem_type=i8, + rawConstantIndices=MLIR.IR.Attribute([Int32(offset)]), + ), + 1, + ) MLIR.Dialects.llvm.store(MLIR.IR.argument(wrapbody, argidx), ptr) end @@ -530,11 +580,11 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( ), ), ) - + argidx += 1 end end - + MLIR.IR.block!(wrapbody) do for arg in allocs if arg === nothing @@ -544,7 +594,12 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1) push!(wrapargs, argres) end - MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[])) + MLIR.Dialects.llvm.call( + wrapargs, + MLIR.IR.Value[]; + callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), + op_bundle_sizes=MLIR.IR.Attribute(Int32[]), + ) MLIR.Dialects.llvm.return_(nothing) end @@ -565,7 +620,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( mlir_args; result_0=restys, fn=MLIR.IR.FlatSymbolRefAttribute(sym_name), - output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases) + output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases), ) argidx = 1 @@ -574,7 +629,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( continue end arg.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, argidx)) - argidx+=1 + argidx += 1 end end diff --git a/src/Compiler.jl b/src/Compiler.jl index 3e3cea1348..569e4b77f6 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -293,7 +293,12 @@ function optimization_passes(; no_nan::Bool=false) ) func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",") return join( - ["inline{default-pipeline=canonicalize max-iterations=4}", "libdevice-funcs-raise", func_passes], ',' + [ + "inline{default-pipeline=canonicalize max-iterations=4}", + "libdevice-funcs-raise", + func_passes, + ], + ',', ) end diff --git a/src/mlir/Dialects/Nvvm.jl b/src/mlir/Dialects/Nvvm.jl index 60f59be07e..8c0017b366 100755 --- a/src/mlir/Dialects/Nvvm.jl +++ b/src/mlir/Dialects/Nvvm.jl @@ -78,18 +78,15 @@ function barrier( attributes = NamedAttribute[] !isnothing(barrierId) && push!(operands, barrierId) !isnothing(numberOfThreads) && push!(operands, numberOfThreads) - push!( - attributes, - operandsegmentsizes([ - if (barrierId == nothing) - 0 - elseif 1(numberOfThreads == nothing) - 0 - else - 1 - end - ]), - ) + push!(attributes, operandsegmentsizes([ + if (barrierId == nothing) + 0 + elseif 1(numberOfThreads == nothing) + 0 + else + 1 + end, + ])) return create_operation( "nvvm.barrier", diff --git a/src/mlir/Dialects/TPU.jl b/src/mlir/Dialects/TPU.jl index ee104ff15d..e0a43e1788 100644 --- a/src/mlir/Dialects/TPU.jl +++ b/src/mlir/Dialects/TPU.jl @@ -902,18 +902,17 @@ function sem_signal( attributes = NamedAttribute[] !isnothing(device_id) && push!(operands, device_id) !isnothing(core_id) && push!(operands, core_id) - push!( - attributes, - operandsegmentsizes([ - 1, 1, if (device_id == nothing) - 0 - elseif 1(core_id == nothing) - 0 - else - 1 - end - ]), - ) + push!(attributes, operandsegmentsizes([ + 1, + 1, + if (device_id == nothing) + 0 + elseif 1(core_id == nothing) + 0 + else + 1 + end, + ])) !isnothing(core_type) && push!(attributes, namedattribute("core_type", core_type)) return create_operation( diff --git a/src/mlir/Dialects/Triton.jl b/src/mlir/Dialects/Triton.jl index f02e239e61..d0eb666c8f 100755 --- a/src/mlir/Dialects/Triton.jl +++ b/src/mlir/Dialects/Triton.jl @@ -482,18 +482,18 @@ function dot_scaled( ] !isnothing(lhs_scale) && push!(operands, lhs_scale) !isnothing(rhs_scale) && push!(operands, rhs_scale) - push!( - attributes, - operandsegmentsizes([ - 1, 1, 1, if (lhs_scale == nothing) - 0 - elseif 1(rhs_scale == nothing) - 0 - else - 1 - end - ]), - ) + push!(attributes, operandsegmentsizes([ + 1, + 1, + 1, + if (lhs_scale == nothing) + 0 + elseif 1(rhs_scale == nothing) + 0 + else + 1 + end, + ])) return create_operation( "tt.dot_scaled", @@ -949,16 +949,16 @@ function load( attributes = NamedAttribute[] !isnothing(mask) && push!(operands, mask) !isnothing(other) && push!(operands, other) - push!( - attributes, - operandsegmentsizes([1, if (mask == nothing) + push!(attributes, operandsegmentsizes([ + 1, + if (mask == nothing) 0 elseif 1(other == nothing) 0 else 1 - end]), - ) + end, + ])) !isnothing(result) && push!(op_ty_results, result) !isnothing(boundaryCheck) && push!(attributes, namedattribute("boundaryCheck", boundaryCheck)) diff --git a/src/utils.jl b/src/utils.jl index 47a5ea72c6..2f79036cf8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -95,7 +95,9 @@ function should_rewrite_ft(@nospecialize(ft)) return false end if ft <: Core.Function - if hasfield(typeof(ft), :name) && hasfield(typeof(ft.name), :name) && isdefined(ft.name, :name) + if hasfield(typeof(ft), :name) && + hasfield(typeof(ft.name), :name) && + isdefined(ft.name, :name) namestr = String(ft.name.name) if startswith(namestr, "##(overlay (. Reactant (inert REACTANT_METHOD_TABLE)") return false @@ -172,8 +174,6 @@ function should_rewrite_ft(@nospecialize(ft)) return false end - - # Default assume all functions need to be reactant-ified return true end diff --git a/test/integration/cuda.jl b/test/integration/cuda.jl index 598bd9b615..ca445e3e26 100644 --- a/test/integration/cuda.jl +++ b/test/integration/cuda.jl @@ -17,19 +17,19 @@ function square!(x, y) end @static if !Sys.isapple() -@testset "Square Kernel" begin - oA = collect(1:1:64) - A = Reactant.to_rarray(oA) - B = Reactant.to_rarray(100 .* oA) - if CUDA.functional() - @jit square!(A, B) - @test all(Array(A) .≈ (oA .* oA .* 100)) - @test all(Array(B) .≈ (oA .* 100)) - else - @code_hlo optimize = :before_kernel square!(A, B) + @testset "Square Kernel" begin + oA = collect(1:1:64) + A = Reactant.to_rarray(oA) + B = Reactant.to_rarray(100 .* oA) + if CUDA.functional() + @jit square!(A, B) + @test all(Array(A) .≈ (oA .* oA .* 100)) + @test all(Array(B) .≈ (oA .* 100)) + else + @code_hlo optimize = :before_kernel square!(A, B) + end end end -end function sin_kernel!(x, y) i = threadIdx().x @@ -44,19 +44,19 @@ function sin!(x, y) end @static if !Sys.isapple() -@testset "Sin Kernel" begin - oA = collect(Float64, 1:1:64) - A = Reactant.to_rarray(oA) - B = Reactant.to_rarray(100 .* oA) - if CUDA.functional() - @jit sin!(A, B) - @test all(Array(A) .≈ oA .* sin.(oA .* 100)) - @test all(Array(B) .≈ (oA .* 100)) - else - @code_hlo optimize = :before_kernel sin!(A, B) + @testset "Sin Kernel" begin + oA = collect(Float64, 1:1:64) + A = Reactant.to_rarray(oA) + B = Reactant.to_rarray(100 .* oA) + if CUDA.functional() + @jit sin!(A, B) + @test all(Array(A) .≈ oA .* sin.(oA .* 100)) + @test all(Array(B) .≈ (oA .* 100)) + else + @code_hlo optimize = :before_kernel sin!(A, B) + end end end -end function smul_kernel!(x, y) i = threadIdx().x @@ -72,18 +72,17 @@ function smul!(x) end @static if !Sys.isapple() -@testset "Constant Op Kernel" begin - oA = collect(1:1:64) - A = Reactant.to_rarray(oA) - if CUDA.functional() - @jit smul!(A) - @test all(Array(A) .≈ oA .* 15) - else - @code_hlo optimize = :before_kernel smul!(A) + @testset "Constant Op Kernel" begin + oA = collect(1:1:64) + A = Reactant.to_rarray(oA) + if CUDA.functional() + @jit smul!(A) + @test all(Array(A) .≈ oA .* 15) + else + @code_hlo optimize = :before_kernel smul!(A) + end end end -end - function tuplef!(tup) tup[1][] += 2 @@ -95,26 +94,25 @@ function tuplef2!(tup) return nothing end -tuplef(a) = @cuda threads=1 tuplef!((a,)) -tuplef2(a) = @cuda threads=1 tuplef2!((5, a)) +tuplef(a) = @cuda threads = 1 tuplef!((a,)) +tuplef2(a) = @cuda threads = 1 tuplef2!((5, a)) @static if !Sys.isapple() -@testset "Structured Kernel Arguments" begin - A = ConcreteRArray(fill(1)) - if CUDA.functional() - @jit tuplef(A) - @test all(Array(A) .≈ 3) - else - @code_hlo optimize = :before_kernel tuplef(A) + @testset "Structured Kernel Arguments" begin + A = ConcreteRArray(fill(1)) + if CUDA.functional() + @jit tuplef(A) + @test all(Array(A) .≈ 3) + else + @code_hlo optimize = :before_kernel tuplef(A) + end + + A = ConcreteRArray(fill(1)) + if CUDA.functional() + @jit tuplef2(A) + @test all(Array(A) .≈ 5) + else + @code_hlo optimize = :before_kernel tuplef2(A) + end end - - A = ConcreteRArray(fill(1)) - if CUDA.functional() - @jit tuplef2(A) - @test all(Array(A) .≈ 5) - else - @code_hlo optimize = :before_kernel tuplef2(A) - end - -end end diff --git a/test/integration/kernelabstractions.jl b/test/integration/kernelabstractions.jl index 51d3e4c8cd..af6c0e80ba 100644 --- a/test/integration/kernelabstractions.jl +++ b/test/integration/kernelabstractions.jl @@ -14,14 +14,14 @@ using CUDA: CuArray tmp_sum += a[i, k] * a[k, j] end - output[i, j] = tmp_sum + return output[i, j] = tmp_sum end # Creating a wrapper kernel for launching with error checks function matmul!(output, a, backend) kernel! = matmul_kernel!(backend) - kernel!(output, a, ndrange = size(output)) - KernelAbstractions.synchronize(backend) + kernel!(output, a; ndrange=size(output)) + return KernelAbstractions.synchronize(backend) end @testset "KernelAbstractions Call" begin @@ -31,4 +31,3 @@ end @jit matmul!(out, A, backend) @test all(Array(out) .≈ 100) end -