diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 279227b811..ba0765af5f 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -478,7 +478,8 @@ function __init__() end Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1) Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2) - return Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3) + Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3) + return nothing end end # module ReactantCUDAExt diff --git a/src/Compiler.jl b/src/Compiler.jl index dae77a0679..bb912d0663 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -328,7 +328,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) if isdefined(Reactant_jll, :ptxas_path) toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))] end - kern = "lower-kernel{toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}" + kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}" if optimize === :all run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) diff --git a/test/cuda.jl b/test/cuda.jl index 711a6ef7e4..549002e4f1 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -2,9 +2,16 @@ using Reactant using Test using CUDA +using Reactant_jll +@show Reactant_jll.libReactantExtra_path + function square_kernel!(x) - i = threadIdx().x - x[i] *= x[i] + #i = threadIdx().x + #x[i] *= x[i] + #@cuprintf("overwrote value of %f was thrown during kernel execution on thread (%d, %d, %d) in block (%d, %d, %d).\n", + # 0.0, threadIdx().x, threadIdx().y, threadIdx().z, blockIdx().x, blockIdx().y, blockIdx().z) + #x[i], threadIdx().x, threadIdx().y, threadIdx().z, blockIdx().x, blockIdx().y, blockIdx().z) + # sync_threads() return nothing end @@ -18,9 +25,9 @@ end @testset "Square Kernel" begin oA = collect(1:1:64) A = Reactant.to_rarray(oA) - @show @code_hlo optimize = false square!(A) - @show @code_hlo optimize = :before_kernel square!(A) - @show @code_hlo square!(A) + # @show @code_hlo optimize = false square!(A) + # @show @code_hlo optimize = :before_kernel square!(A) + # @show @code_hlo square!(A) func! = @compile square!(A) func!(A) @show A