Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions deps/ReactantExtra/.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,32 @@ build --repo_env=RULES_PYTHON_ENABLE_PYSTAR=0

build -c opt

build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --repo_env TF_NVCC_CLANG=1
build:cuda --repo_env TF_NCCL_USE_STUB=1
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.1"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0"
build:cuda --repo_env=HERMETIC_NVSHMEM_VERSION="3.2.5"
common:cuda --repo_env TF_NEED_CUDA=1
common:cuda --repo_env TF_NVCC_CLANG=1
common:cuda --repo_env TF_NCCL_USE_STUB=1
common:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.1"
common:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0"
common:cuda --repo_env=HERMETIC_NVSHMEM_VERSION="3.2.5"
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,compute_90"
build:cuda --crosstool_top="@local_config_cuda//crosstool:toolchain"
build:cuda --@local_config_cuda//:enable_cuda
common:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,compute_90"
common:cuda --crosstool_top="@local_config_cuda//crosstool:toolchain"
common:cuda --@local_config_cuda//:enable_cuda
# Default hermetic CUDA and CUDNN versions.
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true
build:cuda --@local_config_cuda//:cuda_compiler=nvcc
# build:cuda --@local_config_nvshmem//:override_include_nvshmem_libs=true
# build:cuda --@local_config_nvshmem//cuda:include_nvshmem_libs=true
common:cuda --@local_config_cuda//cuda:include_cuda_libs=true
common:cuda --@local_config_cuda//:cuda_compiler=nvcc
# common:cuda --@local_config_nvshmem//:override_include_nvshmem_libs=true
# common:cuda --@local_config_nvshmem//cuda:include_nvshmem_libs=true


build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --repo_env TF_NEED_ROCM=1
build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100,gfx1200,gfx1201"
common:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
common:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
common:rocm --repo_env TF_NEED_ROCM=1
common:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100,gfx1200,gfx1201"

# Build with hipcc for ROCm and clang for the host.
build:rocm --action_env=TF_ROCM_CLANG="1"
build:rocm --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
build:rocm --copt=-Wno-gnu-offsetof-extensions
build:rocm --copt=-Qunused-arguments
build:rocm --action_env=TF_HIPCC_CLANG="1"
common:rocm --action_env=TF_ROCM_CLANG="1"
common:rocm --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
common:rocm --copt=-Wno-gnu-offsetof-extensions
common:rocm --copt=-Qunused-arguments
common:rocm --action_env=TF_HIPCC_CLANG="1"
9 changes: 5 additions & 4 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1626,8 +1626,9 @@ REACTANT_ABI HeldIfrtArray *ifrt_client_assemble_array_from_single_shards(
REACTANT_ABI HeldIfrtArray *
ifrt_pjrt_array_create(ifrt::PjRtClient *client,
HeldValue<std::shared_ptr<xla::PjRtBuffer>> *buffer) {
return reactant::capture(tsl::RCReference<ifrt::Array>(
MyValueOrThrow(xla::ifrt::PjRtArray::Create(client, buffer->obj()))));
return reactant::capture(
tsl::RCReference<ifrt::Array>(MyValueOrThrow(xla::ifrt::PjRtArray::Create(
client, buffer->obj(), /*has_custom_layout*/ false))));
}

// we might me interested in the `Compiler::Compile` method variant that accepts
Expand Down Expand Up @@ -2373,7 +2374,7 @@ REACTANT_ABI bool hlo_sharding_check_eq(xla::HloSharding *hloSharding,

#pragma endregion

typedef ifrt::Future<> IfRtFutureType;
typedef tsl::Future<> IfRtFutureType;

REACTANT_ABI void ifrt_free_future(IfRtFutureType *Future) { delete Future; }

Expand Down Expand Up @@ -2600,7 +2601,7 @@ REACTANT_ABI void ifrt_loaded_executable_execute(
// there is only 1 status and is valid because we set `options.fill_status =
// true`
*futures = true;
*status = new FutureType(result.status);
*status = new IfRtFutureType(result.status);

for (int i = 0; i < num_results; i++) {
op_results[i] = reactant::capture(result.outputs[i]);
Expand Down
26 changes: 9 additions & 17 deletions deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"

NSYNC_SHA256 = ""

ENZYMEXLA_COMMIT = "996b3e5ccbe5024d69d9a365e5f4db79f41dbc6c"

ENZYMEXLA_COMMIT = "6137ac98e710adf6f4e953bf441db4e25b2db40f"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down Expand Up @@ -36,7 +34,11 @@ LLVM_TARGETS = [
"AArch64",
"X86",
"ARM",
]
]
#+ [
# "PowerPC",
# "SystemZ"
#]

# Uncomment these lines to use a custom LLVM commit
# LLVM_COMMIT = "b39c5cb6977f35ad727d86b2dd6232099734ffd3"
Expand Down Expand Up @@ -73,19 +75,9 @@ LLVM_TARGETS = [
# ],
# )

http_archive(
name = "rules_ml_toolchain",
patch_cmds = [
"""
sed -i.bak0 "/D_FORTIFY_SOURCE/d" cc/features/BUILD third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl
""",
],
sha256 = "e7e44c4e349a1c1f31398bd2257c51432e73ea0e7e24cce67090b68b0b50007e",
strip_prefix = "rules_ml_toolchain-55dcd0a52c7e0f9eec9927a32512229c09ac3b3e",
urls = [
"https://github.com/google-ml-infra/rules_ml_toolchain/archive/55dcd0a52c7e0f9eec9927a32512229c09ac3b3e.tar.gz",
],
)
load("@enzyme_ad//third_party/ml_toolchain:workspace.bzl", ml_toolchain_workspace = "repo")

ml_toolchain_workspace()

load("@enzyme_ad//third_party/jax:workspace.bzl", jax_workspace = "repo")

Expand Down
2 changes: 1 addition & 1 deletion src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2149,7 +2149,7 @@ end
)

if !mincut
MLIR.IR.attr!(while_op, "enzymexla.disable_min_cut", MLIR.IR.UnitAttribute())
MLIR.IR.attr!(while_op, "enzyme.disable_mincut", MLIR.IR.UnitAttribute())
end

if checkpointing
Expand Down
2 changes: 1 addition & 1 deletion test/control_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ end
@test for_no_track_numbers_ra(x_ra, n_ra) == for_no_track_numbers(x, n)

ir = sprint(show, @code_hlo optimize = "enzyme-batch" for_no_track_numbers(x_ra, n_ra))
@test contains(ir, "enzymexla.disable_min_cut")
@test contains(ir, "enzyme.disable_mincut")
@test contains(ir, "enzymexla.enable_checkpointing")
end

Expand Down
Loading