From fdcb111316bf4a1099054eabc28792924610288f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 12 Jan 2025 19:27:54 -0500 Subject: [PATCH 01/15] WIP: adapt to sroa jll --- deps/ReactantExtra/API.cpp | 2 +- deps/ReactantExtra/WORKSPACE | 74 ++++++++++++++++++------------------ src/Compiler.jl | 37 ++++++++++-------- 3 files changed, 60 insertions(+), 53 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index c5dee2cade..ceccf70d17 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -599,7 +599,7 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { prepareRegistry(registry); mlir::registerenzymePasses(); - regsiterenzymeXLAPasses(); + registerenzymexlaPasses(); // Register the standard passes we want. mlir::registerCSEPass(); diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 6bdffaffd9..631d8b8a65 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "4d7c91e5d71fc98b901f7aa40b6deacb449fa873" +ENZYMEXLA_COMMIT = "3b217bbfd5680ecd88c20285fe7b5693c541fa8b" ENZYMEXLA_SHA256 = "" http_archive( @@ -95,39 +95,39 @@ LLVM_TARGETS = select({ }) + ["AArch64", "X86", "ARM"] # Uncomment these lines to use a custom LLVM commit -# LLVM_COMMIT = "023dbbaa3eeddd537e2376aa7355e3bcef618908" -# LLVM_SHA256 = "" -# http_archive( -# name = "llvm-raw", -# build_file_content = "# empty", -# sha256 = LLVM_SHA256, -# strip_prefix = "llvm-project-" + LLVM_COMMIT, -# urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)], -# ) -# -# -# load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") -# maybe( -# http_archive, -# name = "llvm_zlib", -# build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD", -# sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731", -# strip_prefix = "zlib-ng-2.0.7", -# urls = [ -# "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip", -# ], -# ) -# -# maybe( -# http_archive, -# name = "llvm_zstd", -# build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD", -# sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0", -# strip_prefix = "zstd-1.5.2", -# urls = [ -# "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz" -# ], -# ) +LLVM_COMMIT = "9b4bf06be33f0fe6a4c487bb9244d8c0f6acab3f" +LLVM_SHA256 = "" +http_archive( + name = "llvm-raw", + build_file_content = "# empty", + sha256 = LLVM_SHA256, + strip_prefix = "llvm-project-" + LLVM_COMMIT, + urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)], +) + + +load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") +maybe( + http_archive, + name = "llvm_zlib", + build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD", + sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731", + strip_prefix = "zlib-ng-2.0.7", + urls = [ + "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip", + ], +) + +maybe( + http_archive, + name = "llvm_zstd", + build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD", + sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0", + strip_prefix = "zstd-1.5.2", + urls = [ + "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz" + ], +) http_archive( name = "jax", @@ -138,9 +138,9 @@ http_archive( patches = ["@enzyme_ad//:patches/jax.patch"], ) -# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") -XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5" -XLA_SHA256 = "" +load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") +# XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5" +# XLA_SHA256 = "" http_archive( name = "xla", diff --git a/src/Compiler.jl b/src/Compiler.jl index d6565f7612..ae30b9c294 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -116,7 +116,7 @@ function create_result( end # Optimization passes via transform dialect -function optimization_passes(; no_nan::Bool=false) +function optimization_passes(; no_nan::Bool=false, sroa::Bool=false) transform_passes_list = [ "patterns=compare_op_canon<16>", "transpose_transpose<16>", @@ -295,12 +295,16 @@ function optimization_passes(; no_nan::Bool=false) ",", ) func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",") - return join( - [ - "inline{default-pipeline=canonicalize max-iterations=4}", - "libdevice-funcs-raise", - func_passes, - ], + passes = [ + "inline{default-pipeline=canonicalize max-iterations=4}" + ] + if sroa + push!(passes, "sroa-wrappers") + push!(passes, "libdevice-funcs-raise") + push!(passes, "canonicalize") + end + push!(passes, func_passes) + return join(passes, ',', ) end @@ -310,6 +314,8 @@ end const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true) + @show pass_pipeline + flush(stdout) pm = MLIR.IR.PassManager() MLIR.IR.enable_verifier!(pm, enable_verifier) opm = MLIR.IR.OpPassManager(pm) @@ -382,9 +388,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce" opt_passes = optimization_passes(; no_nan) + opt_passes2 = optimization_passes(; no_nan, sroa=false) if optimize === :all - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) run_pass_pipeline!( mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false ) @@ -395,14 +402,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, kern, ], ',', ), ) elseif optimize === :before_kernel - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) run_pass_pipeline!( mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false ) @@ -413,13 +420,13 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, ], ',', ), ) elseif optimize === :no_enzyme - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) run_pass_pipeline!(mod, "arith-raise{stablehlo=true}"; enable_verifier=false) run_pass_pipeline!( mod, @@ -428,7 +435,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, ], ',', ), @@ -457,14 +464,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", - opt_passes, + opt_passes2, kern, ], ',', ), ) elseif optimize === :before_enzyme - run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) run_pass_pipeline!( mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false ) From 4b2bd6b77a13ff5d76c8dc64578ee513070a4ebe Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 13 Jan 2025 13:18:12 -0500 Subject: [PATCH 02/15] fixup --- deps/ReactantExtra/WORKSPACE | 2 +- ext/ReactantCUDAExt.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 631d8b8a65..a9921022b2 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "3b217bbfd5680ecd88c20285fe7b5693c541fa8b" +ENZYMEXLA_COMMIT = "1b473e8e77850ece61bb0e85b152b95cb6a70be0" ENZYMEXLA_SHA256 = "" http_archive( diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 2c6870f2c5..1add274123 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -650,6 +650,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( end location = MLIR.IR.Location() + @assert length(restys) == length(aliases) call = MLIR.Dialects.enzymexla.kernel_call( blk_operands..., mlir_args; From 68afe868c99babb9b4ed08583448408d9e333d96 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 13 Jan 2025 15:23:17 -0500 Subject: [PATCH 03/15] fix --- deps/ReactantExtra/WORKSPACE | 4 ++-- ext/ReactantCUDAExt.jl | 28 ++++++++++++++-------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index a9921022b2..f29f2e49bb 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "1b473e8e77850ece61bb0e85b152b95cb6a70be0" +ENZYMEXLA_COMMIT = "1d9cab766d5c35646da956a3b32aff9d61315f90" ENZYMEXLA_SHA256 = "" http_archive( @@ -95,7 +95,7 @@ LLVM_TARGETS = select({ }) + ["AArch64", "X86", "ARM"] # Uncomment these lines to use a custom LLVM commit -LLVM_COMMIT = "9b4bf06be33f0fe6a4c487bb9244d8c0f6acab3f" +LLVM_COMMIT = "b39c5cb6977f35ad727d86b2dd6232099734ffd3" LLVM_SHA256 = "" http_archive( name = "llvm-raw", diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 1add274123..2ba3f527bf 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -578,6 +578,20 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( push!(restys, MLIR.IR.type(arg)) push!(mlir_args, arg) + push!( + aliases, + MLIR.IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), + length(wrapper_tys) == 1 ? 0 : 1, + length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1), + argidx - 1, + 0, + C_NULL, + ), + ), + ) + for p in paths if p[1] !== kernelargsym continue @@ -602,20 +616,6 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( ) MLIR.Dialects.llvm.store(MLIR.IR.argument(wrapbody, argidx), ptr) end - - push!( - aliases, - MLIR.IR.Attribute( - MLIR.API.stablehloOutputOperandAliasGet( - MLIR.IR.context(), - length(wrapper_tys) == 1 ? 0 : 1, - length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1), - argidx - 1, - 0, - C_NULL, - ), - ), - ) end argidx += 1 end From 843c462bebf6ee23ac16a8e39422862776c517a7 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 13 Jan 2025 22:29:47 -0500 Subject: [PATCH 04/15] fixup --- deps/ReactantExtra/WORKSPACE | 2 +- ext/ReactantCUDAExt.jl | 1 + src/Compiler.jl | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index f29f2e49bb..be0a427932 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "1d9cab766d5c35646da956a3b32aff9d61315f90" +ENZYMEXLA_COMMIT = "64a1c283072d4ce4eb319c69b32a6f3c68f30cbe" ENZYMEXLA_SHA256 = "" http_archive( diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 2ba3f527bf..01439e78b3 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -658,6 +658,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( fn=MLIR.IR.FlatSymbolRefAttribute(sym_name), output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases), ) + @show string(call), typeof(func.f), collect(map(typeof, args)) argidx = 1 for arg in values(seen) diff --git a/src/Compiler.jl b/src/Compiler.jl index ae30b9c294..9c65473907 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -387,7 +387,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: end kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce" - opt_passes = optimization_passes(; no_nan) + opt_passes = optimization_passes(; no_nan, sroa=true) opt_passes2 = optimization_passes(; no_nan, sroa=false) if optimize === :all From 4f3c68bc30112ef08adc619d0ec5c5f3779d9815 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 15 Jan 2025 00:09:45 -0500 Subject: [PATCH 05/15] fixup --- deps/ReactantExtra/API.cpp | 9 +++++ deps/ReactantExtra/BUILD | 1 + deps/ReactantExtra/WORKSPACE | 74 ++++++++++++++++++------------------ ext/ReactantCUDAExt.jl | 19 +++++++-- src/Compiler.jl | 11 +++++- src/Tracing.jl | 18 +++++++-- 6 files changed, 88 insertions(+), 44 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index ceccf70d17..c7b8f4ff18 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -119,6 +119,15 @@ template T MyValueOrThrow(absl::StatusOr v) { } } +extern "C" void ReactantHandleCuResult(uint32_t curesult) { + if (curesult != 0) { + std::string err = "Bad Cuda Result = " + std::to_string(curesult); + if (ReactantThrowError) { + ReactantThrowError(err.c_str()); + } + } +} + // MLIR C-API extras #pragma region MLIR Extra extern "C" MlirAttribute mlirComplexAttrDoubleGet(MlirContext ctx, diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 44424a50fc..c1e9916be7 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -436,6 +436,7 @@ cc_library( "-Wl,-exported_symbol,_ConvertLLVMToMLIR", "-Wl,-exported_symbol,_RegisterEnzymeXLAGPUHandler", "-Wl,-exported_symbol,_ReactantThrowError", +"-Wl,-exported_symbol,_ReactantHandleCuResult", ]}), deps = [ "@enzyme//:EnzymeMLIR", diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index be0a427932..c1dd256aa3 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "64a1c283072d4ce4eb319c69b32a6f3c68f30cbe" +ENZYMEXLA_COMMIT = "362f33f518900ebf66cee7f0135a436907f8f692" ENZYMEXLA_SHA256 = "" http_archive( @@ -95,39 +95,39 @@ LLVM_TARGETS = select({ }) + ["AArch64", "X86", "ARM"] # Uncomment these lines to use a custom LLVM commit -LLVM_COMMIT = "b39c5cb6977f35ad727d86b2dd6232099734ffd3" -LLVM_SHA256 = "" -http_archive( - name = "llvm-raw", - build_file_content = "# empty", - sha256 = LLVM_SHA256, - strip_prefix = "llvm-project-" + LLVM_COMMIT, - urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)], -) - - -load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") -maybe( - http_archive, - name = "llvm_zlib", - build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD", - sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731", - strip_prefix = "zlib-ng-2.0.7", - urls = [ - "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip", - ], -) - -maybe( - http_archive, - name = "llvm_zstd", - build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD", - sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0", - strip_prefix = "zstd-1.5.2", - urls = [ - "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz" - ], -) +# LLVM_COMMIT = "b39c5cb6977f35ad727d86b2dd6232099734ffd3" +# LLVM_SHA256 = "" +# http_archive( +# name = "llvm-raw", +# build_file_content = "# empty", +# sha256 = LLVM_SHA256, +# strip_prefix = "llvm-project-" + LLVM_COMMIT, +# urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)], +# ) +# +# +# load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") +# maybe( +# http_archive, +# name = "llvm_zlib", +# build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD", +# sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731", +# strip_prefix = "zlib-ng-2.0.7", +# urls = [ +# "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip", +# ], +# ) +# +# maybe( +# http_archive, +# name = "llvm_zstd", +# build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD", +# sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0", +# strip_prefix = "zstd-1.5.2", +# urls = [ +# "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz" +# ], +# ) http_archive( name = "jax", @@ -138,9 +138,9 @@ http_archive( patches = ["@enzyme_ad//:patches/jax.patch"], ) -load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") -# XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5" -# XLA_SHA256 = "" +# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") +XLA_COMMIT = "1bb4fc18e73faa1c001d96bfe3a22f733987b018" +XLA_SHA256 = "" http_archive( name = "xla", diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 01439e78b3..91e833bad8 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -425,6 +425,7 @@ function get_field_offset(T::Type, path) offset = 0 current_type = T + for field in path # Get the field index field_idx = if field isa Integer @@ -440,11 +441,17 @@ function get_field_offset(T::Type, path) end # Add the offset of this field - offset += fieldoffset(current_type, field_idx) + toffset = fieldoffset(current_type, field_idx) + tcurrent_type = fieldtype(current_type, field_idx) + offset += toffset + @show current_type, field_idx, toffset, offset, tcurrent_type # Update current_type to the field's type for next iteration - current_type = fieldtype(current_type, field_idx) + current_type = tcurrent_type + end + + @show T, path, offset return offset end @@ -552,6 +559,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( 1, ) push!(allocs, (alloc, argty)) + @show string(alloc), string(argty), typeof(a) sz = sizeof(a) array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz)) @@ -658,7 +666,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( fn=MLIR.IR.FlatSymbolRefAttribute(sym_name), output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases), ) - @show string(call), typeof(func.f), collect(map(typeof, args)) + # @show string(call), typeof(func.f), collect(map(typeof, args)) argidx = 1 for arg in values(seen) @@ -788,6 +796,11 @@ function __init__() Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1) Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2) Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3) + ptr4 = Reactant.XLA.Libdl.dlsym(handle, "cuStreamSynchronize"; throw_error=false) + if ptr4 === nothing + ptr4 = C_NULL + end + Reactant.Compiler.cuSync[] = Base.reinterpret(UInt, ptr4) end return nothing end diff --git a/src/Compiler.jl b/src/Compiler.jl index 9c65473907..68070978bd 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -357,6 +357,8 @@ end const cuLaunch = Ref{UInt}(0) const cuFunc = Ref{UInt}(0) const cuModule = Ref{UInt}(0) +const cuSync = Ref{UInt}(0) +const DEBUG_KERNEL = Ref{Bool}(false) function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false) # Explicitly don't use block! to avoid creating a closure, which creates @@ -385,7 +387,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: if isdefined(Reactant_jll, :ptxas_path) toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))] end - kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce" + if DEBUG_KERNEL[] + curesulthandler = XLA.Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult") + @assert curesulthandler !== nothing + curesulthandler = Base.reinterpret(UInt, curesulthandler) + kern = "lower-kernel{debug=true cuResultHandlerPtr=$curesulthandler run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) cuStreamSynchronizePtr=$(cuSync[])},symbol-dce" + else + kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce" + end opt_passes = optimization_passes(; no_nan, sroa=true) opt_passes2 = optimization_passes(; no_nan, sroa=false) diff --git a/src/Tracing.jl b/src/Tracing.jl index 2f4885147d..36ef8edfbb 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -318,6 +318,18 @@ function Base.showerror(io::IO, err::NoFieldMatchError) ) end +function make_tracer( + seen, + @nospecialize(prev::Union{Base.ExceptionStack, Core.MethodInstance}), + @nospecialize(path), + mode; + toscalar=false, + tobatch=nothing, + track_numbers=(), + kwargs..., +) + return prev +end append_path(path, i) = (path..., i) function make_tracer( @@ -590,7 +602,7 @@ function make_tracer( if mode == ArrayToConcrete return ConcreteRNumber(prev) else - if mode == TracedTrack + if mode == TracedTrack || mode == NoStopTracedTrack res = TracedRNumber{RT}( (path,), TracedUtils.broadcast_to_size(prev, ()).mlir_data ) @@ -638,7 +650,7 @@ end function make_tracer( seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs... ) where {RT<:Array} - if haskey(seen, prev) + if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] end if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive @@ -699,7 +711,7 @@ function make_tracer( end function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...) - if haskey(seen, prev) + if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] end prev2 = prev.contents From 149f24b7dca10d6b90f4fc27137a68462dbb4645 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 14 Jan 2025 23:10:32 -0600 Subject: [PATCH 06/15] rmprint --- ext/ReactantCUDAExt.jl | 4 ---- src/Compiler.jl | 2 -- 2 files changed, 6 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 91e833bad8..8a683d33ac 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -444,14 +444,12 @@ function get_field_offset(T::Type, path) toffset = fieldoffset(current_type, field_idx) tcurrent_type = fieldtype(current_type, field_idx) offset += toffset - @show current_type, field_idx, toffset, offset, tcurrent_type # Update current_type to the field's type for next iteration current_type = tcurrent_type end - @show T, path, offset return offset end @@ -559,7 +557,6 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( 1, ) push!(allocs, (alloc, argty)) - @show string(alloc), string(argty), typeof(a) sz = sizeof(a) array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz)) @@ -666,7 +663,6 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( fn=MLIR.IR.FlatSymbolRefAttribute(sym_name), output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases), ) - # @show string(call), typeof(func.f), collect(map(typeof, args)) argidx = 1 for arg in values(seen) diff --git a/src/Compiler.jl b/src/Compiler.jl index 68070978bd..7c1b41db8f 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -314,8 +314,6 @@ end const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true) - @show pass_pipeline - flush(stdout) pm = MLIR.IR.PassManager() MLIR.IR.enable_verifier!(pm, enable_verifier) opm = MLIR.IR.OpPassManager(pm) From cb86b3c3d23851661f891bc70f19f3fb56234fbb Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 15 Jan 2025 00:15:16 -0600 Subject: [PATCH 07/15] fix patch --- deps/ReactantExtra/WORKSPACE | 3 --- 1 file changed, 3 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index c1dd256aa3..82e337b148 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -54,9 +54,6 @@ XLA_PATCHES = XLA_PATCHES + [ sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/backends/cpu/runtime/thunk_executor.h """, """ -sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/stream_executor/host/host_kernel.cc -""", -""" sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/tsl/concurrency/async_value_ref.h """, """ From 9c1381d60fe57654918eaa23e91e3d5d2351963c Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 15 Jan 2025 09:24:34 -0600 Subject: [PATCH 08/15] Update WORKSPACE --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 82e337b148..07876648e3 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "362f33f518900ebf66cee7f0135a436907f8f692" +ENZYMEXLA_COMMIT = "0d2bbcef73e106ce31e0243cb0b38cc5830fb39f" ENZYMEXLA_SHA256 = "" http_archive( From e68b0dd29fa0e1a206e7f8a4b64071d489205b64 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 15 Jan 2025 12:35:34 -0600 Subject: [PATCH 09/15] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 868ea7473b..1e44a5c18a 100644 --- a/Project.toml +++ b/Project.toml @@ -67,7 +67,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.3" -Reactant_jll = "0.0.39" +Reactant_jll = "0.0.40" Scratch = "1.2" SpecialFunctions = "2" Statistics = "1.10" From 060f245bf95abb6c12630d24d8617f77b3aeec3f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 15 Jan 2025 13:00:35 -0600 Subject: [PATCH 10/15] adapt to upstream properly --- deps/ReactantExtra/WORKSPACE | 6 ++---- deps/ReactantExtra/workspace.bzl | 14 -------------- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 07876648e3..9924552771 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "0d2bbcef73e106ce31e0243cb0b38cc5830fb39f" +ENZYMEXLA_COMMIT = "12dc0bf6932befe236eacfcd19ca9522f870f7b9" ENZYMEXLA_SHA256 = "" http_archive( @@ -135,9 +135,7 @@ http_archive( patches = ["@enzyme_ad//:patches/jax.patch"], ) -# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") -XLA_COMMIT = "1bb4fc18e73faa1c001d96bfe3a22f733987b018" -XLA_SHA256 = "" +load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") http_archive( name = "xla", diff --git a/deps/ReactantExtra/workspace.bzl b/deps/ReactantExtra/workspace.bzl index 695f1d8578..e69de29bb2 100644 --- a/deps/ReactantExtra/workspace.bzl +++ b/deps/ReactantExtra/workspace.bzl @@ -1,14 +0,0 @@ -ENZYMEXLA_COMMIT = "049a05abfaf23abee646ad26834bb8725c348f51" -ENZYMEXLA_SHA256 = "" - -NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" -NSYNC_SHA256 = "" - -RULES_CC_COMMIT = "c8c38f8c710cbbf834283e4777916b68261b359c" -RULES_CC_SHA256 = "85723d827f080c5e927334f1fb18a294c0b3f94fee6d6b45945f5cdae6ea0fd4" - -RULES_PYTHON_VERSION = "0.34.0" -RULES_PYTHON_SHA256 = "778aaeab3e6cfd56d681c89f5c10d7ad6bf8d2f1a72de9de55b23081b2d31618" - -UPB_COMMIT = "9effcbcb27f0a665f9f345030188c0b291e32482" -UPB_SHA256 = "61d0417abd60e65ed589c9deee7c124fe76a4106831f6ad39464e1525cef1454" From 46e8210e230caa7e6d3c5455a363e0932e1e569b Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 15 Jan 2025 14:35:03 -0500 Subject: [PATCH 11/15] cuconvert --- ext/ReactantCUDAExt.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 8a683d33ac..f7f638e365 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -239,9 +239,12 @@ function Adapt.adapt_structure( ) end -Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg) +function recudaconvert(arg) return adapt(ReactantKernelAdaptor(), arg) end +Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg) + return recudaconvert(arg) +end function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRArray{T,N}) where {T,N} res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}(xs) @@ -456,7 +459,7 @@ end Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( args...; - convert=Val(false), + convert=Val(true), blocks::CuDim=1, threads::CuDim=1, cooperative::Bool=false, @@ -466,6 +469,10 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( blockdim = CUDA.CuDim3(blocks) threaddim = CUDA.CuDim3(threads) + if convert == Val(true) + args = recudaconvert.(args) + end + mlir_args = MLIR.IR.Value[] restys = MLIR.IR.Type[] aliases = MLIR.IR.Attribute[] From 475529299216f994d845bb0561978749aadaf0cb Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 15 Jan 2025 17:19:21 -0600 Subject: [PATCH 12/15] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1e44a5c18a..c8a12c7ae7 100644 --- a/Project.toml +++ b/Project.toml @@ -67,7 +67,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.3" -Reactant_jll = "0.0.40" +Reactant_jll = "0.0.41" Scratch = "1.2" SpecialFunctions = "2" Statistics = "1.10" From 63c6c25ceb853f9ae36d2c4bed844af279462dc8 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 15 Jan 2025 18:47:29 -0600 Subject: [PATCH 13/15] fix ci errs --- test/integration/python.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/integration/python.jl b/test/integration/python.jl index 54c2eec16d..91921dedfc 100644 --- a/test/integration/python.jl +++ b/test/integration/python.jl @@ -2,6 +2,9 @@ using Reactant using Reactant: Ops using Test + +# Jax on Github CI dislikes X86 macos +@static if !Sys.isapple() || Sys.ARCH != :x86_64 using PythonCall @testset "PythonCall" begin @@ -11,3 +14,4 @@ using PythonCall @test typeof(result) == ConcreteRNumber{Float32} @test result ≈ 6 end +end \ No newline at end of file From 522221eab25cd008faa4b1caeec7b321a78d3f5a Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 15 Jan 2025 21:57:36 -0500 Subject: [PATCH 14/15] alias --- test/integration/cuda.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/integration/cuda.jl b/test/integration/cuda.jl index da5d3c52ba..1703ff9914 100644 --- a/test/integration/cuda.jl +++ b/test/integration/cuda.jl @@ -142,13 +142,11 @@ end @testset "Aliasing arguments" begin a = ConcreteRArray([3]) - s = (10, a) - if CUDA.functional() - @jit aliased((s, s)) + @jit aliased(a) @test all(Array(a) == 9) else - @code_hlo optimize = :before_kernel aliased(s) + @code_hlo optimize = :before_kernel aliased(a) end end end From 0aaebc5f9d50e24a8ff723753827b3f0122f4427 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 15 Jan 2025 22:33:12 -0500 Subject: [PATCH 15/15] cuda test --- test/integration/cuda.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/integration/cuda.jl b/test/integration/cuda.jl index 1703ff9914..4ec52ab869 100644 --- a/test/integration/cuda.jl +++ b/test/integration/cuda.jl @@ -128,7 +128,7 @@ end # maybe weird cuda things function aliased!(tup) x, y = tup - x[2][1] *= y[2][1] + x[1] *= y[1] return nothing end @@ -144,7 +144,7 @@ end if CUDA.functional() @jit aliased(a) - @test all(Array(a) == 9) + @test all(Array(a) .== 9) else @code_hlo optimize = :before_kernel aliased(a) end @@ -168,10 +168,9 @@ end if CUDA.functional() a = CuArray([4]) b = ConcreteRArray([3]) - @jit mixed(a, b) - @test all(Array(a) == 4) - @test all(Array(b) == 12) + @test all(Array(a) .== 4) + @test all(Array(b) .== 12) end end end