From 6121b596caa72128c228ec619df41ec030e8e76f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 7 Feb 2025 17:53:49 -0600 Subject: [PATCH 1/5] JLL related fixups --- deps/ReactantExtra/API.cpp | 19 +++++++++++-------- deps/ReactantExtra/BUILD | 3 ++- src/Compiler.jl | 15 +++++++++++++-- src/Precompile.jl | 4 ++-- src/Reactant.jl | 13 +++++++++---- src/XLA.jl | 4 +++- 6 files changed, 40 insertions(+), 18 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 685969cd91..a07e05f7d5 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -852,14 +852,7 @@ extern "C" void RegisterDialects(MlirContext cctx) { #include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" #include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" -extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { - mlir::DialectRegistry ®istry = *unwrap(creg); - prepareRegistry(registry); - - mlir::registerLLVMDialectImport(registry); - mlir::registerNVVMDialectImport(registry); - mlir::LLVM::registerInlinerInterface(registry); - +extern "C" void InitializePasses(MlirDialectRegistry creg) { mlir::registerenzymePasses(); enzyme::registerenzymexlaPasses(); @@ -901,6 +894,16 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { xla::sdy::registerSdyRoundTripImportPipeline(); } +extern "C" void InitializeRegistry(MlirDialectRegistry creg) { + mlir::DialectRegistry ®istry = *unwrap(creg); + prepareRegistry(registry); + + mlir::registerLLVMDialectImport(registry); + mlir::registerNVVMDialectImport(registry); + mlir::LLVM::registerInlinerInterface(registry); + +} + /// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric /// suffix in `lastUsedID`. static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName, diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 30a78f308e..bf100c09f2 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -435,7 +435,8 @@ cc_library( "-Wl,-exported_symbol,_FutureAwait", "-Wl,-exported_symbol,_XLAExecute", "-Wl,-exported_symbol,_RegisterDialects", -"-Wl,-exported_symbol,_InitializeRegistryAndPasses", +"-Wl,-exported_symbol,_InitializeRegistry", +"-Wl,-exported_symbol,_InitializePasses", "-Wl,-exported_symbol,_ifrt_*", "-Wl,-exported_symbol,_RegisterCustomCallTarget", "-Wl,-exported_symbol,_ConvertLLVMToMLIR", diff --git a/src/Compiler.jl b/src/Compiler.jl index 04f54e3cd4..aac96fa053 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -202,7 +202,7 @@ function create_result( end # Optimization passes via transform dialect -function optimization_passes(; no_nan::Bool=false, sroa::Bool=false) +function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Bool=true) transform_passes_list = [ "patterns=compare_op_canon<16>", "transpose_transpose<16>", @@ -407,7 +407,10 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false) ",", ) func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",") - passes = ["inline{default-pipeline=canonicalize max-iterations=4}"] + passes = String[] + if inline + push!(passes, "inline{default-pipeline=canonicalize max-iterations=4}") + end if sroa push!(passes, "propagate-constant-bounds") if DUMP_LLVMIR[] @@ -703,6 +706,14 @@ function compile_mlir!( run_pass_pipeline!( mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern ) + elseif optimize === :canonicalize + run_pass_pipeline!( + mod, "canonicalize" + ) + elseif optimize === :just_batch + run_pass_pipeline!( + mod, "enzyme-batch" + ) elseif optimize !== :none error("Invalid optimize option: $(Meta.quot(optimize))") end diff --git a/src/Precompile.jl b/src/Precompile.jl index a38d4720b8..41c3fc8922 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -63,10 +63,10 @@ end @compile_workload begin @static if precompilation_supported() x = ConcreteRNumber(2.0; client) - Reactant.compile(sin, (x,); client) + Reactant.compile(sin, (x,); client, optimize=:all) y = ConcreteRArray([2.0]; client) - Reactant.compile(Base.sum, (y,); client) + Reactant.compile(Base.sum, (y,); client, optimize=:all) end end XLA.free_client(client) diff --git a/src/Reactant.jl b/src/Reactant.jl index 1d969f628f..06e01aa910 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -153,18 +153,23 @@ export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace, withi const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}() -const initialize_dialect_first_run = Ref{Bool}(true) - +const passes_initialized = Ref(false) function initialize_dialect() registry[] = MLIR.IR.DialectRegistry() - @ccall MLIR.API.mlir_c.InitializeRegistryAndPasses( + @ccall MLIR.API.mlir_c.InitializeRegistry( registry[]::MLIR.API.MlirDialectRegistry )::Cvoid - initialize_dialect_first_run[] = false + if !passes_initialized[] + @ccall MLIR.API.mlir_c.InitializePasses( + registry[]::MLIR.API.MlirDialectRegistry + )::Cvoid + passes_initialized[] = true + end return nothing end function deinitialize_dialect() + passes_initialized[] = false return registry[] = nothing end diff --git a/src/XLA.jl b/src/XLA.jl index 219b8f665c..6511efd498 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -217,13 +217,15 @@ function __init__() println(stdout, e) end else + if !Reactant.precompiling() try gpu = GPUClient() backends["gpu"] = gpu default_backend[] = gpu catch e - println(stdout, e) + println(stdout, e) end + end end end From bbe6cfa07462a4aca8391b06521b0f1828a9dffc Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 7 Feb 2025 18:04:24 -0600 Subject: [PATCH 2/5] Bump enzymexla --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 1ae05a5ca1..21ac0ce651 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "e2cc5276372199a5b291b8140bd55c46e8e1538a" +ENZYMEXLA_COMMIT = "f69d8621df3a5491d796e3da6d937c5572e9af52" ENZYMEXLA_SHA256 = "" http_archive( From aebb7700f0d0053c507b9deab6051e2dd06bbbc0 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 7 Feb 2025 18:56:33 -0600 Subject: [PATCH 3/5] Update WORKSPACE --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 21ac0ce651..54fa79077d 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "f69d8621df3a5491d796e3da6d937c5572e9af52" +ENZYMEXLA_COMMIT = "e3b0a810763eab1fdab9a8231088160cd3c42e0c" ENZYMEXLA_SHA256 = "" http_archive( From c8f6a7993bcc997d476f3c71fb5de8b241368efc Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 7 Feb 2025 22:17:25 -0600 Subject: [PATCH 4/5] bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0c7a133b5e..b852cc5f37 100644 --- a/Project.toml +++ b/Project.toml @@ -79,7 +79,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.5" -Reactant_jll = "0.0.62" +Reactant_jll = "0.0.64" Scratch = "1.2" Sockets = "1.10" SpecialFunctions = "2.4" From b5b7a06b9689fbe977953fc1f3e00ef456c5af79 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 7 Feb 2025 22:24:00 -0600 Subject: [PATCH 5/5] add new opts --- src/Compiler.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Compiler.jl b/src/Compiler.jl index aac96fa053..ae89256d45 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -389,6 +389,8 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Boo "common_compare_expression_rewrite", "compare_select_simplify", "while_simplify<1>", + "scatter_update_computation_const_prop", + "if_remove_unused", ] if no_nan append!(