diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 1b1a470cb3..bb2d9bb7b4 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -412,6 +412,56 @@ AddKernelStatePass() = LLVM.NewPMModulePass("AddKernelStatePass", kern_pass) LowerKernelStatePass() = LLVM.NewPMFunctionPass("LowerKernelStatePass", noop_pass) CleanupKernelStatePass() = LLVM.NewPMModulePass("CleanupKernelStatePass", noop_pass) +# From https://github.com/JuliaGPU/GPUCompiler.jl/blob/7b9322faa34685026c4601a5084eecf5a5d7f3fe/src/ptx.jl#L149 +function vendored_optimize_module!(@nospecialize(job), + mod::LLVM.Module, + instcombine::Bool=false + ) + tm = GPUCompiler.llvm_machine(job.config.target) + # TODO: Use the registered target passes (JuliaGPU/GPUCompiler.jl#450) + LLVM.@dispose pb=LLVM.NewPMPassBuilder() begin + LLVM.register!(pb, GPUCompiler.NVVMReflectPass()) + + LLVM.add!(pb, LLVM.NewPMFunctionPassManager()) do fpm + # TODO: need to run this earlier; optimize_module! is called after addOptimizationPasses! + LLVM.add!(fpm, GPUCompiler.NVVMReflectPass()) + + # needed by GemmKernels.jl-like code + LLVM.add!(fpm, LLVM.SpeculativeExecutionPass()) + + # NVPTX's target machine info enables runtime unrolling, + # but Julia's pass sequence only invokes the simple unroller. + LLVM.add!(fpm, LLVM.LoopUnrollPass(; job.config.opt_level)) + if instcombine + LLVM.add!(fpm, LLVM.InstCombinePass()) # clean-up redundancy + else + LLVM.add!(fpm, LLVM.InstSimplifyPass()) # clean-up redundancy + end + LLVM.add!(fpm, LLVM.NewPMLoopPassManager(; use_memory_ssa=true)) do lpm + LLVM.add!(lpm, LLVM.LICMPass()) # the inner runtime check might be + # outer loop invariant + end + + # the above loop unroll pass might have unrolled regular, non-runtime nested loops. + # that code still needs to be optimized (arguably, multiple unroll passes should be + # scheduled by the Julia optimizer). do so here, instead of re-optimizing entirely. + if job.config.opt_level == 2 + LLVM.add!(fpm, LLVM.GVNPass()) + elseif job.config.opt_level == 1 + LLVM.add!(fpm, LLVM.EarlyCSEPass()) + end + LLVM.add!(fpm, LLVM.DSEPass()) + + LLVM.add!(fpm, LLVM.SimplifyCFGPass()) + end + + # get rid of the internalized functions; now possible unused + LLVM.add!(pb, LLVM.GlobalDCEPass()) + + LLVM.run!(pb, mod, tm) + end +end + # compile to executable machine code function compile(job) # lower to PTX @@ -452,7 +502,7 @@ function compile(job) end LLVM.run!(pb, mod, tm) end - GPUCompiler.optimize_module!(job, mod) + vendored_optimize_module!(job, mod) LLVM.run!(CUDA.GPUCompiler.DeadArgumentEliminationPass(), mod, tm) for fname in ("gpu_report_exception", "gpu_signal_exception") diff --git a/test/basic.jl b/test/basic.jl index 48a2002359..0526ae662d 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -2,6 +2,8 @@ using Reactant using Test using Enzyme using Statistics +using Random +Random.seed!(123) fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf))