From 59157563a930fa0e85ff644d403be115a96599b7 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 18:30:18 -0600 Subject: [PATCH 01/48] Kernel-supporting jll --- deps/ReactantExtra/BUILD | 1 + deps/ReactantExtra/WORKSPACE | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index c538bbb8a4..559533c7df 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -426,6 +426,7 @@ cc_library( "-Wl,-exported_symbol,_ifrt_*", "-Wl,-exported_symbol,_RegisterCustomCallTarget", "-Wl,-exported_symbol,_ConvertLLVMToMLIR", +"-Wl,-exported_symbol,_EnzymeGPUCustomCall", ]}), deps = [ "@enzyme//:EnzymeMLIR", diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 174cc6715c..bb83aae780 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "f6587e37ff7298f2a1a273b08c24d69fca7ff30f" +ENZYMEXLA_COMMIT = "e059f8c6e559c92846b110537c9a8b53f65ec053" ENZYMEXLA_SHA256 = "" http_archive( From 1b35e8e23a078db4ed16b7dfa3c17cd196e56a6e Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 18:34:00 -0600 Subject: [PATCH 02/48] fix rulescc --- deps/ReactantExtra/WORKSPACE | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index bb83aae780..42823c35d8 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -57,19 +57,6 @@ sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config. # """, ] -http_archive( - name = "rules_cc", - sha256 = "85723d827f080c5e927334f1fb18a294c0b3f94fee6d6b45945f5cdae6ea0fd4", - strip_prefix = "rules_cc-c8c38f8c710cbbf834283e4777916b68261b359c", - urls = [ - "https://github.com/bazelbuild/rules_cc/archive/c8c38f8c710cbbf834283e4777916b68261b359c.tar.gz", - ], -) - -load("@rules_cc//cc:repositories.bzl", "rules_cc_dependencies") - -rules_cc_dependencies() - LLVM_TARGETS = select({ "@bazel_tools//src/conditions:windows": ["AMDGPU", "NVPTX"], "@bazel_tools//src/conditions:darwin": [], From 3f364ca0813e6dcebe7877d5dede2405800b740e Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 18:37:19 -0600 Subject: [PATCH 03/48] adapt to hedron dep --- deps/ReactantExtra/WORKSPACE | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 42823c35d8..6c35de819b 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -19,6 +19,27 @@ http_archive( urls = ["https://github.com/EnzymeAD/Enzyme-JAX/archive/{commit}.tar.gz".format(commit = ENZYMEXLA_COMMIT)], ) + +# Hedron's Compile Commands Extractor for Bazel +# https://github.com/hedronvision/bazel-compile-commands-extractor +http_archive( + name = "hedron_compile_commands", + + # Replace the commit hash (0e990032f3c5a866e72615cf67e5ce22186dcb97) in both places (below) with the latest (https://github.com/hedronvision/bazel-compile-commands-extractor/commits/main), rather than using the stale one here. + # Even better, set up Renovate and let it do the work for you (see "Suggestion: Updates" in the README). + url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/4f28899228fb3ad0126897876f147ca15026151e.tar.gz", + strip_prefix = "bazel-compile-commands-extractor-4f28899228fb3ad0126897876f147ca15026151e", + # When you first run this tool, it'll recommend a sha256 hash to put here with a message like: "DEBUG: Rule 'hedron_compile_commands' indicated that a canonical reproducible form can be obtained by modifying arguments sha256 = ..." +) +load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup") +hedron_compile_commands_setup() +load("@hedron_compile_commands//:workspace_setup_transitive.bzl", "hedron_compile_commands_setup_transitive") +hedron_compile_commands_setup_transitive() +load("@hedron_compile_commands//:workspace_setup_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive") +hedron_compile_commands_setup_transitive_transitive() +load("@hedron_compile_commands//:workspace_setup_transitive_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive_transitive") +hedron_compile_commands_setup_transitive_transitive_transitive() + load("@enzyme_ad//:workspace.bzl", "JAX_COMMIT", "JAX_SHA256", "ENZYME_COMMIT", "ENZYME_SHA256", "XLA_PATCHES") XLA_PATCHES = XLA_PATCHES + [ From 2d745c4dede6710821693736833c0713697a5efe Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 18:42:35 -0600 Subject: [PATCH 04/48] init target --- src/XLA.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/XLA.jl b/src/XLA.jl index 00420edb84..b21999f962 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -131,6 +131,7 @@ function __init__() end end + @ccall MLIR.API.mlir_c.RegisterCustomCallTarget("enzymexla_gpu"::Cstring, cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c)), "CUDA") return nothing end From 2892212dd3c6a09f3f8a0dc4e37c934f4f075ea3 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 18:48:34 -0600 Subject: [PATCH 05/48] fixup --- deps/ReactantExtra/API.cpp | 1 - ext/ReactantCUDAExt.jl | 16 +++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 3ae7a7ebf9..e3068aa5dd 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -50,7 +50,6 @@ #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_c_api_client.h" -#include "xla/service/cpu/simple_orc_jit.h" #include "xla/python/ifrt/hlo/hlo_program.h" #include "llvm/MC/TargetRegistry.h" diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 2b0f87994c..83e8179de0 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -388,8 +388,19 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( fname = Reactant.TracedUtils.get_attribute_by_name(func.entry, "sym_name") # Force public for now while we don't have real users - MLIR.IR.rmattr!(func.entry, "sym_visibility") - + # MLIR.IR.rmattr!(func.entry, "sym_visibility") + + op_ty_results = IR.Type[result_0...,] + operands = Value[inputs...,] + owned_regions = MLIR.IR.Region[] + successors = MLIR.IR.Block[] + attributes = MLIR.IR.NamedAttribute[ + MLIR.IR.namedattribute("fn", fname), + MLIR.IR.namedattribute("output_operand_aliases", output_operand_aliases) + ] + + location = MLIR.IR.Location() + call = MLIR.Dialects.stablehlo.custom_call( mlir_args; result_0=restys, @@ -397,7 +408,6 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( output_operand_aliases, backend_config=MLIR.IR.Attribute(fname), ) - # call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod)) for (i, res) in enumerate(rarrays) res.mlir_data = transpose_val(MLIR.IR.result(call, i)) end From 261b3c21bcdcfe0a1d4002c3e22351ec5d2c9e75 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 18:57:33 -0600 Subject: [PATCH 06/48] additional fixups --- deps/ReactantExtra/BUILD | 3 +++ ext/ReactantCUDAExt.jl | 8 +++++++- src/Compiler.jl | 19 ++++++++++++++++++- test/runtests.jl | 35 +---------------------------------- 4 files changed, 29 insertions(+), 36 deletions(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 559533c7df..b5daf012e6 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -470,6 +470,9 @@ cc_library( "@xla//xla/pjrt:pjrt_c_api_client", "@xla//xla/pjrt/cpu:cpu_client", + "@xla//xla/service:metrics_proto_cc", + "@xla//xla/service:metrics_proto_cc_impl", + "@xla//xla/service/cpu:cpu_compiler", "@xla//xla/stream_executor/tpu:tpu_on_demand_compiler", "@xla//xla/stream_executor/tpu:tpu_executor", diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 83e8179de0..6b7d16e997 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -391,7 +391,13 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( # MLIR.IR.rmattr!(func.entry, "sym_visibility") op_ty_results = IR.Type[result_0...,] - operands = Value[inputs...,] + operands = MLIR.IR.Value[] + for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z) + push!(operands, TracedUtils.promote_to(TracedRNumber{Int}, idx).mlir_data) + end + for arg in mlir_ir_args + push!(operands, arg) + end owned_regions = MLIR.IR.Region[] successors = MLIR.IR.Block[] attributes = MLIR.IR.NamedAttribute[ diff --git a/src/Compiler.jl b/src/Compiler.jl index deb1248699..bd5ab49b9e 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -306,6 +306,22 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) optimize isa Bool && (optimize = ifelse(optimize, :all, :none)) if optimize === :all + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) + run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!( + mod, + join( + [ + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes, + "lower-kernel" + ], + ',', + ), + ) + elseif optimize === :before_kernel run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) run_pass_pipeline!( @@ -341,6 +357,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", opt_passes, + "lower-kernel" ], ',', ), @@ -349,7 +366,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) run_pass_pipeline!( - mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math" + mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,lower-kernel" ) elseif optimize !== :none error("Invalid optimize option: $(Meta.quot(optimize))") diff --git a/test/runtests.jl b/test/runtests.jl index 68dfcaead3..297df1a33f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,37 +41,4 @@ end const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) -@testset "Reactant.jl Tests" begin - if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core" - @safetestset "Layout" include("layout.jl") - @safetestset "Tracing" include("tracing.jl") - @safetestset "Basic" include("basic.jl") - @safetestset "Autodiff" include("autodiff.jl") - @safetestset "Complex" include("complex.jl") - @safetestset "Broadcast" include("bcast.jl") - @safetestset "Struct" include("struct.jl") - @safetestset "Closure" include("closure.jl") - @safetestset "Compile" include("compile.jl") - @safetestset "Buffer Donation" include("buffer_donation.jl") - @safetestset "Shortcuts to MLIR ops" include("ops.jl") - @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") - @safetestset "Control Flow" include("control_flow.jl") - end - - if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" - @safetestset "Linear Algebra" include("integration/linear_algebra.jl") - @safetestset "AbstractFFTs" include("integration/fft.jl") - @safetestset "Random" include("integration/random.jl") - end - - if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" - @testset "Neural Networks" begin - @safetestset "NNlib Primitives" include("nn/nnlib.jl") - @safetestset "Flux.jl Integration" include("nn/flux.jl") - if Sys.islinux() - @safetestset "LuxLib Primitives" include("nn/luxlib.jl") - @safetestset "Lux Integration" include("nn/lux.jl") - end - end - end -end +include("cuda.jl") From 7ef39a4f7cef0e6e0c8745ee7c35be5d5f64c315 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 19:03:21 -0600 Subject: [PATCH 07/48] fixup --- src/XLA.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/XLA.jl b/src/XLA.jl index b21999f962..3fcf4423ab 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -131,7 +131,7 @@ function __init__() end end - @ccall MLIR.API.mlir_c.RegisterCustomCallTarget("enzymexla_gpu"::Cstring, cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c)), "CUDA") + @ccall MLIR.API.mlir_c.RegisterCustomCallTarget("enzymexla_gpu"::Cstring, cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c))::Ptr{Cvoid}, "CUDA"::Cstring)::Cvoid return nothing end From e86af4f589d1b6c9fc7bde7883378c95440f8e3a Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 19:28:44 -0600 Subject: [PATCH 08/48] fix --- deps/ReactantExtra/API.cpp | 12 +++++------- deps/ReactantExtra/WORKSPACE | 2 +- ext/ReactantCUDAExt.jl | 11 +++++------ 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index e3068aa5dd..c08b79219b 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -459,15 +459,13 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable* exec, int num_args, PjRtBu } } +void prepareRegistry(mlir::DialectRegistry ®istry); + extern "C" void RegisterDialects(MlirContext cctx) { mlir::MLIRContext &context = *unwrap(cctx); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); + DialectRegistry registry; + prepareRegistry(registry); + context.appendDialectRegistry(registry); } #include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 6c35de819b..88db233258 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "e059f8c6e559c92846b110537c9a8b53f65ec053" +ENZYMEXLA_COMMIT = "fb483c06f697990c60cc3c0bda7fb1d730fca3de" ENZYMEXLA_SHA256 = "" http_archive( diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 6b7d16e997..ea16b2957e 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -390,19 +390,18 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( # Force public for now while we don't have real users # MLIR.IR.rmattr!(func.entry, "sym_visibility") - op_ty_results = IR.Type[result_0...,] operands = MLIR.IR.Value[] - for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z) - push!(operands, TracedUtils.promote_to(TracedRNumber{Int}, idx).mlir_data) + for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z, shmem) + push!(operands, Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, idx).mlir_data) end - for arg in mlir_ir_args + for arg in mlir_args push!(operands, arg) end owned_regions = MLIR.IR.Region[] successors = MLIR.IR.Block[] attributes = MLIR.IR.NamedAttribute[ - MLIR.IR.namedattribute("fn", fname), - MLIR.IR.namedattribute("output_operand_aliases", output_operand_aliases) + MLIR.IR.NamedAttribute("fn", fname), + MLIR.IR.NamedAttribute("output_operand_aliases", output_operand_aliases) ] location = MLIR.IR.Location() From f1d289c6ee08c73a9f6ef6955fa758f4d5be0fdc Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 19:33:59 -0600 Subject: [PATCH 09/48] registry utils --- deps/ReactantExtra/BUILD | 1 + test/cuda.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index b5daf012e6..44953fd307 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -358,6 +358,7 @@ cc_library( ], ) + [ + "@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp", # "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc", "@xla//xla:xla.pb.cc", "@xla//xla:xla_data.pb.cc", diff --git a/test/cuda.jl b/test/cuda.jl index b5744d35fb..8240add121 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -19,6 +19,7 @@ end oA = collect(1:1:64) A = Reactant.to_rarray(oA) @show @code_hlo optimize = false square!(A) + @show @code_hlo optimize=:before_kernel square!(A) @show @code_hlo square!(A) func = @compile square!(A) @test all(Array(A) .≈ (oA .* oA)) From 802f4451fcc9b3082587c821496c68fdf4383d0d Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 19:40:06 -0600 Subject: [PATCH 10/48] callname --- ext/ReactantCUDAExt.jl | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index ea16b2957e..cf0f6a4abd 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -405,13 +405,15 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( ] location = MLIR.IR.Location() - - call = MLIR.Dialects.stablehlo.custom_call( - mlir_args; - result_0=restys, - call_target_name="reactant_gpu_call", - output_operand_aliases, - backend_config=MLIR.IR.Attribute(fname), + call = MLIR.IR.create_operation( + "enzymexla.kernel_call", + location; + operands, + owned_regions, + successors, + attributes, + results=restys, + result_inference=false, ) for (i, res) in enumerate(rarrays) res.mlir_data = transpose_val(MLIR.IR.result(call, i)) From 9aefd5e18785a5679d4f50e2d87eb9c1337c0e4a Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 19:50:44 -0600 Subject: [PATCH 11/48] reg --- deps/ReactantExtra/API.cpp | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index c08b79219b..f603105a4a 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -473,34 +473,10 @@ extern "C" void RegisterDialects(MlirContext cctx) { #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::DialectRegistry ®istry = *unwrap(creg); - - // Register MLIR stuff - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - - registry.insert(); + prepareRegistry(registry); mlir::registerenzymePasses(); regsiterenzymeXLAPasses(); - mlir::enzyme::registerXLAAutoDiffInterfaces(registry); - - mlir::func::registerInlinerExtension(registry); // Register the standard passes we want. mlir::registerCSEPass(); @@ -517,7 +493,6 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::registerLLVMDialectImport(registry); mlir::registerNVVMDialectImport(registry); - mlir::LLVM::registerInlinerInterface(registry); /* registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { @@ -535,15 +510,10 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { }); */ - // Register the autodiff interface implementations for upstream dialects. - enzyme::registerCoreDialectAutodiffInterfaces(registry); - // Transform dialect and extensions. mlir::transform::registerInterpreterPass(); - mlir::linalg::registerTransformDialectExtension(registry); mlir::enzyme::registerGenerateApplyPatternsPass(); mlir::enzyme::registerRemoveTransformPass(); - mlir::enzyme::registerEnzymeJaxTransformExtension(registry); } From 312ee5bbcd54f0c929ee4916f0aeaa64c9a49cb9 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 20:10:38 -0600 Subject: [PATCH 12/48] fix --- deps/ReactantExtra/API.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index f603105a4a..8be25aa888 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -466,6 +466,14 @@ extern "C" void RegisterDialects(MlirContext cctx) { DialectRegistry registry; prepareRegistry(registry); context.appendDialectRegistry(registry); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); } #include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" From bd94773fe2c546dc043f41ce6a61f57e93a51e54 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 20:25:26 -0600 Subject: [PATCH 13/48] fix bld --- deps/ReactantExtra/API.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 8be25aa888..f374e3399e 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -11,6 +11,7 @@ #include "Enzyme/MLIR/Passes/Passes.h" #include "src/enzyme_ad/jax/Implementations/XLADerivatives.h" #include "src/enzyme_ad/jax/Passes/Passes.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" #include "src/enzyme_ad/jax/TransformOps/TransformOps.h" #include "mlir/Conversion/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -36,6 +37,7 @@ #include "mlir/Transforms/Passes.h" #include "llvm/Support/TargetSelect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" @@ -500,7 +502,7 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::registerLLVMDialectImport(registry); mlir::registerNVVMDialectImport(registry); - + mlir::LLVM::registerInlinerInterface(registry); /* registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { From ef143c3ddf2c8e94cc22896abcfd164f469a3afb Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 21:01:00 -0600 Subject: [PATCH 14/48] cleanup --- deps/ReactantExtra/WORKSPACE | 2 +- ext/ReactantCUDAExt.jl | 13 ++----------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 88db233258..2e04c1ea40 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "fb483c06f697990c60cc3c0bda7fb1d730fca3de" +ENZYMEXLA_COMMIT = "3ce6c51887642fa85313dd17a4bbde227e109a35" ENZYMEXLA_SHA256 = "" http_archive( diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index cf0f6a4abd..0637eeb964 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -351,8 +351,6 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( shmem::Integer=0, call_kwargs..., ) where {F,tt} - @show call_kwargs - blockdim = CUDA.CuDim3(blocks) threaddim = CUDA.CuDim3(threads) @@ -400,8 +398,8 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( owned_regions = MLIR.IR.Region[] successors = MLIR.IR.Block[] attributes = MLIR.IR.NamedAttribute[ - MLIR.IR.NamedAttribute("fn", fname), - MLIR.IR.NamedAttribute("output_operand_aliases", output_operand_aliases) + MLIR.IR.NamedAttribute("fn", MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))), + MLIR.IR.NamedAttribute("output_operand_aliases", MLIR.IR.Attribute(output_operand_aliases)) ] location = MLIR.IR.Location() @@ -418,13 +416,6 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( for (i, res) in enumerate(rarrays) res.mlir_data = transpose_val(MLIR.IR.result(call, i)) end - - @show blockdim - @show threaddim - #CUDA.cuLaunchKernel(f, - # blockdim.x, blockdim.y, blockdim.z, - # threaddim.x, threaddim.y, threaddim.z, - # shmem, stream, kernelParams, C_NULL) end # cache of compilation caches, per context From 1be673204d8472ce8153aa3d6106e9ff2a71f08f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 21:44:45 -0600 Subject: [PATCH 15/48] no pip --- deps/ReactantExtra/WORKSPACE | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 2e04c1ea40..d43833fdd3 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "3ce6c51887642fa85313dd17a4bbde227e109a35" +ENZYMEXLA_COMMIT = "dea63960da134128b152c1624d1425048cd9fb3a" ENZYMEXLA_SHA256 = "" http_archive( @@ -105,7 +105,7 @@ http_archive( load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") python_init_rules() - + load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories") python_init_repositories( requirements = { @@ -116,23 +116,23 @@ python_init_repositories( "3.13": "//build:requirements_lock_3_13.txt", }, ) - + load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") python_init_toolchains() - -load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip") -python_init_pip() - -load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") -python_init_rules() - -load("@rules_python//python:repositories.bzl", "py_repositories") - -py_repositories() - -load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependencies") - -pip_install_dependencies() +# +# load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip") +# python_init_pip() +# +# load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") +# python_init_rules() +# +# load("@rules_python//python:repositories.bzl", "py_repositories") +# +# py_repositories() +# +# load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependencies") +# +# pip_install_dependencies() http_archive( name = "enzyme", From 8e553deeb84b7aedf1e87c480155a46db13edba8 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 22:05:58 -0600 Subject: [PATCH 16/48] fix --- src/Compiler.jl | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index bd5ab49b9e..2616986216 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1,5 +1,7 @@ module Compiler +using Reactant_jll + import ..Reactant: Reactant, MLIR, @@ -305,6 +307,11 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) optimize isa Bool && (optimize = ifelse(optimize, :all, :none)) + toolkit = "" + if isdefined(Reactant_jll, :ptxas_path) + toolkit = Reactant_jll.ptxas_path[1:end-length("/bin/ptxas")] + end + kern = "lower-kernel{toolkitPath=$toolkit}" if optimize === :all run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) @@ -316,7 +323,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", opt_passes, - "lower-kernel" + kern ], ',', ), @@ -357,7 +364,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", opt_passes, - "lower-kernel" + kern ], ',', ), @@ -366,7 +373,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) run_pass_pipeline!( - mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,lower-kernel" + mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,"*kern ) elseif optimize !== :none error("Invalid optimize option: $(Meta.quot(optimize))") From a2c664c906a7fc72108d321f9db33eaa6ecb5b04 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 17 Dec 2024 14:39:14 -0600 Subject: [PATCH 17/48] force rules python to older version before bug --- deps/ReactantExtra/WORKSPACE | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index d43833fdd3..cc0b3d179f 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -20,6 +20,13 @@ http_archive( ) +http_archive( + name = "rules_python", + sha256 = "778aaeab3e6cfd56d681c89f5c10d7ad6bf8d2f1a72de9de55b23081b2d31618", + strip_prefix = "rules_python-0.34.0", + url = "https://github.com/bazelbuild/rules_python/releases/download/0.34.0/rules_python-0.34.0.tar.gz", +) + # Hedron's Compile Commands Extractor for Bazel # https://github.com/hedronvision/bazel-compile-commands-extractor http_archive( From e41bb8fbbc0ceb1a9eede0a7bba090c1d2a998d2 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 17 Dec 2024 17:14:33 -0600 Subject: [PATCH 18/48] fixup jll --- deps/ReactantExtra/BUILD | 9 ++++++--- deps/ReactantExtra/WORKSPACE | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 44953fd307..20500242f3 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -3,6 +3,8 @@ load("@xla//tools/toolchains/cross_compile/cc:cc_toolchain_config.bzl", "cc_tool # load("//toolchain:yggdrasil.bzl", "ygg_cc_toolchain") licenses(["notice"]) +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") + package( default_applicable_licenses = [], default_visibility = ["//:__subpackages__"], @@ -360,7 +362,7 @@ cc_library( ) + [ "@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp", # "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc", - "@xla//xla:xla.pb.cc", + # "@xla//xla:xla.pb.cc", "@xla//xla:xla_data.pb.cc", "@xla//xla/stream_executor:device_description.pb.cc", "@xla//xla/service:hlo.pb.cc", @@ -509,12 +511,13 @@ cc_library( "@xla//xla/service/gpu/model:hlo_op_profiles", "@xla//xla/service/gpu/model:hlo_op_profile_proto_cc_impl", "@xla//xla/service/gpu:nvptx_compiler", - "@xla//xla/service/gpu:amdgpu_compiler", "@xla//xla/service/gpu:gpu_transfer_manager", "@xla//xla/stream_executor:kernel", ], "//conditions:default": [], - }), + }) + if_rocm([ + "@xla//xla/service/gpu:amdgpu_compiler", + ]), ) # cc_shared_library( diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index cc0b3d179f..9a755c620d 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "dea63960da134128b152c1624d1425048cd9fb3a" +ENZYMEXLA_COMMIT = "e5f20f21e8d1ad8f698ce0c03b07acf3a078ec9d" ENZYMEXLA_SHA256 = "" http_archive( From 6f92d004de61b6c6cf28b66e15ee3a388d7bfa8f Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 17 Dec 2024 19:12:51 -0600 Subject: [PATCH 19/48] with proto --- deps/ReactantExtra/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 20500242f3..26b80761f3 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -472,6 +472,9 @@ cc_library( "@xla//xla/pjrt:pjrt_api", "@xla//xla/pjrt:pjrt_c_api_client", "@xla//xla/pjrt/cpu:cpu_client", + + "@xla//xla:xla_proto_cc", + "@xla//xla:xla_proto_cc_impl", "@xla//xla/service:metrics_proto_cc", "@xla//xla/service:metrics_proto_cc_impl", From 1787ece69d31beb8e2d70d1704fd9c37bddba41d Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 17 Dec 2024 19:21:32 -0600 Subject: [PATCH 20/48] fix --- deps/ReactantExtra/WORKSPACE | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 9a755c620d..d0b207ec15 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -57,6 +57,9 @@ sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/ sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/stream_executor/host/host_kernel.cc """, """ +sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/tsl/concurrency/async_value_ref.h +""", +""" sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_LINK_H=1\\/HAVE_LINK_H=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl """, """ From d38eac905f9ec760ec89501b88d454b5ab5ed20f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 18 Dec 2024 01:26:20 -0500 Subject: [PATCH 21/48] fix --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index d0b207ec15..d710881722 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "e5f20f21e8d1ad8f698ce0c03b07acf3a078ec9d" +ENZYMEXLA_COMMIT = "7a9b0c28b6744c4ebecaf6b70ec9569ffbe3f713" ENZYMEXLA_SHA256 = "" http_archive( From b50c8f1b4cc040bcd1502ecddc84d48af593a0de Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 01:09:08 -0600 Subject: [PATCH 22/48] Update WORKSPACE --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index d710881722..a26c0327f9 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "7a9b0c28b6744c4ebecaf6b70ec9569ffbe3f713" +ENZYMEXLA_COMMIT = "b7468f5fff13c6cec9c152032ad8745b2afbd7a3" ENZYMEXLA_SHA256 = "" http_archive( From 4002eff383d3cc9019fb5520c6ba2a580c97a65a Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 18 Dec 2024 10:54:18 -0500 Subject: [PATCH 23/48] more deps for apple --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index a26c0327f9..fa7db6a4d7 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "b7468f5fff13c6cec9c152032ad8745b2afbd7a3" +ENZYMEXLA_COMMIT = "6fa2a71ed44f509b6954700edee1e6bdd700037f" ENZYMEXLA_SHA256 = "" http_archive( From 65189e2061344b32d57d8f40588a36a4a50b7acd Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 18 Dec 2024 12:48:10 -0500 Subject: [PATCH 24/48] bump --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index fa7db6a4d7..1dea9f0b8d 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "6fa2a71ed44f509b6954700edee1e6bdd700037f" +ENZYMEXLA_COMMIT = "e420581c013a5d679310cf3eb8525714ccc7d45f" ENZYMEXLA_SHA256 = "" http_archive( From 5c9ef9a27e1073d7d2a230965c960ed810619db4 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 18 Dec 2024 14:50:44 -0500 Subject: [PATCH 25/48] fix --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 1dea9f0b8d..adfe597a1f 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "e420581c013a5d679310cf3eb8525714ccc7d45f" +ENZYMEXLA_COMMIT = "12a4638523a996f41fb6ba5af77369ae449144dc" ENZYMEXLA_SHA256 = "" http_archive( From a149e0a12326dc3b8294a154793cc12c667f6575 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 18 Dec 2024 17:27:04 -0500 Subject: [PATCH 26/48] workspace bump --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index adfe597a1f..80adaab03f 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "12a4638523a996f41fb6ba5af77369ae449144dc" +ENZYMEXLA_COMMIT = "53bda7276e75bc3199493fefb317328e5508f713" ENZYMEXLA_SHA256 = "" http_archive( From a049cf2f4eb732cf2296ff1e820bf0de61296b9c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 18 Dec 2024 17:29:51 -0500 Subject: [PATCH 27/48] workspace --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 80adaab03f..6d791a91ca 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "53bda7276e75bc3199493fefb317328e5508f713" +ENZYMEXLA_COMMIT = "8fdb884d2d7abc860cf5a609bb6f97e24a831264" ENZYMEXLA_SHA256 = "" http_archive( From cc3e5e5686a814226bbf5afc79d170361c96dc2b Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 20:35:44 -0500 Subject: [PATCH 28/48] Update Compiler.jl --- src/Compiler.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 2616986216..2c34085a63 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -290,6 +290,10 @@ function compile_mlir(f, args; kwargs...) end end +const cuLaunch = Ref{UInt}(0) +const cuFunc = Ref{UInt}(0) +const cuModule = Ref{UInt}(0) + function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys, @@ -311,7 +315,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) if isdefined(Reactant_jll, :ptxas_path) toolkit = Reactant_jll.ptxas_path[1:end-length("/bin/ptxas")] end - kern = "lower-kernel{toolkitPath=$toolkit}" + kern = "lower-kernel{toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}" if optimize === :all run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) From 11165984a051964dfa094c7e6dd0a3df7c24e2bb Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 20:40:47 -0500 Subject: [PATCH 29/48] Update ReactantCUDAExt.jl --- ext/ReactantCUDAExt.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 0637eeb964..abcd578b50 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -459,6 +459,10 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( return res end -function __init__() end +function __init__() + Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, Base.cglobal((:cuLaunchKernel, CUDA.libcuda)) + Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, Base.cglobal((:cuModuleLoadData, CUDA.libcuda)) + Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, Base.cglobsl((:cuModuleGetFunction, CUDA.libcuda)) +end end # module ReactantCUDAExt From ab5e5751c44f0e5997f5f020cea09ee7fbc6dd9d Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 21:01:48 -0500 Subject: [PATCH 30/48] Update ReactantCUDAExt.jl --- ext/ReactantCUDAExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index abcd578b50..67456d5e30 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -460,9 +460,9 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( end function __init__() - Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, Base.cglobal((:cuLaunchKernel, CUDA.libcuda)) - Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, Base.cglobal((:cuModuleLoadData, CUDA.libcuda)) - Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, Base.cglobsl((:cuModuleGetFunction, CUDA.libcuda)) + Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, Base.cglobal((:cuLaunchKernel, CUDA.libcuda))) + Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, Base.cglobal((:cuModuleLoadData, CUDA.libcuda))) + Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, Base.cglobsl((:cuModuleGetFunction, CUDA.libcuda))) end end # module ReactantCUDAExt From 54ae8233cb360d93da7dd5c32e565b052a0011ba Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 21:10:40 -0500 Subject: [PATCH 31/48] Update ReactantCUDAExt.jl --- ext/ReactantCUDAExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 67456d5e30..3172727ad3 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -460,9 +460,9 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( end function __init__() - Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, Base.cglobal((:cuLaunchKernel, CUDA.libcuda))) - Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, Base.cglobal((:cuModuleLoadData, CUDA.libcuda))) - Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, Base.cglobsl((:cuModuleGetFunction, CUDA.libcuda))) + Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, Base.cglobal((:cuLaunchKernel, CUDA.CUDA_Driver_jll.libcuda))) + Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, Base.cglobal((:cuModuleLoadData, CUDA.CUDA_Driver_jll.libcuda))) + Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, Base.cglobsl((:cuModuleGetFunction, CUDA.CUDA_Driver_jll.libcuda))) end end # module ReactantCUDAExt From cc47802e3a634dd4e21a6b9d9cb1725d5bae0727 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 21:14:50 -0500 Subject: [PATCH 32/48] Update ReactantCUDAExt.jl --- ext/ReactantCUDAExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 3172727ad3..80ef5bd753 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -460,9 +460,9 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( end function __init__() - Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, Base.cglobal((:cuLaunchKernel, CUDA.CUDA_Driver_jll.libcuda))) - Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, Base.cglobal((:cuModuleLoadData, CUDA.CUDA_Driver_jll.libcuda))) - Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, Base.cglobsl((:cuModuleGetFunction, CUDA.CUDA_Driver_jll.libcuda))) + Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, Base.cglobal((:cuLaunchKernel, "libcuda.so.1"))) + Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, Base.cglobal((:cuModuleLoadData, "libcuda.so.1"))) + Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, Base.cglobsl((:cuModuleGetFunction, "libcuda.so.1"))) end end # module ReactantCUDAExt From ce055913d76ddd7242fb58967e64bbf9647138df Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 21:23:58 -0500 Subject: [PATCH 33/48] Update ReactantCUDAExt.jl --- ext/ReactantCUDAExt.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 80ef5bd753..75afe9c102 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -460,9 +460,12 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( end function __init__() - Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, Base.cglobal((:cuLaunchKernel, "libcuda.so.1"))) - Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, Base.cglobal((:cuModuleLoadData, "libcuda.so.1"))) - Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, Base.cglobsl((:cuModuleGetFunction, "libcuda.so.1"))) + ptr1 = Reactant.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuLaunchKernel") + ptr2 = Reactant.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuModuleLoadData") + ptr3 = Reactant.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuModuleGetFunction") + Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1) + Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2) + Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3) end end # module ReactantCUDAExt From e5ca9dd088b5f01e64ffee3821c92de80aa29820 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 21:29:03 -0500 Subject: [PATCH 34/48] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index acc3130cb7..0f68251777 100644 --- a/Project.toml +++ b/Project.toml @@ -34,7 +34,7 @@ ReactantCore = {path = "lib/ReactantCore"} [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" -ReactantCUDAExt = "CUDA" +ReactantCUDAExt = ["CUDA", "Libdl"] ReactantNNlibExt = "NNlib" ReactantRandom123Ext = "Random123" ReactantStatisticsExt = "Statistics" From 0a7da9721fda6fb0924fd37a19f26280bdd6c419 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 21:29:48 -0500 Subject: [PATCH 35/48] Update ReactantCUDAExt.jl --- ext/ReactantCUDAExt.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 75afe9c102..b30e9fce9b 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -3,6 +3,7 @@ module ReactantCUDAExt using CUDA using Reactant: Reactant, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber using ReactantCore: @trace +using Libdl using Adapt @@ -460,9 +461,9 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( end function __init__() - ptr1 = Reactant.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuLaunchKernel") - ptr2 = Reactant.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuModuleLoadData") - ptr3 = Reactant.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuModuleGetFunction") + ptr1 = Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuLaunchKernel") + ptr2 = Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuModuleLoadData") + ptr3 = Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuModuleGetFunction") Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1) Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2) Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3) From a43ccf04e7d517c5b101079b92daa2eaaf8ace78 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 21:32:57 -0500 Subject: [PATCH 36/48] Update Project.toml --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 0f68251777..de2306c0a1 100644 --- a/Project.toml +++ b/Project.toml @@ -67,4 +67,5 @@ julia = "1.10" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" From cc411a35d4a365dc13619dc6b7dad7695fe0a7d2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 21:35:07 -0500 Subject: [PATCH 37/48] Update Project.toml --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index de2306c0a1..e561b29502 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ Scratch = "6c6a2e73-6563-6170-7368-637461726353" AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random123 = "74087812-796a-5b5d-8853-05524746bad3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" From 0d6260c7a0004d3cdce8bf26a684f4afde81a480 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 21:47:21 -0500 Subject: [PATCH 38/48] fix --- Project.toml | 4 +--- ext/ReactantCUDAExt.jl | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index e561b29502..acc3130cb7 100644 --- a/Project.toml +++ b/Project.toml @@ -23,7 +23,6 @@ Scratch = "6c6a2e73-6563-6170-7368-637461726353" AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random123 = "74087812-796a-5b5d-8853-05524746bad3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -35,7 +34,7 @@ ReactantCore = {path = "lib/ReactantCore"} [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" -ReactantCUDAExt = ["CUDA", "Libdl"] +ReactantCUDAExt = "CUDA" ReactantNNlibExt = "NNlib" ReactantRandom123Ext = "Random123" ReactantStatisticsExt = "Statistics" @@ -68,5 +67,4 @@ julia = "1.10" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index b30e9fce9b..1368e10e3a 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -461,9 +461,18 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( end function __init__() - ptr1 = Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuLaunchKernel") - ptr2 = Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuModuleLoadData") - ptr3 = Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuModuleGetFunction") + ptr1 = Reactant.XLA.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuLaunchKernel"; throw_error=false) + if ptr1 === nothing + ptr1 = C_NULL + end + ptr2 = Reactant.XLA.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuModuleLoadData"; throw_error=false) + if ptr2 === nothing + ptr2 = C_NULL + end + ptr3 = Reactant.XLA.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuModuleGetFunction"; throw_error=false) + if ptr3 === nothing + ptr3 = C_NULL + end Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1) Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2) Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3) From 17a2e12410a0ff985874988719ae41d403136c3e Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 21:55:11 -0500 Subject: [PATCH 39/48] Update ReactantCUDAExt.jl --- ext/ReactantCUDAExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 1368e10e3a..857c8c3238 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -461,15 +461,15 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( end function __init__() - ptr1 = Reactant.XLA.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuLaunchKernel"; throw_error=false) + ptr1 = Reactant.XLA.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda, "cuLaunchKernel"; throw_error=false) if ptr1 === nothing ptr1 = C_NULL end - ptr2 = Reactant.XLA.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuModuleLoadData"; throw_error=false) + ptr2 = Reactant.XLA.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda, "cuModuleLoadData"; throw_error=false) if ptr2 === nothing ptr2 = C_NULL end - ptr3 = Reactant.XLA.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda_handle, "cuModuleGetFunction"; throw_error=false) + ptr3 = Reactant.XLA.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda, "cuModuleGetFunction"; throw_error=false) if ptr3 === nothing ptr3 = C_NULL end From 31c20e49d7621d63535500dda9179823dd58e8e4 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 21:59:19 -0500 Subject: [PATCH 40/48] Update ReactantCUDAExt.jl --- ext/ReactantCUDAExt.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 857c8c3238..77fe797b49 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -461,15 +461,16 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( end function __init__() - ptr1 = Reactant.XLA.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda, "cuLaunchKernel"; throw_error=false) + handle = Reactant.XLA.Libdl.dlopen(CUDA.CUDA_Driver_jll.libcuda) + ptr1 = Reactant.XLA.Libdl.dlsym(handle, "cuLaunchKernel"; throw_error=false) if ptr1 === nothing ptr1 = C_NULL end - ptr2 = Reactant.XLA.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda, "cuModuleLoadData"; throw_error=false) + ptr2 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleLoadData"; throw_error=false) if ptr2 === nothing ptr2 = C_NULL end - ptr3 = Reactant.XLA.Libdl.dlsym(CUDA.CUDA_Driver_jll.libcuda, "cuModuleGetFunction"; throw_error=false) + ptr3 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleGetFunction"; throw_error=false) if ptr3 === nothing ptr3 = C_NULL end From eca6ee992d0aca78d2680daea0879a6030322b77 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 22:01:34 -0500 Subject: [PATCH 41/48] Update ReactantCUDAExt.jl --- ext/ReactantCUDAExt.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 77fe797b49..3c08e6f11c 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -461,7 +461,10 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( end function __init__() - handle = Reactant.XLA.Libdl.dlopen(CUDA.CUDA_Driver_jll.libcuda) + handle = Reactant.XLA.Libdl.dlopen(CUDA.CUDA_Driver_jll.libcuda; throw_error=false) + if handle === nothing + handle = C_NULL + end ptr1 = Reactant.XLA.Libdl.dlsym(handle, "cuLaunchKernel"; throw_error=false) if ptr1 === nothing ptr1 = C_NULL From 54fcea83ce7152dfab5e64d2235a1fcde726a6fe Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 22:05:31 -0500 Subject: [PATCH 42/48] Update cuda.jl --- test/cuda.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/cuda.jl b/test/cuda.jl index 8240add121..7252d37f37 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -22,5 +22,6 @@ end @show @code_hlo optimize=:before_kernel square!(A) @show @code_hlo square!(A) func = @compile square!(A) + func!(A) @test all(Array(A) .≈ (oA .* oA)) end From f5951781740f1b2b1b7631b82197eeb3276938a2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 22:10:27 -0500 Subject: [PATCH 43/48] Update cuda.jl --- test/cuda.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cuda.jl b/test/cuda.jl index 7252d37f37..aed8781f29 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -21,7 +21,7 @@ end @show @code_hlo optimize = false square!(A) @show @code_hlo optimize=:before_kernel square!(A) @show @code_hlo square!(A) - func = @compile square!(A) + func! = @compile square!(A) func!(A) @test all(Array(A) .≈ (oA .* oA)) end From ac98093b58aacf3506c32a920aa5edaed9254860 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Dec 2024 22:16:39 -0500 Subject: [PATCH 44/48] Update cuda.jl --- test/cuda.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/cuda.jl b/test/cuda.jl index aed8781f29..b524b3b9d1 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -23,5 +23,7 @@ end @show @code_hlo square!(A) func! = @compile square!(A) func!(A) + @show A + @show oA @test all(Array(A) .≈ (oA .* oA)) end From f8d8c95726b79c3a78ca895fb6f9052f9d9f4157 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 23 Dec 2024 14:35:15 -0500 Subject: [PATCH 45/48] Cuda kernel v2 --- Project.toml | 9 +++++---- src/Compiler.jl | 2 +- test/cuda.jl | 17 ++++++++++++----- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index acc3130cb7..4e2dbd4d61 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.2.11" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" @@ -28,9 +29,6 @@ Random123 = "74087812-796a-5b5d-8853-05524746bad3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" -[sources] -ReactantCore = {path = "lib/ReactantCore"} - [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" @@ -57,7 +55,7 @@ Preferences = "1.4" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.3" -Reactant_jll = "0.0.27" +Reactant_jll = "0.0.31" Scratch = "1.2" Statistics = "1.10" YaoBlocks = "0.13" @@ -68,3 +66,6 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" + +[sources.ReactantCore] +path = "lib/ReactantCore" diff --git a/src/Compiler.jl b/src/Compiler.jl index 2c34085a63..51efeb8aa8 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -315,7 +315,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) if isdefined(Reactant_jll, :ptxas_path) toolkit = Reactant_jll.ptxas_path[1:end-length("/bin/ptxas")] end - kern = "lower-kernel{toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}" + kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}" if optimize === :all run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) diff --git a/test/cuda.jl b/test/cuda.jl index b524b3b9d1..f5c8fe68f1 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -2,9 +2,16 @@ using Reactant using Test using CUDA +using Reactant_jll +@show Reactant_jll.libReactantExtra_path + function square_kernel!(x) - i = threadIdx().x - x[i] *= x[i] + #i = threadIdx().x + #x[i] *= x[i] + #@cuprintf("overwrote value of %f was thrown during kernel execution on thread (%d, %d, %d) in block (%d, %d, %d).\n", + # 0.0, threadIdx().x, threadIdx().y, threadIdx().z, blockIdx().x, blockIdx().y, blockIdx().z) + #x[i], threadIdx().x, threadIdx().y, threadIdx().z, blockIdx().x, blockIdx().y, blockIdx().z) + # sync_threads() return nothing end @@ -18,9 +25,9 @@ end @testset "Square Kernel" begin oA = collect(1:1:64) A = Reactant.to_rarray(oA) - @show @code_hlo optimize = false square!(A) - @show @code_hlo optimize=:before_kernel square!(A) - @show @code_hlo square!(A) + #@show @code_hlo optimize = false square!(A) + #@show @code_hlo optimize=:before_kernel square!(A) + #@show @code_hlo square!(A) func! = @compile square!(A) func!(A) @show A From 1c71a3ba7ff2949445843b6cc37bec90c485afab Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 23 Dec 2024 18:58:41 -0500 Subject: [PATCH 46/48] Update Project.toml --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2f406f4bf9..a2723854be 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.2.11" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" From c22c01ad7b8a1fcdbfddeaa28e9b53f77a7722b0 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 23 Dec 2024 18:58:53 -0500 Subject: [PATCH 47/48] Update API.cpp --- deps/ReactantExtra/API.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index e77de21111..3292f38800 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -49,6 +49,8 @@ #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/hlo/hlo_program.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" From fa621fe3f5e2e3b7c8354a4a24cb0c3b8f979419 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 23 Dec 2024 19:48:41 -0500 Subject: [PATCH 48/48] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/ReactantCUDAExt.jl | 2 +- test/cuda.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index c6ea606508..ba0765af5f 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -479,7 +479,7 @@ function __init__() Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1) Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2) Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3) - return + return nothing end end # module ReactantCUDAExt diff --git a/test/cuda.jl b/test/cuda.jl index 1a654428d9..549002e4f1 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -10,7 +10,7 @@ function square_kernel!(x) #x[i] *= x[i] #@cuprintf("overwrote value of %f was thrown during kernel execution on thread (%d, %d, %d) in block (%d, %d, %d).\n", # 0.0, threadIdx().x, threadIdx().y, threadIdx().z, blockIdx().x, blockIdx().y, blockIdx().z) - #x[i], threadIdx().x, threadIdx().y, threadIdx().z, blockIdx().x, blockIdx().y, blockIdx().z) + #x[i], threadIdx().x, threadIdx().y, threadIdx().z, blockIdx().x, blockIdx().y, blockIdx().z) # sync_threads() return nothing