diff --git a/Project.toml b/Project.toml index 66da87a205..641a7b759a 100644 --- a/Project.toml +++ b/Project.toml @@ -97,7 +97,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.15" -Reactant_jll = "0.0.229" +Reactant_jll = "0.0.230" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" diff --git a/deps/ReactantExtra/.bazelrc b/deps/ReactantExtra/.bazelrc index 1b442a68c4..be37689e71 100644 --- a/deps/ReactantExtra/.bazelrc +++ b/deps/ReactantExtra/.bazelrc @@ -44,8 +44,15 @@ build:cuda --@local_config_cuda//:cuda_compiler=nvcc # build:cuda --@local_config_nvshmem//:override_include_nvshmem_libs=true # build:cuda --@local_config_nvshmem//cuda:include_nvshmem_libs=true -build:rocm --repo_env TF_NEED_ROCM=1 -build:rocm --define=using_rocm=true -build:rocm --define=using_rocm_hipcc=true -build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030" +build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain +build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true +build:rocm --repo_env TF_NEED_ROCM=1 +build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100,gfx1200,gfx1201" + +# Build with hipcc for ROCm and clang for the host. +build:rocm --action_env=TF_ROCM_CLANG="1" +build:rocm --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" +build:rocm --copt=-Wno-gnu-offsetof-extensions +build:rocm --copt=-Qunused-arguments +build:rocm --action_env=TF_HIPCC_CLANG="1" diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 4c3717e851..a04c9f587f 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -64,6 +64,7 @@ cc_toolchain_config( abi_version = "local", builtin_sysroot = "/opt/x86_64-linux-musl/bin/../x86_64-linux-musl/sys-root", compile_flags = [ + "-stdlib=libstdc++", ], compiler = "clang", coverage_compile_flags = ["--coverage"], @@ -93,6 +94,7 @@ cc_toolchain_config( #"--ld-path=/opt/x86_64-linux-musl/bin/ld.lld", "--ld-path=/opt/bin/x86_64-linux-musl-cxx11/x86_64-linux-musl-ld.lld", "-stdlib=libstdc++", + "-static-libstdc++", ], link_libs = [ "-lstdc++", @@ -105,7 +107,6 @@ cc_toolchain_config( "-DNDEBUG", "-ffunction-sections", "-fdata-sections", - "-stdlib=libstdc++", ], opt_link_flags = ["-Wl,--gc-sections"], supports_start_end_lib = True, @@ -188,6 +189,7 @@ cc_toolchain_config( "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/GCC_VERSION/BB_TARGET", "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/GCC_VERSION/backward", "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/GCC_VERSION/parallel", + "-stdlib=libstdc++", ], compiler = "compiler", coverage_compile_flags = ["--coverage"], @@ -209,7 +211,10 @@ cc_toolchain_config( ], dbg_compile_flags = ["-g"], host_system_name = "linux", - link_flags = [], + link_flags = [ + "-stdlib=libstdc++", + "-static-libstdc++", + ], link_libs = [ "-lstdc++", "-lm", @@ -221,7 +226,6 @@ cc_toolchain_config( "-DNDEBUG", "-ffunction-sections", "-fdata-sections", - # "-stdlib=libstdc++", ], opt_link_flags = ["-Wl,--gc-sections"], # TODO gcc doesn't support it, only put it on clang (maybe even only for clang on aarch64-darwin?) @@ -349,7 +353,6 @@ cc_toolchain_config( "-DNDEBUG", "-ffunction-sections", "-fdata-sections", - # "-stdlib=libstdc++", ], opt_link_flags = ["-Wl,--gc-sections"], # TODO gcc doesn't support it, only put it on clang (maybe even only for clang on aarch64-darwin?) @@ -452,7 +455,6 @@ cc_toolchain_config( ], dbg_compile_flags = [ "-g", - "-stdlib=libc++", ], host_system_name = "linux", link_flags = [ @@ -470,7 +472,6 @@ cc_toolchain_config( "-DNDEBUG", "-ffunction-sections", "-fdata-sections", - "-stdlib=libc++", ], opt_link_flags = ["-Wl,--gc-sections"], supports_start_end_lib = True, @@ -695,12 +696,11 @@ cc_toolchain_config( ], dbg_compile_flags = [ "-g", - "-stdlib=libstdc++", ], host_system_name = "linux", link_flags = [ "-fuse-ld=lld", - "-stdlib=libc++", + "-stdlib=libstdc++", ], link_libs = [ "-lstdc++", @@ -713,7 +713,6 @@ cc_toolchain_config( "-DNDEBUG", "-ffunction-sections", "-fdata-sections", - "-stdlib=libc++", ], opt_link_flags = ["-Wl,--gc-sections"], supports_start_end_lib = False, @@ -758,7 +757,7 @@ cc_toolchain_config( "-D__TIME__=\"redacted\"", "-Wno-unused-command-line-argument", "-Wno-gnu-offsetof-extensions", - "-mxsave", + "-mxsave", ], ) @@ -781,6 +780,7 @@ cc_toolchain_config( abi_version = "local", compile_flags = [ "-I/usr/include/c++/11", + "-stdlib=libstdc++", ], compiler = "clang", coverage_compile_flags = ["--coverage"], @@ -817,7 +817,6 @@ cc_toolchain_config( "-DNDEBUG", "-ffunction-sections", "-fdata-sections", - "-stdlib=libstdc++", "-I/usr/include/c++/11", ], opt_link_flags = ["-Wl,--gc-sections"], @@ -1193,7 +1192,7 @@ cc_library( ], "@xla//xla/tsl:windows": [ "@xla//xla/tsl/platform/windows:platform_port", - ], + ], "//conditions:default": [ "@gloo//:transport_tcp", "@xla//xla/backends/cpu/collectives:gloo_collectives", diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 9cbcddb575..87b364df19 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" -ENZYMEXLA_COMMIT = "15d5dab32f0eb9da1c26b7b934148aa430bdb144" +ENZYMEXLA_COMMIT = "ced80ad0634e2f30f0d8bef2eaaff3cad1b97664" ENZYMEXLA_SHA256 = "" @@ -38,12 +38,47 @@ CUPTI_NEW = [] NEW_XLA_PATCHES = CUPTI_NEW + [ """ +echo "--- a/src/google/protobuf/stubs/port.h" >> third_party/proto.patch +echo "+++ b/src/google/protobuf/stubs/port.h" >> third_party/proto.patch +echo "@@ -27,7 +27,7 @@" >> third_party/proto.patch +echo " #include " >> third_party/proto.patch +echo " #elif defined(__APPLE__)" >> third_party/proto.patch +echo " #include " >> third_party/proto.patch +echo "-#elif defined(__linux__) || defined(__ANDROID__) || defined(__CYGWIN__)" >> third_party/proto.patch +echo "+#elif !defined(__NVCC__) && (defined(__linux__) || defined(__ANDROID__) || defined(__CYGWIN__))" >> third_party/proto.patch +echo " #include // IWYU pragma: export" >> third_party/proto.patch +echo " #endif" >> third_party/proto.patch +echo "" >> third_party/proto.patch +echo "@@ -143,7 +143,7 @@" >> third_party/proto.patch +echo " #define bswap_32(x) OSSwapInt32(x)" >> third_party/proto.patch +echo " #define bswap_64(x) OSSwapInt64(x)" >> third_party/proto.patch +echo "" >> third_party/proto.patch +echo "-#elif !defined(__linux__) && !defined(__ANDROID__) && !defined(__CYGWIN__)" >> third_party/proto.patch +echo "+#elif defined(__NVCC__) || (!defined(__linux__) && !defined(__ANDROID__) && !defined(__CYGWIN__))" >> third_party/proto.patch +echo "" >> third_party/proto.patch +echo " #ifndef bswap_16" >> third_party/proto.patch +echo " static inline uint16_t bswap_16(uint16_t x) {" >> third_party/proto.patch +sed -i.bak0 "s/protobuf-6.31.1.patch\\"/protobuf-6.31.1.patch\\", \\":proto.patch\\"/g" workspace2.bzl +""", + """ +sed -i.bak0 "s/def main():/def main():\\n if TMPDIR: os.environ['TMPDIR'] = TMPDIR/g" third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl +""", + """ +sed -i.bak0 "s/__chkstk/__chkstk_ms/g" xla/service/cpu/runtime_symbol_generator.cc +""", + """ +sed -i.bak0 "1s/^/#include \\"llvm\\/Support\\/DynamicLibrary.h\\"\\n/g" xla/service/cpu/runtime_symbol_generator.cc +""", + """ +sed -i.bak0 "s/(__chkstk_ms)/(llvm::sys::DynamicLibrary::SearchForAddressOfSymbol(\\"__chkstk_ms\\"))/g" xla/service/cpu/runtime_symbol_generator.cc +""", + """ sed -i.bak0 "s/Shlwapi/shlwapi/g" xla/tsl/platform/windows/load_library.cc xla/tsl/platform/windows/windows_file_system.cc xla/tsl/platform/windows/env.cc """, -""" + """ sed -i.bak0 "1s/^/#ifdef PLATFORM_WINDOWS\\n#include \\n#include \\n#endif/g" third_party/tsl/tsl/platform/cpu_info.cc """, -""" + """ sed -i.bak0 "1s/^/#define _USE_MATH_DEFINES\\n/g" xla/fp_util.h xla/hlo/builder/lib/prng.cc xla/literal_comparison.cc xla/hlo/builder/lib/math.cc xla/service/spmd/fft_handler.cc xla/service/cpu/onednn_contraction_rewriter.cc xla/hlo/evaluator/hlo_evaluator.cc """, """ @@ -138,9 +173,11 @@ LLVM_TARGETS = [ # ) load("@enzyme_ad//third_party/jax:workspace.bzl", jax_workspace = "repo") + jax_workspace([]) load("@enzyme_ad//third_party/xla:workspace.bzl", xla_workspace = "repo") + xla_workspace(NEW_XLA_PATCHES) # @@ -158,10 +195,9 @@ xla_workspace(NEW_XLA_PATCHES) # # pip_install_dependencies() - load("@enzyme_ad//third_party/enzyme:workspace.bzl", enzyme_workspace = "repo") -enzyme_workspace() +enzyme_workspace() # http_archive( # name = "upb", diff --git a/src/Compiler.jl b/src/Compiler.jl index 56c5b413dd..2dce2dac1b 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1147,8 +1147,8 @@ function optimization_passes( append!( transform_passes_list, [ - "no_nan", - "no_nan_self_sub_simplify", + "no_nan_compare_simplify(1)", + "no_nan_self_sub_simplify(1)", "no_nan_add_sub_simplify(1)", "no_nan_mul_simplify(1)", "no_nan_div_simplify(1)", @@ -1158,6 +1158,8 @@ function optimization_passes( append!( transform_passes_list, [ + "no_nan_compare_simplify(0)", + "no_nan_self_sub_simplify(0)", "no_nan_add_sub_simplify(0)", "no_nan_mul_simplify(0)", "no_nan_div_simplify(0)", diff --git a/src/Enzyme.jl b/src/Enzyme.jl index d33b5e6729..d8cc086e05 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -5,6 +5,46 @@ const enzyme_dupnoneed = 3 const enzyme_outnoneed = 4 const enzyme_constnoneed = 5 +@inline function Enzyme.make_zero(x::RNumber) + return zero(Core.Typeof(x)) +end + +@inline function Enzyme.make_zero(x::RArray{FT,N})::RArray{FT,N} where {FT<:AbstractFloat,N} + return Base.zero(x) +end + +@inline function Enzyme.make_zero( + x::RArray{Complex{FT},N} +)::RArray{Complex{FT},N} where {FT<:AbstractFloat,N} + return Base.zero(x) +end + +macro register_make_zero_inplace(sym) + quote + @inline function $sym(prev::RArray{T,N})::Nothing where {T<:AbstractFloat,N} + $sym(prev, nothing) + return nothing + end + + @inline function $sym(prev::RArray{T,N}, seen::ST)::Nothing where {T,N,ST} + if Enzyme.Compiler.guaranteed_const_nongen(T, nothing) + return nothing + end + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end + fill!(prev, zero(T)) + return nothing + end + end +end + +@register_make_zero_inplace(Enzyme.make_zero!) +@register_make_zero_inplace(Enzyme.remake_zero!) + function Enzyme.make_zero( ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) )::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}} diff --git a/test/autodiff.jl b/test/autodiff.jl index 25641bc231..a11ec900e4 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -231,6 +231,13 @@ vector_forward_ad(x) = Enzyme.autodiff(Forward, fn, BatchDuplicated(x, Enzyme.on @test res[1][4] ≈ res_enz[1][4] end +@testset "make_zero!" begin + x = Reactant.to_rarray([3.1]) + @jit Enzyme.make_zero!(x) + + @test @allowscalar x[1] ≈ 0.0 +end + function simple_forward(x, st) rng = copy(st.rng) y = similar(x)