diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 1cb682e601..36134e154f 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -1606,6 +1606,7 @@ genrule( "@jax//jaxlib/mosaic:dialect/tpu/integrations/c/tpu_passes.capi.h.inc", "@jax//jaxlib/mosaic/dialect/gpu:integrations/c/attributes.h", "@jax//jaxlib/mosaic/dialect/gpu:integrations/c/gpu_dialect.h", + "@enzyme_ad//src/enzyme_ad/jax:Integrations/c/EnzymeXLA.h", "//:Project.toml", "//:Manifest.toml", "//:wrap.toml", @@ -1613,7 +1614,7 @@ genrule( "//:make.jl", ], outs = ["libMLIR_h.jl"], - cmd = "$$JULIA \"--color=yes\" \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$(location @shardy//shardy/integrations/c:attributes.h)\" \"$(location @jax//jaxlib/triton:triton_dialect_capi.h)\" \"$(location @jax//jaxlib/mosaic:dialect/tpu/integrations/c/tpu_dialect.h)\" \"$(location @jax//jaxlib/mosaic/dialect/gpu:integrations/c/gpu_dialect.h)\" \"$@\"", + cmd = "$$JULIA \"--color=yes\" \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$(location @shardy//shardy/integrations/c:attributes.h)\" \"$(location @jax//jaxlib/triton:triton_dialect_capi.h)\" \"$(location @jax//jaxlib/mosaic:dialect/tpu/integrations/c/tpu_dialect.h)\" \"$(location @jax//jaxlib/mosaic/dialect/gpu:integrations/c/gpu_dialect.h)\" \"$(location @enzyme_ad//src/enzyme_ad/jax:Integrations/c/EnzymeXLA.h)\" \"$@\"", tags = [ "jlrule", ], diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index cdde71e4c3..4b271f0e7b 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" -ENZYMEXLA_COMMIT = "db993e6a5ab93f831da937a42651b34fc1ac4248" +ENZYMEXLA_COMMIT = "c7e4ccff2a29ee37f1931f4048bdacadb1f707e9" ENZYMEXLA_SHA256 = "" diff --git a/deps/ReactantExtra/make.jl b/deps/ReactantExtra/make.jl index 194d43497f..d7d6110406 100644 --- a/deps/ReactantExtra/make.jl +++ b/deps/ReactantExtra/make.jl @@ -1,3 +1,7 @@ +using Pkg: Pkg +Pkg.activate(@__DIR__) +Pkg.instantiate() + import BinaryBuilderBase: PkgSpec, Prefix, temp_prefix, setup_dependencies, cleanup_dependencies, destdir using Clang.Generators @@ -19,11 +23,12 @@ let options = deepcopy(options) genarg = first(eachsplit(ARGS[3], " ")) gen_include_dir = joinpath(splitpath(genarg)[1:(end - 4)]...) - hlo_include_dir = joinpath(splitpath(ARGS[end - 5])[1:(end - 1)]...) - sdy_include_dir = joinpath(splitpath(ARGS[end - 4])[1:(end - 1)]...) - triton_include_dir = joinpath(splitpath(ARGS[end - 3])[1:(end - 1)]...) - mosaic_tpu_include_dir = joinpath(splitpath(ARGS[end - 2])[1:(end - 1)]...) - mosaic_gpu_include_dir = joinpath(splitpath(ARGS[end - 1])[1:(end - 1)]...) + hlo_include_dir = joinpath(splitpath(ARGS[end - 6])[1:(end - 1)]...) + sdy_include_dir = joinpath(splitpath(ARGS[end - 5])[1:(end - 1)]...) + triton_include_dir = joinpath(splitpath(ARGS[end - 4])[1:(end - 1)]...) + mosaic_tpu_include_dir = joinpath(splitpath(ARGS[end - 3])[1:(end - 1)]...) + mosaic_gpu_include_dir = joinpath(splitpath(ARGS[end - 2])[1:(end - 1)]...) + enzymexla_include_dir = joinpath(splitpath(ARGS[end - 1])[1:(end - 1)]...) append!( args, @@ -44,6 +49,8 @@ let options = deepcopy(options) mosaic_tpu_include_dir, "-I", mosaic_gpu_include_dir, + "-I", + enzymexla_include_dir, "-x", "c++", ], @@ -56,6 +63,7 @@ let options = deepcopy(options) detect_headers(triton_include_dir, args, Dict())..., detect_headers(mosaic_tpu_include_dir, args, Dict())..., detect_headers(mosaic_gpu_include_dir, args, Dict())..., + detect_headers(enzymexla_include_dir, args, Dict())..., ] ctx = create_context(headers, args, options) diff --git a/src/mlir/libMLIR_h.jl b/src/mlir/libMLIR_h.jl index e87b136406..3955d63279 100755 --- a/src/mlir/libMLIR_h.jl +++ b/src/mlir/libMLIR_h.jl @@ -11479,4 +11479,32 @@ function mlirGetDialectHandle__mosaic_gpu__() @ccall mlir_c.mlirGetDialectHandle__mosaic_gpu__()::MlirDialectHandle end +function enzymexlaLapackLayoutAttrGet(ctx, col_major) + @ccall mlir_c.enzymexlaLapackLayoutAttrGet( + ctx::MlirContext, col_major::UInt8 + )::MlirAttribute +end + +function enzymexlaLapackTransposeAttrGet(ctx, mode) + @ccall mlir_c.enzymexlaLapackTransposeAttrGet( + ctx::MlirContext, mode::Int32 + )::MlirAttribute +end + +function enzymexlaLapackSideAttrGet(ctx, left_side) + @ccall mlir_c.enzymexlaLapackSideAttrGet( + ctx::MlirContext, left_side::UInt8 + )::MlirAttribute +end + +function enzymexlaQRAlgorithmAttrGet(ctx, mode) + @ccall mlir_c.enzymexlaQRAlgorithmAttrGet(ctx::MlirContext, mode::Int32)::MlirAttribute +end + +function enzymexlaGeluApproximationAttrGet(ctx, mode) + @ccall mlir_c.enzymexlaGeluApproximationAttrGet( + ctx::MlirContext, mode::Int32 + )::MlirAttribute +end + const MLIR_CAPI_DWARF_ADDRESS_SPACE_NULL = -1