Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1606,14 +1606,15 @@ genrule(
"@jax//jaxlib/mosaic:dialect/tpu/integrations/c/tpu_passes.capi.h.inc",
"@jax//jaxlib/mosaic/dialect/gpu:integrations/c/attributes.h",
"@jax//jaxlib/mosaic/dialect/gpu:integrations/c/gpu_dialect.h",
"@enzyme_ad//src/enzyme_ad/jax:Integrations/c/EnzymeXLA.h",
"//:Project.toml",
"//:Manifest.toml",
"//:wrap.toml",
"//:missing_defs.jl",
"//:make.jl",
],
outs = ["libMLIR_h.jl"],
cmd = "$$JULIA \"--color=yes\" \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$(location @shardy//shardy/integrations/c:attributes.h)\" \"$(location @jax//jaxlib/triton:triton_dialect_capi.h)\" \"$(location @jax//jaxlib/mosaic:dialect/tpu/integrations/c/tpu_dialect.h)\" \"$(location @jax//jaxlib/mosaic/dialect/gpu:integrations/c/gpu_dialect.h)\" \"$@\"",
cmd = "$$JULIA \"--color=yes\" \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$(location @shardy//shardy/integrations/c:attributes.h)\" \"$(location @jax//jaxlib/triton:triton_dialect_capi.h)\" \"$(location @jax//jaxlib/mosaic:dialect/tpu/integrations/c/tpu_dialect.h)\" \"$(location @jax//jaxlib/mosaic/dialect/gpu:integrations/c/gpu_dialect.h)\" \"$(location @enzyme_ad//src/enzyme_ad/jax:Integrations/c/EnzymeXLA.h)\" \"$@\"",
tags = [
"jlrule",
],
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"

NSYNC_SHA256 = ""

ENZYMEXLA_COMMIT = "db993e6a5ab93f831da937a42651b34fc1ac4248"
ENZYMEXLA_COMMIT = "c7e4ccff2a29ee37f1931f4048bdacadb1f707e9"

ENZYMEXLA_SHA256 = ""

Expand Down
18 changes: 13 additions & 5 deletions deps/ReactantExtra/make.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
using Pkg: Pkg
Pkg.activate(@__DIR__)
Pkg.instantiate()

import BinaryBuilderBase:
PkgSpec, Prefix, temp_prefix, setup_dependencies, cleanup_dependencies, destdir
using Clang.Generators
Expand All @@ -19,11 +23,12 @@ let options = deepcopy(options)
genarg = first(eachsplit(ARGS[3], " "))

gen_include_dir = joinpath(splitpath(genarg)[1:(end - 4)]...)
hlo_include_dir = joinpath(splitpath(ARGS[end - 5])[1:(end - 1)]...)
sdy_include_dir = joinpath(splitpath(ARGS[end - 4])[1:(end - 1)]...)
triton_include_dir = joinpath(splitpath(ARGS[end - 3])[1:(end - 1)]...)
mosaic_tpu_include_dir = joinpath(splitpath(ARGS[end - 2])[1:(end - 1)]...)
mosaic_gpu_include_dir = joinpath(splitpath(ARGS[end - 1])[1:(end - 1)]...)
hlo_include_dir = joinpath(splitpath(ARGS[end - 6])[1:(end - 1)]...)
sdy_include_dir = joinpath(splitpath(ARGS[end - 5])[1:(end - 1)]...)
triton_include_dir = joinpath(splitpath(ARGS[end - 4])[1:(end - 1)]...)
mosaic_tpu_include_dir = joinpath(splitpath(ARGS[end - 3])[1:(end - 1)]...)
mosaic_gpu_include_dir = joinpath(splitpath(ARGS[end - 2])[1:(end - 1)]...)
enzymexla_include_dir = joinpath(splitpath(ARGS[end - 1])[1:(end - 1)]...)

append!(
args,
Expand All @@ -44,6 +49,8 @@ let options = deepcopy(options)
mosaic_tpu_include_dir,
"-I",
mosaic_gpu_include_dir,
"-I",
enzymexla_include_dir,
"-x",
"c++",
],
Expand All @@ -56,6 +63,7 @@ let options = deepcopy(options)
detect_headers(triton_include_dir, args, Dict())...,
detect_headers(mosaic_tpu_include_dir, args, Dict())...,
detect_headers(mosaic_gpu_include_dir, args, Dict())...,
detect_headers(enzymexla_include_dir, args, Dict())...,
]

ctx = create_context(headers, args, options)
Expand Down
28 changes: 28 additions & 0 deletions src/mlir/libMLIR_h.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11479,4 +11479,32 @@ function mlirGetDialectHandle__mosaic_gpu__()
@ccall mlir_c.mlirGetDialectHandle__mosaic_gpu__()::MlirDialectHandle
end

function enzymexlaLapackLayoutAttrGet(ctx, col_major)
@ccall mlir_c.enzymexlaLapackLayoutAttrGet(
ctx::MlirContext, col_major::UInt8
)::MlirAttribute
end

function enzymexlaLapackTransposeAttrGet(ctx, mode)
@ccall mlir_c.enzymexlaLapackTransposeAttrGet(
ctx::MlirContext, mode::Int32
)::MlirAttribute
end

function enzymexlaLapackSideAttrGet(ctx, left_side)
@ccall mlir_c.enzymexlaLapackSideAttrGet(
ctx::MlirContext, left_side::UInt8
)::MlirAttribute
end

function enzymexlaQRAlgorithmAttrGet(ctx, mode)
@ccall mlir_c.enzymexlaQRAlgorithmAttrGet(ctx::MlirContext, mode::Int32)::MlirAttribute
end

function enzymexlaGeluApproximationAttrGet(ctx, mode)
@ccall mlir_c.enzymexlaGeluApproximationAttrGet(
ctx::MlirContext, mode::Int32
)::MlirAttribute
end

const MLIR_CAPI_DWARF_ADDRESS_SPACE_NULL = -1
Loading