From f2592b851ee9e5f02807d38018e3e5a54ce70046 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 25 Jan 2025 19:58:37 -0500 Subject: [PATCH 1/4] feat: build the shardy dialect --- deps/ReactantExtra/BUILD | 15 +++++++++++++++ deps/ReactantExtra/WORKSPACE | 7 +++++++ deps/ReactantExtra/make-bindings.jl | 11 ++++++++++- 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 7825fb1d10..fd1ea88988 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -795,6 +795,21 @@ gentbl_cc_library( tblgen = "//:mlir-jl-tblgen", ) +gentbl_cc_library( + name = "ShardyJLIncGen", + tbl_outs = [( + ["--generator=jl-op-defs", "--disable-module-wrap=0"], + "Shardy.jl" + ) + ], + td_file = "@shardy//shardy/dialect/sdy/ir:ops.td", + deps = [ + "@shardy//shardy/dialect/sdy/ir:sdy_td_files", + ], + tblgen = "//:mlir-jl-tblgen", + includes = ["external/shardy"], +) + genrule( name = "libMLIR_h.jl", tags = [ diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 22f948fdc1..9e41f5ec88 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -185,6 +185,13 @@ http_archive( urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], ) +http_archive( + name = "shardy", + sha256 = "", + urls = ["https://github.com/openxla/shardy/archive/main.tar.gz"], + strip_prefix = "shardy-main/", +) + http_archive( name = "build_bazel_rules_apple", sha256 = "34c41bfb59cdaea29ac2df5a2fa79e5add609c71bb303b2ebb10985f93fa20e7", diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index db8f518583..6e6a85f649 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -1,8 +1,16 @@ +const bazel_cmd = if !isnothing(Sys.which("bazelisk")) + "bazelisk" +elseif !isnothing(Sys.which("bazel")) + "bazel" +else + error("Could not find `bazel` or `bazelisk` in PATH!") +end + function build_file(output_path) file = basename(output_path) run( Cmd( - `bazel build --action_env=JULIA=$(Base.julia_cmd().exec[1]) --action_env=JULIA_DEPOT_PATH=$(Base.DEPOT_PATH) --repo_env HERMETIC_PYTHON_VERSION="3.10" --check_visibility=false --verbose_failures //:$file`; + `$(bazel_cmd) build --action_env=JULIA=$(Base.julia_cmd().exec[1]) --action_env=JULIA_DEPOT_PATH=$(Base.DEPOT_PATH) --repo_env HERMETIC_PYTHON_VERSION="3.10" --check_visibility=false --verbose_failures //:$file`; dir=@__DIR__, ), ) @@ -29,6 +37,7 @@ for file in [ "Affine.jl", "TPU.jl", "Triton.jl", + "Shardy.jl", ] build_file(joinpath(src_dir, "mlir", "Dialects", file)) end From 59284137d902a623dc1d83e677285998a33df7bc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 25 Jan 2025 20:07:21 -0500 Subject: [PATCH 2/4] fix: remove http_archive --- deps/ReactantExtra/WORKSPACE | 7 ------- 1 file changed, 7 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 9e41f5ec88..22f948fdc1 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -185,13 +185,6 @@ http_archive( urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], ) -http_archive( - name = "shardy", - sha256 = "", - urls = ["https://github.com/openxla/shardy/archive/main.tar.gz"], - strip_prefix = "shardy-main/", -) - http_archive( name = "build_bazel_rules_apple", sha256 = "34c41bfb59cdaea29ac2df5a2fa79e5add609c71bb303b2ebb10985f93fa20e7", From e010300afc63086daf40d5da74d2cb0ca32714e6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 25 Jan 2025 20:13:12 -0500 Subject: [PATCH 3/4] fix: register dialect --- deps/ReactantExtra/API.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 9398368e24..9b00302d88 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -37,6 +37,7 @@ #include "src/enzyme_ad/jax/Implementations/XLADerivatives.h" #include "src/enzyme_ad/jax/Passes/Passes.h" #include "llvm/Support/TargetSelect.h" +#include "shardy/dialect/sdy/ir/dialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "stablehlo/dialect/ChloOps.h" @@ -701,6 +702,8 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::registerNVVMDialectImport(registry); mlir::LLVM::registerInlinerInterface(registry); + mlir::sdy::registerAllDialects(registry); + /* registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { LLVM::LLVMFunctionType::attachInterface(*ctx); From 53fa443ce50ec493276deb931f8a8e17ee12c12c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 25 Jan 2025 20:19:39 -0500 Subject: [PATCH 4/4] fix: missing loadDialect --- deps/ReactantExtra/API.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 9b00302d88..7c089c8223 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -674,6 +674,7 @@ extern "C" void RegisterDialects(MlirContext cctx) { context.loadDialect(); context.loadDialect(); context.loadDialect(); + context.loadDialect(); } #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"