From 3845f3a42262efa78bd05fd8a865504cc7d91cd6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 7 Nov 2025 16:46:26 -0500 Subject: [PATCH 1/3] docs: fix npm --- docs/package.json | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/package.json b/docs/package.json index cf8296f03c..321e0d0f3c 100644 --- a/docs/package.json +++ b/docs/package.json @@ -5,6 +5,7 @@ "@types/node": "^22.13.9", "markdown-it": "^14.1.0", "markdown-it-mathjax3": "^4.3.2", + "patch-package": "^8.0.1", "vitepress": "^1.6.3", "vitepress-plugin-tabs": "^0.6.0" }, From 6b4d9b480c019b57e36d2fa36b1f65f3383196db Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 7 Nov 2025 14:18:54 -0600 Subject: [PATCH 2/3] feat: new enzymexla_tt_ext dialect bindings --- deps/ReactantExtra/BUILD | 18 ++++++ deps/ReactantExtra/make-bindings.jl | 1 + docs/src/.vitepress/config.mts | 2 + docs/src/api/dialects/tritonext.md | 11 ++++ src/mlir/Dialects/Enzyme.jl | 6 +- src/mlir/Dialects/TritonExt.jl | 89 +++++++++++++++++++++++++++++ 6 files changed, 124 insertions(+), 3 deletions(-) create mode 100644 docs/src/api/dialects/tritonext.md create mode 100644 src/mlir/Dialects/TritonExt.jl diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index e014f71e80..efd323037b 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -1447,6 +1447,24 @@ gentbl_cc_library( ], ) +gentbl_cc_library( + name = "TritonExtJLIncGen", + tbl_outs = [ + ( + [ + "--generator=jl-op-defs", + "--disable-module-wrap=0", + ], + "TritonExt.jl", + ), + ], + tblgen = "//:mlir-jl-tblgen", + td_file = "@enzyme_ad//src/enzyme_ad/jax:Dialect/TritonExt/Ops.td", + deps = [ + "@enzyme_ad//src/enzyme_ad/jax:TritonExtDialectTdFiles", + ], +) + gentbl_cc_library( name = "TPUJLIncGen", tbl_outs = [ diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index f84309fef1..ebdb7cd9b0 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -42,6 +42,7 @@ for file in [ "MPI.jl", "MemRef.jl", "SparseTensor.jl", + "TritonExt.jl" ] build_file(joinpath(src_dir, "mlir", "Dialects", file)) end diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index dacce466fb..2853e7abd5 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -131,6 +131,7 @@ export default defineConfig({ { text: "SparseTensor", link: "/api/dialects/sparsetensor" }, { text: "StableHLO", link: "/api/dialects/stablehlo" }, { text: "Triton", link: "/api/dialects/triton" }, + { text: "TritonExt", link: "/api/dialects/tritonext" }, { text: "TPU", link: "/api/dialects/tpu" }, { text: "VHLO", link: "/api/dialects/vhlo" }, ], @@ -221,6 +222,7 @@ export default defineConfig({ { text: "SparseTensor", link: "/api/dialects/sparsetensor" }, { text: "StableHLO", link: "/api/dialects/stablehlo" }, { text: "Triton", link: "/api/dialects/triton" }, + { text: "TritonExt", link: "/api/dialects/tritonext" }, { text: "TPU", link: "/api/dialects/tpu" }, { text: "VHLO", link: "/api/dialects/vhlo" }, ], diff --git a/docs/src/api/dialects/tritonext.md b/docs/src/api/dialects/tritonext.md new file mode 100644 index 0000000000..7508ee5b80 --- /dev/null +++ b/docs/src/api/dialects/tritonext.md @@ -0,0 +1,11 @@ +```@meta +CollapsedDocStrings = true +``` + +# TritonExt Dialect + +Provides extensions to the Triton dialect. + +```@autodocs +Modules = [Reactant.MLIR.Dialects.enzymexla_tt_ext] +``` diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index c9c190f505..fadd6c94cb 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -361,12 +361,12 @@ end Extract value from batched operand at index """ -function extract(input::Value, index::Value; output::IR.Type, location=Location()) +function extract(input::Value; output::IR.Type, index, location=Location()) op_ty_results = IR.Type[output,] - operands = Value[input, index] + operands = Value[input,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[] + attributes = NamedAttribute[namedattribute("index", index),] return create_operation( "enzyme.extract", diff --git a/src/mlir/Dialects/TritonExt.jl b/src/mlir/Dialects/TritonExt.jl new file mode 100644 index 0000000000..ddf36e9ae2 --- /dev/null +++ b/src/mlir/Dialects/TritonExt.jl @@ -0,0 +1,89 @@ +module enzymexla_tt_ext +using ...IR +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + +function call( + gridx::Value, + gridy::Value, + gridz::Value, + blockx::Value, + blocky::Value, + blockz::Value, + clusterx::Value, + clustery::Value, + clusterz::Value, + inputs::Vector{Value}; + result_0::Vector{IR.Type}, + fn, + backend_config=nothing, + operand_layouts=nothing, + result_layouts=nothing, + arg_attrs=nothing, + res_attrs=nothing, + output_operand_aliases=nothing, + xla_side_effect_free=nothing, + location=Location(), +) + op_ty_results = IR.Type[result_0...,] + operands = Value[ + gridx, gridy, gridz, blockx, blocky, blockz, clusterx, clustery, clusterz, inputs... + ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(backend_config) && + push!(attributes, namedattribute("backend_config", backend_config)) + !isnothing(operand_layouts) && + push!(attributes, namedattribute("operand_layouts", operand_layouts)) + !isnothing(result_layouts) && + push!(attributes, namedattribute("result_layouts", result_layouts)) + !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) + !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) + !isnothing(output_operand_aliases) && + push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) + !isnothing(xla_side_effect_free) && + push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) + + return create_operation( + "enzymexla_tt_ext.call", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function module_(; sym_name, bodyRegion::Region, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[bodyRegion,] + successors = Block[] + attributes = NamedAttribute[namedattribute("sym_name", sym_name),] + + return create_operation( + "enzymexla_tt_ext.module", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +end # enzymexla_tt_ext From d1d82846f3d9dad5c36fe2bdfb8823bd411a5262 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 7 Nov 2025 15:40:51 -0500 Subject: [PATCH 3/3] Update deps/ReactantExtra/make-bindings.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- deps/ReactantExtra/make-bindings.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index ebdb7cd9b0..9e4295e9cb 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -42,7 +42,7 @@ for file in [ "MPI.jl", "MemRef.jl", "SparseTensor.jl", - "TritonExt.jl" + "TritonExt.jl", ] build_file(joinpath(src_dir, "mlir", "Dialects", file)) end