Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.32"
Reactant_jll = "0.0.33"
Scratch = "1.2"
Statistics = "1.10"
YaoBlocks = "0.13"
Expand Down
4 changes: 2 additions & 2 deletions deps/ReactantExtra/.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ build -c opt
build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --repo_env TF_NVCC_CLANG=1
build:cuda --repo_env TF_NCCL_USE_STUB=1
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.6.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.6.0"
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
Expand Down
33 changes: 20 additions & 13 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ cc_toolchain_config(
coverage_link_flags = ["--coverage"],
cpu = "k8",
cxx_builtin_include_directories = [
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/13.2.0",
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/13.2.0/x86_64-linux-musl",
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/13.2.0/backward",
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0",
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0/x86_64-linux-musl",
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0/backward",
Expand Down Expand Up @@ -149,14 +152,14 @@ cc_toolchain_config(
abi_libc_version = "local",
abi_version = "local",
cxx_builtin_include_directories = [
"/opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include",
"/opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include-fixed",
"/opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include",
"/opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include-fixed",
"/opt/BB_TARGET/BB_TARGET/include",
"/opt/BB_TARGET/BB_TARGET/sys-root/usr/include",
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0",
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/BB_TARGET",
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/backward",
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/parallel"
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0",
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/BB_TARGET",
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/backward",
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/parallel"
],
tool_paths = {
"ar": "/opt/bin/BB_FULL_TARGET/ar",
Expand Down Expand Up @@ -193,14 +196,14 @@ cc_toolchain_config(
"-Wno-free-nonheap-object",
"-fno-omit-frame-pointer",
# TODO cxx_builtin_include_directories doesn't seem to be working, so we add the INCLUDE_PATHs manually
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include",
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include-fixed",
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include",
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include-fixed",
"-isystem /opt/BB_TARGET/BB_TARGET/include",
"-isystem /opt/BB_TARGET/BB_TARGET/sys-root/usr/include",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/BB_TARGET",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/backward",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/parallel",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/BB_TARGET",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/backward",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/parallel",
],
opt_compile_flags = [
"-g0",
Expand Down Expand Up @@ -361,6 +364,7 @@ cc_library(

) + [
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp",
"@enzyme_ad//src/enzyme_ad/jax:gpu.cc",
# "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc",
# "@xla//xla:xla.pb.cc",
"@xla//xla:xla_data.pb.cc",
Expand Down Expand Up @@ -429,7 +433,7 @@ cc_library(
"-Wl,-exported_symbol,_ifrt_*",
"-Wl,-exported_symbol,_RegisterCustomCallTarget",
"-Wl,-exported_symbol,_ConvertLLVMToMLIR",
"-Wl,-exported_symbol,_EnzymeGPUCustomCall",
"-Wl,-exported_symbol,_RegisterEnzymeXLAGPUHandler",
"-Wl,-exported_symbol,_ReactantThrowError",
]}),
deps = [
Expand Down Expand Up @@ -469,6 +473,9 @@ cc_library(
"@llvm-project//llvm:X86CodeGen",
"@enzyme_ad//src/enzyme_ad/jax:TransformOps",
"@enzyme_ad//src/enzyme_ad/jax:XLADerivatives",
# "@enzyme_ad//src/enzyme_ad/jax:gpu",
"@xla//xla/ffi/api:ffi",
"@xla//xla/ffi:ffi_api",
"@stablehlo//:chlo_ops",
"@xla//xla/pjrt:pjrt_api",
"@xla//xla/pjrt:pjrt_c_api_client",
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "b6d6563aa3a3050474a4250bf18322f7ebf0b486"
ENZYMEXLA_COMMIT = "74046d05089c02946058f8fd94ed23efd0bf3ccc"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down
2 changes: 2 additions & 0 deletions deps/build_local.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ else
run(
Cmd(
`bazel build $(arg) -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1])
--repo_env=GCC_HOST_COMPILER_PATH=/usr/bin/gcc
--repo_env=CC=/home/wmoses/llvms/llvm16-r/clang+llvm-16.0.2-x86_64-linux-gnu-ubuntu-22.04/bin/clang
--repo_env HERMETIC_PYTHON_VERSION="3.10"
--check_visibility=false --verbose_failures :libReactantExtra.so`;
dir=source_dir,
Expand Down
52 changes: 26 additions & 26 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,11 @@ function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where
return res
end

const _kernel_instances = Dict{Any,Any}()

# Since we cache these objects we cannot cache data containing MLIR operations (e.g. the entry must be a string
# and not the operation itself).
struct LLVMFunc{F,tt}
f::Union{F,Nothing}
entry::MLIR.IR.Operation
entry::String
end

const GPUCompiler = CUDA.GPUCompiler
Expand Down Expand Up @@ -324,9 +324,9 @@ function compile(job)
)::MLIR.API.MlirOperation

entry = MLIR.IR.Operation(linkRes)

entry
String(Reactant.TracedUtils.get_attribute_by_name(linkRes, "sym_name"))
end

return LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, entry)
end

Expand Down Expand Up @@ -378,9 +378,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(

output_operand_aliases = MLIR.IR.Attribute(aliases)

fname = Reactant.TracedUtils.get_attribute_by_name(func.entry, "sym_name")
# Force public for now while we don't have real users
# MLIR.IR.rmattr!(func.entry, "sym_visibility")
fname = func.entry

operands = MLIR.IR.Value[]
for idx in
Expand Down Expand Up @@ -460,25 +458,27 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction(
end

function __init__()
handle = Reactant.XLA.Libdl.dlopen(CUDA.CUDA_Driver_jll.libcuda; throw_error=false)
if handle === nothing
handle = C_NULL
end
ptr1 = Reactant.XLA.Libdl.dlsym(handle, "cuLaunchKernel"; throw_error=false)
if ptr1 === nothing
ptr1 = C_NULL
end
ptr2 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleLoadData"; throw_error=false)
if ptr2 === nothing
ptr2 = C_NULL
end
ptr3 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleGetFunction"; throw_error=false)
if ptr3 === nothing
ptr3 = C_NULL
if CUDA.CUDA_Driver_jll.libcuda !== nothing
handle = Reactant.XLA.Libdl.dlopen(CUDA.CUDA_Driver_jll.libcuda; throw_error=false)
if handle === nothing
handle = C_NULL
end
ptr1 = Reactant.XLA.Libdl.dlsym(handle, "cuLaunchKernel"; throw_error=false)
if ptr1 === nothing
ptr1 = C_NULL
end
ptr2 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleLoadData"; throw_error=false)
if ptr2 === nothing
ptr2 = C_NULL
end
ptr3 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleGetFunction"; throw_error=false)
if ptr3 === nothing
ptr3 = C_NULL
end
Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1)
Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2)
Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
end
Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1)
Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2)
Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
return nothing
end

Expand Down
6 changes: 1 addition & 5 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,7 @@ function __init__()
end
end

@ccall MLIR.API.mlir_c.RegisterCustomCallTarget(
"enzymexla_gpu"::Cstring,
cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c))::Ptr{Cvoid},
"CUDA"::Cstring,
)::Cvoid
@ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid

# This wasn't properly exported on macos, we'll remove the try once macOS JLL
# has the fix.
Expand Down
5 changes: 4 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,10 @@ function call_with_reactant_generator(
ir, rt = CC.typeinf_ircode(interp, mi, nothing)
end

ir, any_changed = rewrite_insts!(ir, interp)
if !is_reactant_method(mi::Core.MethodInstance)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jumerckx as a note. Something that I missed earlier, we previously had a catch all like this that we'll only do rewrite insts if the top level function is not a reactant method, but I think this was accidentally removed by your earlier PR. In any case restored here (and once CUDA lands will actually be tested it is done)

ir, any_changed = rewrite_insts!(ir, interp)
end

src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ())
src.slotnames = fill(:none, length(ir.argtypes) + 1)
src.slotflags = fill(zero(UInt8), length(ir.argtypes))
Expand Down
36 changes: 0 additions & 36 deletions test/cuda.jl

This file was deleted.

29 changes: 29 additions & 0 deletions test/integration/cuda.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using Reactant
using Test
using CUDA

function square_kernel!(x, y)
i = threadIdx().x
x[i] *= y[i]
sync_threads()
return nothing
end

# basic squaring on GPU
function square!(x, y)
@cuda blocks = 1 threads = length(x) square_kernel!(x, y)
return nothing
end

@testset "Square Kernel" begin
oA = collect(1:1:64)
A = Reactant.to_rarray(oA)
B = Reactant.to_rarray(100 .* oA)
if CUDA.functional()
@jit square!(A, B)
@test all(Array(A) .≈ (oA .* oA .* 100))
@test all(Array(B) .≈ (oA .* 100))
else
@code_hlo optimize = :before_kernel square!(A, B)
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
# Temporarily disabled as minutia are debugged
# @safetestset "CUDA" include("integration/cuda.jl")
@safetestset "Linear Algebra" include("integration/linear_algebra.jl")
@safetestset "AbstractFFTs" include("integration/fft.jl")
@safetestset "Random" include("integration/random.jl")
Expand Down
Loading