diff --git a/Project.toml b/Project.toml index 868ea7473b..c8a12c7ae7 100644 --- a/Project.toml +++ b/Project.toml @@ -67,7 +67,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.3" -Reactant_jll = "0.0.39" +Reactant_jll = "0.0.41" Scratch = "1.2" SpecialFunctions = "2" Statistics = "1.10" diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index c5dee2cade..c7b8f4ff18 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -119,6 +119,15 @@ template T MyValueOrThrow(absl::StatusOr v) { } } +extern "C" void ReactantHandleCuResult(uint32_t curesult) { + if (curesult != 0) { + std::string err = "Bad Cuda Result = " + std::to_string(curesult); + if (ReactantThrowError) { + ReactantThrowError(err.c_str()); + } + } +} + // MLIR C-API extras #pragma region MLIR Extra extern "C" MlirAttribute mlirComplexAttrDoubleGet(MlirContext ctx, @@ -599,7 +608,7 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { prepareRegistry(registry); mlir::registerenzymePasses(); - regsiterenzymeXLAPasses(); + registerenzymexlaPasses(); // Register the standard passes we want. mlir::registerCSEPass(); diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 44424a50fc..c1e9916be7 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -436,6 +436,7 @@ cc_library( "-Wl,-exported_symbol,_ConvertLLVMToMLIR", "-Wl,-exported_symbol,_RegisterEnzymeXLAGPUHandler", "-Wl,-exported_symbol,_ReactantThrowError", +"-Wl,-exported_symbol,_ReactantHandleCuResult", ]}), deps = [ "@enzyme//:EnzymeMLIR", diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 6bdffaffd9..9924552771 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "4d7c91e5d71fc98b901f7aa40b6deacb449fa873" +ENZYMEXLA_COMMIT = "12dc0bf6932befe236eacfcd19ca9522f870f7b9" ENZYMEXLA_SHA256 = "" http_archive( @@ -54,9 +54,6 @@ XLA_PATCHES = XLA_PATCHES + [ sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/backends/cpu/runtime/thunk_executor.h """, """ -sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/stream_executor/host/host_kernel.cc -""", -""" sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/tsl/concurrency/async_value_ref.h """, """ @@ -95,7 +92,7 @@ LLVM_TARGETS = select({ }) + ["AArch64", "X86", "ARM"] # Uncomment these lines to use a custom LLVM commit -# LLVM_COMMIT = "023dbbaa3eeddd537e2376aa7355e3bcef618908" +# LLVM_COMMIT = "b39c5cb6977f35ad727d86b2dd6232099734ffd3" # LLVM_SHA256 = "" # http_archive( # name = "llvm-raw", @@ -138,9 +135,7 @@ http_archive( patches = ["@enzyme_ad//:patches/jax.patch"], ) -# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") -XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5" -XLA_SHA256 = "" +load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") http_archive( name = "xla", diff --git a/deps/ReactantExtra/workspace.bzl b/deps/ReactantExtra/workspace.bzl index 695f1d8578..e69de29bb2 100644 --- a/deps/ReactantExtra/workspace.bzl +++ b/deps/ReactantExtra/workspace.bzl @@ -1,14 +0,0 @@ -ENZYMEXLA_COMMIT = "049a05abfaf23abee646ad26834bb8725c348f51" -ENZYMEXLA_SHA256 = "" - -NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" -NSYNC_SHA256 = "" - -RULES_CC_COMMIT = "c8c38f8c710cbbf834283e4777916b68261b359c" -RULES_CC_SHA256 = "85723d827f080c5e927334f1fb18a294c0b3f94fee6d6b45945f5cdae6ea0fd4" - -RULES_PYTHON_VERSION = "0.34.0" -RULES_PYTHON_SHA256 = "778aaeab3e6cfd56d681c89f5c10d7ad6bf8d2f1a72de9de55b23081b2d31618" - -UPB_COMMIT = "9effcbcb27f0a665f9f345030188c0b291e32482" -UPB_SHA256 = "61d0417abd60e65ed589c9deee7c124fe76a4106831f6ad39464e1525cef1454" diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 2c6870f2c5..f7f638e365 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -239,9 +239,12 @@ function Adapt.adapt_structure( ) end -Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg) +function recudaconvert(arg) return adapt(ReactantKernelAdaptor(), arg) end +Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg) + return recudaconvert(arg) +end function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRArray{T,N}) where {T,N} res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}(xs) @@ -425,6 +428,7 @@ function get_field_offset(T::Type, path) offset = 0 current_type = T + for field in path # Get the field index field_idx = if field isa Integer @@ -440,18 +444,22 @@ function get_field_offset(T::Type, path) end # Add the offset of this field - offset += fieldoffset(current_type, field_idx) + toffset = fieldoffset(current_type, field_idx) + tcurrent_type = fieldtype(current_type, field_idx) + offset += toffset # Update current_type to the field's type for next iteration - current_type = fieldtype(current_type, field_idx) + current_type = tcurrent_type + end + return offset end Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( args...; - convert=Val(false), + convert=Val(true), blocks::CuDim=1, threads::CuDim=1, cooperative::Bool=false, @@ -461,6 +469,10 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( blockdim = CUDA.CuDim3(blocks) threaddim = CUDA.CuDim3(threads) + if convert == Val(true) + args = recudaconvert.(args) + end + mlir_args = MLIR.IR.Value[] restys = MLIR.IR.Type[] aliases = MLIR.IR.Attribute[] @@ -578,6 +590,20 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( push!(restys, MLIR.IR.type(arg)) push!(mlir_args, arg) + push!( + aliases, + MLIR.IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), + length(wrapper_tys) == 1 ? 0 : 1, + length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1), + argidx - 1, + 0, + C_NULL, + ), + ), + ) + for p in paths if p[1] !== kernelargsym continue @@ -602,20 +628,6 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( ) MLIR.Dialects.llvm.store(MLIR.IR.argument(wrapbody, argidx), ptr) end - - push!( - aliases, - MLIR.IR.Attribute( - MLIR.API.stablehloOutputOperandAliasGet( - MLIR.IR.context(), - length(wrapper_tys) == 1 ? 0 : 1, - length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1), - argidx - 1, - 0, - C_NULL, - ), - ), - ) end argidx += 1 end @@ -650,6 +662,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( end location = MLIR.IR.Location() + @assert length(restys) == length(aliases) call = MLIR.Dialects.enzymexla.kernel_call( blk_operands..., mlir_args; @@ -786,6 +799,11 @@ function __init__() Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1) Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2) Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3) + ptr4 = Reactant.XLA.Libdl.dlsym(handle, "cuStreamSynchronize"; throw_error=false) + if ptr4 === nothing + ptr4 = C_NULL + end + Reactant.Compiler.cuSync[] = Base.reinterpret(UInt, ptr4) end return nothing end diff --git a/src/Compiler.jl b/src/Compiler.jl index d6565f7612..7c1b41db8f 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -116,7 +116,7 @@ function create_result( end # Optimization passes via transform dialect -function optimization_passes(; no_nan::Bool=false) +function optimization_passes(; no_nan::Bool=false, sroa::Bool=false) transform_passes_list = [ "patterns=compare_op_canon<16>", "transpose_transpose<16>", @@ -295,12 +295,16 @@ function optimization_passes(; no_nan::Bool=false) ",", ) func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",") - return join( - [ - "inline{default-pipeline=canonicalize max-iterations=4}", - "libdevice-funcs-raise", - func_passes, - ], + passes = [ + "inline{default-pipeline=canonicalize max-iterations=4}" + ] + if sroa + push!(passes, "sroa-wrappers") + push!(passes, "libdevice-funcs-raise") + push!(passes, "canonicalize") + end + push!(passes, func_passes) + return join(passes, ',', ) end @@ -351,6 +355,8 @@ end const cuLaunch = Ref{UInt}(0) const cuFunc = Ref{UInt}(0) const cuModule = Ref{UInt}(0) +const cuSync = Ref{UInt}(0) +const DEBUG_KERNEL = Ref{Bool}(false) function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false) # Explicitly don't use block! to avoid creating a closure, which creates @@ -379,12 +385,20 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: if isdefined(Reactant_jll, :ptxas_path) toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))] end - kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce" + if DEBUG_KERNEL[] + curesulthandler = XLA.Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult") + @assert curesulthandler !== nothing + curesulthandler = Base.reinterpret(UInt, curesulthandler) + kern = "lower-kernel{debug=true cuResultHandlerPtr=$curesulthandler run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) cuStreamSynchronizePtr=$(cuSync[])},symbol-dce" + else + kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce" + end - opt_passes = optimization_passes(; no_nan) + opt_passes = optimization_passes(; no_nan, sroa=true) + opt_passes2 = optimization_passes(; no_nan, sroa=false) if optimize === :all - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) run_pass_pipeline!( mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false ) @@ -395,14 +409,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, kern, ], ',', ), ) elseif optimize === :before_kernel - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) run_pass_pipeline!( mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false ) @@ -413,13 +427,13 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, ], ',', ), ) elseif optimize === :no_enzyme - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) run_pass_pipeline!(mod, "arith-raise{stablehlo=true}"; enable_verifier=false) run_pass_pipeline!( mod, @@ -428,7 +442,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, ], ',', ), @@ -457,14 +471,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, kern, ], ',', ), ) elseif optimize === :before_enzyme - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) run_pass_pipeline!( mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false ) diff --git a/src/Tracing.jl b/src/Tracing.jl index 2f4885147d..36ef8edfbb 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -318,6 +318,18 @@ function Base.showerror(io::IO, err::NoFieldMatchError) ) end +function make_tracer( + seen, + @nospecialize(prev::Union{Base.ExceptionStack, Core.MethodInstance}), + @nospecialize(path), + mode; + toscalar=false, + tobatch=nothing, + track_numbers=(), + kwargs..., +) + return prev +end append_path(path, i) = (path..., i) function make_tracer( @@ -590,7 +602,7 @@ function make_tracer( if mode == ArrayToConcrete return ConcreteRNumber(prev) else - if mode == TracedTrack + if mode == TracedTrack || mode == NoStopTracedTrack res = TracedRNumber{RT}( (path,), TracedUtils.broadcast_to_size(prev, ()).mlir_data ) @@ -638,7 +650,7 @@ end function make_tracer( seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs... ) where {RT<:Array} - if haskey(seen, prev) + if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] end if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive @@ -699,7 +711,7 @@ function make_tracer( end function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...) - if haskey(seen, prev) + if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] end prev2 = prev.contents diff --git a/test/integration/cuda.jl b/test/integration/cuda.jl index da5d3c52ba..4ec52ab869 100644 --- a/test/integration/cuda.jl +++ b/test/integration/cuda.jl @@ -128,7 +128,7 @@ end # maybe weird cuda things function aliased!(tup) x, y = tup - x[2][1] *= y[2][1] + x[1] *= y[1] return nothing end @@ -142,13 +142,11 @@ end @testset "Aliasing arguments" begin a = ConcreteRArray([3]) - s = (10, a) - if CUDA.functional() - @jit aliased((s, s)) - @test all(Array(a) == 9) + @jit aliased(a) + @test all(Array(a) .== 9) else - @code_hlo optimize = :before_kernel aliased(s) + @code_hlo optimize = :before_kernel aliased(a) end end end @@ -170,10 +168,9 @@ end if CUDA.functional() a = CuArray([4]) b = ConcreteRArray([3]) - @jit mixed(a, b) - @test all(Array(a) == 4) - @test all(Array(b) == 12) + @test all(Array(a) .== 4) + @test all(Array(b) .== 12) end end end diff --git a/test/integration/python.jl b/test/integration/python.jl index 54c2eec16d..91921dedfc 100644 --- a/test/integration/python.jl +++ b/test/integration/python.jl @@ -2,6 +2,9 @@ using Reactant using Reactant: Ops using Test + +# Jax on Github CI dislikes X86 macos +@static if !Sys.isapple() || Sys.ARCH != :x86_64 using PythonCall @testset "PythonCall" begin @@ -11,3 +14,4 @@ using PythonCall @test typeof(result) == ConcreteRNumber{Float32} @test result ≈ 6 end +end \ No newline at end of file