diff --git a/Project.toml b/Project.toml index 654e091cb8..868ea7473b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal "] -version = "0.2.17" +version = "0.2.18" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -67,7 +67,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.3" -Reactant_jll = "0.0.37" +Reactant_jll = "0.0.39" Scratch = "1.2" SpecialFunctions = "2" Statistics = "1.10" diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 3a54f213aa..f7ada88b22 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -629,7 +629,8 @@ static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName, static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op, mlir::ModuleOp source, mlir::ModuleOp target, - unsigned &lastUsedID) { + unsigned &lastUsedID, + bool &shouldRemove) { using namespace llvm; using namespace mlir; @@ -639,6 +640,13 @@ static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op, return success(); } + if (auto func = dyn_cast(op.getOperation())) { + if (func.isExternal()) { + shouldRemove = true; + return success(); + } + } + StringAttr newSymName = renameSymbol(opName, lastUsedID, source, target); if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, source))) @@ -658,7 +666,7 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, unsigned lastUsedID = 0; - for (auto &op : *newMod.getBody()) { + for (auto &op : make_early_inc_range(*newMod.getBody())) { auto symbolOp = dyn_cast(op); if (!symbolOp) continue; @@ -669,10 +677,14 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, entryFn = &op; } - if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID))) { + bool shouldRemove = false; + if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID, shouldRemove))) { assert(0 && "failed to update all uses"); } - SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private); + if (shouldRemove) + op.erase(); + else + SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private); } prevMod.getBody()->getOperations().splice( prevMod.getBody()->getOperations().end(), diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 21c094c4a3..6bdffaffd9 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "85612ea74731f02aa4e30800038e065912d37ae2" +ENZYMEXLA_COMMIT = "4d7c91e5d71fc98b901f7aa40b6deacb449fa873" ENZYMEXLA_SHA256 = "" http_archive( @@ -138,7 +138,9 @@ http_archive( patches = ["@enzyme_ad//:patches/jax.patch"], ) -load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") +# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") +XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5" +XLA_SHA256 = "" http_archive( name = "xla", diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 811698210d..847d94052d 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -281,9 +281,13 @@ function compile(job) # TODO: on 1.9, this actually creates a context. cache those. entry = GPUCompiler.JuliaContext() do ctx mod, meta = GPUCompiler.compile( + # :llvm, job; optimize=false, cleanup=false, validate=false, libraries=true :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false + # :llvm, job; optimize=false, cleanup=false, validate=true, libraries=false + # :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false ) + GPUCompiler.link_library!(mod, GPUCompiler.load_runtime(job)) entryname = LLVM.name(meta.entry) GPUCompiler.optimize_module!(job, mod) @@ -319,6 +323,8 @@ function compile(job) end end + # GPUCompiler.check_ir(job, mod) + LLVM.strip_debuginfo!(mod) modstr = string(mod) @@ -363,6 +369,38 @@ function to_bytes(x) end end +function Reactant.make_tracer(seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...) + x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))::TracedRArray + Reactant.make_tracer(seen, x, path, mode; kwargs...) + return prev +end + +function get_field_offset(T::Type, path) + offset = 0 + current_type = T + + for field in path + # Get the field index + field_idx = if field isa Integer + field + else + @assert field isa Symbol + findfirst(==(field), fieldnames(current_type)) + end + if field_idx === nothing + error("Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path") + end + + # Add the offset of this field + offset += fieldoffset(current_type, field_idx) + + # Update current_type to the field's type for next iteration + current_type = fieldtype(current_type, field_idx) + end + + return offset +end + Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( args...; convert=Val(false), @@ -384,20 +422,19 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( wrapper_tys = MLIR.IR.Type[] ctx = MLIR.IR.context() - cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1), 1)) - for (i, a) in Tuple{Int, Any}[(0, func.f), enumerate(args)...] - if sizeof(a) == 0 + cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1)) + + # linearize kernel arguments + seen = Reactant.OrderedIdDict() + prev = Any[func.f, args...] + kernelargsym = gensym("kernelarg") + Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.TracedTrack) + wrapper_tys = MLIR.IR.Type[] + for arg in values(seen) + if !(arg isa TracedRArray || arg isa TracedRNumber) continue end - if a isa CuTracedArray - a = - Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray - end - if a isa TracedRArray || a isa TracedRNumber - push!(wrapper_tys, cullvm_ty) - continue - end - # Per below we assume we can inline all other types directly in + push!(wrapper_tys, cullvm_ty) end sym_name = String(gensym("call_$fname")) @@ -426,20 +463,60 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type")) trueidx = 1 - for (i, a) in Tuple{Int, Any}[(0, func.f), enumerate(args)...] + allocs = Union{Tuple{MLIR.IR.Value, MLIR.IR.Type}, Nothing}[] + + llvmptr = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0)) + i8 = MLIR.IR.Type(UInt8) + allargs = [func.f, args...] + for a in allargs if sizeof(a) == 0 + push!(allocs, nothing) continue end - if a isa CuTracedArray - a = - Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray + + # TODO check for only integer and explicitly non cutraced types + MLIR.IR.block!(wrapbody) do + argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1)) + trueidx += 1 + c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1) + alloc = MLIR.IR.result(MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr), 1) + push!(allocs, (alloc, argty)) + + sz = sizeof(a) + array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz)) + cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1) + MLIR.Dialects.llvm.store(cdata, alloc) end - if a isa TracedRArray || a isa TracedRNumber - push!(rarrays, a) - arg = a.mlir_data + + end + + argidx = 1 + for arg in values(seen) + if !(arg isa TracedRArray || arg isa TracedRNumber) + continue + end + for p in Reactant.TracedUtils.get_paths(arg) + if p[1] !== kernelargsym + continue + end + + arg = arg.mlir_data arg = Reactant.TracedUtils.transpose_val(arg) push!(restys, MLIR.IR.type(arg)) push!(mlir_args, arg) + + # Get the allocation corresponding to which arg we're doing + alloc = allocs[p[2]][1] + + # we need to now compute the offset in bytes of the path + julia_arg = allargs[p[2]] + + offset = get_field_offset(typeof(julia_arg), p[3:end]) + MLIR.IR.block!(wrapbody) do + ptr = MLIR.IR.result(MLIR.Dialects.llvm.getelementptr(alloc, MLIR.IR.Value[], res=llvmptr, elem_type=i8, rawConstantIndices=MLIR.IR.Attribute([Int32(offset)])), 1) + MLIR.Dialects.llvm.store(MLIR.IR.argument(wrapbody, argidx), ptr) + end + push!( aliases, MLIR.IR.Attribute( @@ -453,30 +530,20 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( ), ), ) - push!(wrapargs, MLIR.IR.argument(wrapbody, argidx)) + argidx += 1 - trueidx += 1 - continue - end - - # TODO check for only integer and explicitly non cutraced types - @show "Warning: using fallback for kernel argument type conversion for argument of type $(Core.Typeof(a)), if this contains a CuTracedArray this will segfault" - MLIR.IR.block!(wrapbody) do - argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1)) - trueidx += 1 - c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1) - alloc = MLIR.IR.result(MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0))), 1) - - sz = sizeof(a) - array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz)) - cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1) - MLIR.Dialects.llvm.store(cdata, alloc) - argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1) - push!(wrapargs, argres) end end MLIR.IR.block!(wrapbody) do + for arg in allocs + if arg === nothing + continue + end + alloc, argty = arg + argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1) + push!(wrapargs, argres) + end MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[])) MLIR.Dialects.llvm.return_(nothing) end @@ -500,8 +567,14 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( fn=MLIR.IR.FlatSymbolRefAttribute(sym_name), output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases) ) - for (i, res) in enumerate(rarrays) - res.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, i)) + + argidx = 1 + for arg in values(seen) + if !(arg isa TracedRArray || arg isa TracedRNumber) + continue + end + arg.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, argidx)) + argidx+=1 end end @@ -546,6 +619,12 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( return Core.Typeof(res)(f, res.entry) end +function Reactant.traced_type( + ::Type{A}, seen::ST, ::Val{mode}, track_numbers +) where {A<:CuTracedArray,ST,mode} + return A +end + function Reactant.traced_type( ::Type{A}, seen::ST, ::Val{mode}, track_numbers ) where {T,N,A<:CUDA.CuArray{T,N},ST,mode} diff --git a/src/Compiler.jl b/src/Compiler.jl index 4dacbaa217..3e3cea1348 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -293,7 +293,7 @@ function optimization_passes(; no_nan::Bool=false) ) func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",") return join( - ["inline{default-pipeline=canonicalize max-iterations=4}", func_passes], ',' + ["inline{default-pipeline=canonicalize max-iterations=4}", "libdevice-funcs-raise", func_passes], ',' ) end diff --git a/src/utils.jl b/src/utils.jl index d1130ed90a..47a5ea72c6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -116,6 +116,9 @@ function should_rewrite_ft(@nospecialize(ft)) if ft.name.name == Symbol("#launch_configuration") return false end + if ft.name.name == Symbol("cudaconvert") + return false + end end end end @@ -161,7 +164,11 @@ function should_rewrite_ft(@nospecialize(ft)) ft <: typeof(Base.getproperty) || ft <: typeof(Base.vect) || ft <: typeof(Base.eltype) || - ft <: typeof(Base.argtail) + ft <: typeof(Base.argtail) || + ft <: typeof(Base.identity) || + ft <: typeof(Base.print) || + ft <: typeof(Base.println) || + ft <: typeof(Adapt.adapt_structure) return false end diff --git a/test/integration/cuda.jl b/test/integration/cuda.jl index 4803a9fe81..598bd9b615 100644 --- a/test/integration/cuda.jl +++ b/test/integration/cuda.jl @@ -72,9 +72,6 @@ function smul!(x) end @static if !Sys.isapple() - -# Broken pending jll update -@static if false @testset "Constant Op Kernel" begin oA = collect(1:1:64) A = Reactant.to_rarray(oA) @@ -87,4 +84,37 @@ end end end + +function tuplef!(tup) + tup[1][] += 2 + return nothing +end + +function tuplef2!(tup) + tup[2][] *= tup[1] + return nothing +end + +tuplef(a) = @cuda threads=1 tuplef!((a,)) +tuplef2(a) = @cuda threads=1 tuplef2!((5, a)) + +@static if !Sys.isapple() +@testset "Structured Kernel Arguments" begin + A = ConcreteRArray(fill(1)) + if CUDA.functional() + @jit tuplef(A) + @test all(Array(A) .≈ 3) + else + @code_hlo optimize = :before_kernel tuplef(A) + end + + A = ConcreteRArray(fill(1)) + if CUDA.functional() + @jit tuplef2(A) + @test all(Array(A) .≈ 5) + else + @code_hlo optimize = :before_kernel tuplef2(A) + end + +end end