Skip to content
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ PythonCall = "0.9.25"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.15"
Reactant_jll = "0.0.229"
Reactant_jll = "0.0.230"
ScopedValues = "1.3.0"
Scratch = "1.2"
Sockets = "1.10"
Expand Down
15 changes: 11 additions & 4 deletions deps/ReactantExtra/.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,15 @@ build:cuda --@local_config_cuda//:cuda_compiler=nvcc
# build:cuda --@local_config_nvshmem//:override_include_nvshmem_libs=true
# build:cuda --@local_config_nvshmem//cuda:include_nvshmem_libs=true

build:rocm --repo_env TF_NEED_ROCM=1
build:rocm --define=using_rocm=true
build:rocm --define=using_rocm_hipcc=true
build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030"

build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --repo_env TF_NEED_ROCM=1
build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100,gfx1200,gfx1201"

# Build with hipcc for ROCm and clang for the host.
build:rocm --action_env=TF_ROCM_CLANG="1"
build:rocm --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
build:rocm --copt=-Wno-gnu-offsetof-extensions
build:rocm --copt=-Qunused-arguments
build:rocm --action_env=TF_HIPCC_CLANG="1"
23 changes: 11 additions & 12 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ cc_toolchain_config(
abi_version = "local",
builtin_sysroot = "/opt/x86_64-linux-musl/bin/../x86_64-linux-musl/sys-root",
compile_flags = [
"-stdlib=libstdc++",
],
compiler = "clang",
coverage_compile_flags = ["--coverage"],
Expand Down Expand Up @@ -93,6 +94,7 @@ cc_toolchain_config(
#"--ld-path=/opt/x86_64-linux-musl/bin/ld.lld",
"--ld-path=/opt/bin/x86_64-linux-musl-cxx11/x86_64-linux-musl-ld.lld",
"-stdlib=libstdc++",
"-static-libstdc++",
],
link_libs = [
"-lstdc++",
Expand All @@ -105,7 +107,6 @@ cc_toolchain_config(
"-DNDEBUG",
"-ffunction-sections",
"-fdata-sections",
"-stdlib=libstdc++",
],
opt_link_flags = ["-Wl,--gc-sections"],
supports_start_end_lib = True,
Expand Down Expand Up @@ -188,6 +189,7 @@ cc_toolchain_config(
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/GCC_VERSION/BB_TARGET",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/GCC_VERSION/backward",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/GCC_VERSION/parallel",
"-stdlib=libstdc++",
],
compiler = "compiler",
coverage_compile_flags = ["--coverage"],
Expand All @@ -209,7 +211,10 @@ cc_toolchain_config(
],
dbg_compile_flags = ["-g"],
host_system_name = "linux",
link_flags = [],
link_flags = [
"-stdlib=libstdc++",
"-static-libstdc++",
],
link_libs = [
"-lstdc++",
"-lm",
Expand All @@ -221,7 +226,6 @@ cc_toolchain_config(
"-DNDEBUG",
"-ffunction-sections",
"-fdata-sections",
# "-stdlib=libstdc++",
],
opt_link_flags = ["-Wl,--gc-sections"],
# TODO gcc doesn't support it, only put it on clang (maybe even only for clang on aarch64-darwin?)
Expand Down Expand Up @@ -349,7 +353,6 @@ cc_toolchain_config(
"-DNDEBUG",
"-ffunction-sections",
"-fdata-sections",
# "-stdlib=libstdc++",
],
opt_link_flags = ["-Wl,--gc-sections"],
# TODO gcc doesn't support it, only put it on clang (maybe even only for clang on aarch64-darwin?)
Expand Down Expand Up @@ -452,7 +455,6 @@ cc_toolchain_config(
],
dbg_compile_flags = [
"-g",
"-stdlib=libc++",
],
host_system_name = "linux",
link_flags = [
Expand All @@ -470,7 +472,6 @@ cc_toolchain_config(
"-DNDEBUG",
"-ffunction-sections",
"-fdata-sections",
"-stdlib=libc++",
],
opt_link_flags = ["-Wl,--gc-sections"],
supports_start_end_lib = True,
Expand Down Expand Up @@ -695,12 +696,11 @@ cc_toolchain_config(
],
dbg_compile_flags = [
"-g",
"-stdlib=libstdc++",
],
host_system_name = "linux",
link_flags = [
"-fuse-ld=lld",
"-stdlib=libc++",
"-stdlib=libstdc++",
],
link_libs = [
"-lstdc++",
Expand All @@ -713,7 +713,6 @@ cc_toolchain_config(
"-DNDEBUG",
"-ffunction-sections",
"-fdata-sections",
"-stdlib=libc++",
],
opt_link_flags = ["-Wl,--gc-sections"],
supports_start_end_lib = False,
Expand Down Expand Up @@ -758,7 +757,7 @@ cc_toolchain_config(
"-D__TIME__=\"redacted\"",
"-Wno-unused-command-line-argument",
"-Wno-gnu-offsetof-extensions",
"-mxsave",
"-mxsave",
],
)

Expand All @@ -781,6 +780,7 @@ cc_toolchain_config(
abi_version = "local",
compile_flags = [
"-I/usr/include/c++/11",
"-stdlib=libstdc++",
],
compiler = "clang",
coverage_compile_flags = ["--coverage"],
Expand Down Expand Up @@ -817,7 +817,6 @@ cc_toolchain_config(
"-DNDEBUG",
"-ffunction-sections",
"-fdata-sections",
"-stdlib=libstdc++",
"-I/usr/include/c++/11",
],
opt_link_flags = ["-Wl,--gc-sections"],
Expand Down Expand Up @@ -1193,7 +1192,7 @@ cc_library(
],
"@xla//xla/tsl:windows": [
"@xla//xla/tsl/platform/windows:platform_port",
],
],
"//conditions:default": [
"@gloo//:transport_tcp",
"@xla//xla/backends/cpu/collectives:gloo_collectives",
Expand Down
46 changes: 41 additions & 5 deletions deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"

NSYNC_SHA256 = ""

ENZYMEXLA_COMMIT = "15d5dab32f0eb9da1c26b7b934148aa430bdb144"
ENZYMEXLA_COMMIT = "ced80ad0634e2f30f0d8bef2eaaff3cad1b97664"

ENZYMEXLA_SHA256 = ""

Expand Down Expand Up @@ -38,12 +38,47 @@ CUPTI_NEW = []

NEW_XLA_PATCHES = CUPTI_NEW + [
"""
echo "--- a/src/google/protobuf/stubs/port.h" >> third_party/proto.patch
echo "+++ b/src/google/protobuf/stubs/port.h" >> third_party/proto.patch
echo "@@ -27,7 +27,7 @@" >> third_party/proto.patch
echo " #include <intrin.h>" >> third_party/proto.patch
echo " #elif defined(__APPLE__)" >> third_party/proto.patch
echo " #include <libkern/OSByteOrder.h>" >> third_party/proto.patch
echo "-#elif defined(__linux__) || defined(__ANDROID__) || defined(__CYGWIN__)" >> third_party/proto.patch
echo "+#elif !defined(__NVCC__) && (defined(__linux__) || defined(__ANDROID__) || defined(__CYGWIN__))" >> third_party/proto.patch
echo " #include <byteswap.h> // IWYU pragma: export" >> third_party/proto.patch
echo " #endif" >> third_party/proto.patch
echo "" >> third_party/proto.patch
echo "@@ -143,7 +143,7 @@" >> third_party/proto.patch
echo " #define bswap_32(x) OSSwapInt32(x)" >> third_party/proto.patch
echo " #define bswap_64(x) OSSwapInt64(x)" >> third_party/proto.patch
echo "" >> third_party/proto.patch
echo "-#elif !defined(__linux__) && !defined(__ANDROID__) && !defined(__CYGWIN__)" >> third_party/proto.patch
echo "+#elif defined(__NVCC__) || (!defined(__linux__) && !defined(__ANDROID__) && !defined(__CYGWIN__))" >> third_party/proto.patch
echo "" >> third_party/proto.patch
echo " #ifndef bswap_16" >> third_party/proto.patch
echo " static inline uint16_t bswap_16(uint16_t x) {" >> third_party/proto.patch
sed -i.bak0 "s/protobuf-6.31.1.patch\\"/protobuf-6.31.1.patch\\", \\":proto.patch\\"/g" workspace2.bzl
""",
"""
sed -i.bak0 "s/def main():/def main():\\n if TMPDIR: os.environ['TMPDIR'] = TMPDIR/g" third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
""",
"""
sed -i.bak0 "s/__chkstk/__chkstk_ms/g" xla/service/cpu/runtime_symbol_generator.cc
""",
"""
sed -i.bak0 "1s/^/#include \\"llvm\\/Support\\/DynamicLibrary.h\\"\\n/g" xla/service/cpu/runtime_symbol_generator.cc
""",
"""
sed -i.bak0 "s/(__chkstk_ms)/(llvm::sys::DynamicLibrary::SearchForAddressOfSymbol(\\"__chkstk_ms\\"))/g" xla/service/cpu/runtime_symbol_generator.cc
""",
"""
sed -i.bak0 "s/Shlwapi/shlwapi/g" xla/tsl/platform/windows/load_library.cc xla/tsl/platform/windows/windows_file_system.cc xla/tsl/platform/windows/env.cc
""",
"""
"""
sed -i.bak0 "1s/^/#ifdef PLATFORM_WINDOWS\\n#include <immintrin.h>\\n#include <intrin.h>\\n#endif/g" third_party/tsl/tsl/platform/cpu_info.cc
""",
"""
"""
sed -i.bak0 "1s/^/#define _USE_MATH_DEFINES\\n/g" xla/fp_util.h xla/hlo/builder/lib/prng.cc xla/literal_comparison.cc xla/hlo/builder/lib/math.cc xla/service/spmd/fft_handler.cc xla/service/cpu/onednn_contraction_rewriter.cc xla/hlo/evaluator/hlo_evaluator.cc
""",
"""
Expand Down Expand Up @@ -138,9 +173,11 @@ LLVM_TARGETS = [
# )

load("@enzyme_ad//third_party/jax:workspace.bzl", jax_workspace = "repo")

jax_workspace([])

load("@enzyme_ad//third_party/xla:workspace.bzl", xla_workspace = "repo")

xla_workspace(NEW_XLA_PATCHES)

#
Expand All @@ -158,10 +195,9 @@ xla_workspace(NEW_XLA_PATCHES)
#
# pip_install_dependencies()


load("@enzyme_ad//third_party/enzyme:workspace.bzl", enzyme_workspace = "repo")
enzyme_workspace()

enzyme_workspace()

# http_archive(
# name = "upb",
Expand Down
6 changes: 4 additions & 2 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1147,8 +1147,8 @@ function optimization_passes(
append!(
transform_passes_list,
[
"no_nan",
"no_nan_self_sub_simplify",
"no_nan_compare_simplify(1)",
"no_nan_self_sub_simplify(1)",
"no_nan_add_sub_simplify(1)",
"no_nan_mul_simplify(1)",
"no_nan_div_simplify(1)",
Expand All @@ -1158,6 +1158,8 @@ function optimization_passes(
append!(
transform_passes_list,
[
"no_nan_compare_simplify(0)",
"no_nan_self_sub_simplify(0)",
"no_nan_add_sub_simplify(0)",
"no_nan_mul_simplify(0)",
"no_nan_div_simplify(0)",
Expand Down
40 changes: 40 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,46 @@ const enzyme_dupnoneed = 3
const enzyme_outnoneed = 4
const enzyme_constnoneed = 5

@inline function Enzyme.make_zero(x::RNumber)
return zero(Core.Typeof(x))
end

@inline function Enzyme.make_zero(x::RArray{FT,N})::RArray{FT,N} where {FT<:AbstractFloat,N}
return Base.zero(x)
end

@inline function Enzyme.make_zero(
x::RArray{Complex{FT},N}
)::RArray{Complex{FT},N} where {FT<:AbstractFloat,N}
return Base.zero(x)
end

macro register_make_zero_inplace(sym)
quote
@inline function $sym(prev::RArray{T,N})::Nothing where {T<:AbstractFloat,N}
$sym(prev, nothing)
return nothing
end

@inline function $sym(prev::RArray{T,N}, seen::ST)::Nothing where {T,N,ST}
if Enzyme.Compiler.guaranteed_const_nongen(T, nothing)
return nothing
end
if !isnothing(seen)
if prev in seen
return nothing
end
push!(seen, prev)
end
fill!(prev, zero(T))
return nothing
end
end
end

@register_make_zero_inplace(Enzyme.make_zero!)
@register_make_zero_inplace(Enzyme.remake_zero!)

function Enzyme.make_zero(
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}}
Expand Down
7 changes: 7 additions & 0 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@ vector_forward_ad(x) = Enzyme.autodiff(Forward, fn, BatchDuplicated(x, Enzyme.on
@test res[1][4] ≈ res_enz[1][4]
end

@testset "make_zero!" begin
x = Reactant.to_rarray([3.1])
@jit Enzyme.make_zero!(x)

@test @allowscalar x[1] ≈ 0.0
end

function simple_forward(x, st)
rng = copy(st.rng)
y = similar(x)
Expand Down
Loading