Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,9 @@ function vendored_optimize_module!(
end
end

function vendored_buildEarlyOptimizerPipeline(mpm, @nospecialize(job), opt_level; instcombine=false)
function vendored_buildEarlyOptimizerPipeline(
mpm, @nospecialize(job), opt_level; instcombine=false
)
LLVM.add!(mpm, LLVM.NewPMCGSCCPassManager()) do cgpm
# TODO invokeCGSCCCallbacks
LLVM.add!(cgpm, LLVM.NewPMFunctionPassManager()) do fpm
Expand Down Expand Up @@ -496,7 +498,9 @@ function vendored_buildEarlyOptimizerPipeline(mpm, @nospecialize(job), opt_level
end
end

function vendored_buildIntrinsicLoweringPipeline(mpm, @nospecialize(job), opt_level; instcombine::Bool=false)
function vendored_buildIntrinsicLoweringPipeline(
mpm, @nospecialize(job), opt_level; instcombine::Bool=false
)
GPUCompiler.add!(mpm, LLVM.Interop.RemoveNIPass())

# lower GC intrinsics
Expand Down Expand Up @@ -561,7 +565,9 @@ function vendored_buildIntrinsicLoweringPipeline(mpm, @nospecialize(job), opt_le
else
LLVM.add!(fpm, LLVM.InstSimplifyPass())
end
LLVM.add!(fpm, LLVM.SimplifyCFGPass(; GPUCompiler.AggressiveSimplifyCFGOptions...))
LLVM.add!(
fpm, LLVM.SimplifyCFGPass(; GPUCompiler.AggressiveSimplifyCFGOptions...)
)
end
end

Expand All @@ -570,7 +576,7 @@ function vendored_buildIntrinsicLoweringPipeline(mpm, @nospecialize(job), opt_le

# Julia's operand bundles confuse the inliner, so repeat here now they are gone.
# FIXME: we should fix the inliner so that inlined code gets optimized early-on
LLVM.add!(mpm, LLVM.AlwaysInlinerPass())
return LLVM.add!(mpm, LLVM.AlwaysInlinerPass())
end

function vendored_buildNewPMPipeline!(mpm, @nospecialize(job), opt_level)
Expand All @@ -595,7 +601,7 @@ function vendored_buildNewPMPipeline!(mpm, @nospecialize(job), opt_level)
# end
end
vendored_buildIntrinsicLoweringPipeline(mpm, job, opt_level)
GPUCompiler.buildCleanupPipeline(mpm, job, opt_level)
return GPUCompiler.buildCleanupPipeline(mpm, job, opt_level)
end

# compile to executable machine code
Expand Down
36 changes: 19 additions & 17 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,10 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Boo
if sroa
push!(passes, "propagate-constant-bounds")
if DUMP_LLVMIR[]
push!(passes, "sroa-wrappers{dump_prellvm=true dump_postllvm=true instcombine=false instsimplify=true}")
push!(
passes,
"sroa-wrappers{dump_prellvm=true dump_postllvm=true instcombine=false instsimplify=true}",
)
else
push!(passes, "sroa-wrappers{instcombine=false instsimplify=true}")
end
Expand Down Expand Up @@ -558,7 +561,6 @@ end
const DEBUG_KERNEL = Ref{Bool}(false)
const DUMP_LLVMIR = Ref{Bool}(false)


const Raise = Ref{Bool}(false)

function compile_mlir!(
Expand Down Expand Up @@ -657,7 +659,7 @@ function compile_mlir!(
opt_passes2,
kern,
raise,
jit
jit,
],
',',
),
Expand Down Expand Up @@ -711,7 +713,7 @@ function compile_mlir!(
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
opt_passes2,
kern
kern,
],
',',
),
Expand Down Expand Up @@ -758,7 +760,7 @@ function compile_mlir!(
opt_passes2,
kern,
raise,
jit
jit,
],
',',
),
Expand All @@ -769,21 +771,21 @@ function compile_mlir!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
run_pass_pipeline!(
mod, join([
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
kern,
raise,
jit
], ',')
mod,
join(
[
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
kern,
raise,
jit,
],
',',
),
)
elseif optimize === :canonicalize
run_pass_pipeline!(
mod, "canonicalize"
)
run_pass_pipeline!(mod, "canonicalize")
elseif optimize === :just_batch
run_pass_pipeline!(
mod, "enzyme-batch"
)
run_pass_pipeline!(mod, "enzyme-batch")
elseif optimize !== :none
error("Invalid optimize option: $(Meta.quot(optimize))")
end
Expand Down
8 changes: 4 additions & 4 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ function initialize_dialect()
registry[]::MLIR.API.MlirDialectRegistry
)::Cvoid
if !passes_initialized[]
@ccall MLIR.API.mlir_c.InitializePasses(
registry[]::MLIR.API.MlirDialectRegistry
)::Cvoid
passes_initialized[] = true
@ccall MLIR.API.mlir_c.InitializePasses(
registry[]::MLIR.API.MlirDialectRegistry
)::Cvoid
passes_initialized[] = true
end
return nothing
end
Expand Down
16 changes: 8 additions & 8 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,15 @@ function __init__()
println(stdout, e)
end
else
if !Reactant.precompiling()
try
gpu = GPUClient()
backends["gpu"] = gpu
default_backend[] = gpu
catch e
println(stdout, e)
if !Reactant.precompiling()
try
gpu = GPUClient()
backends["gpu"] = gpu
default_backend[] = gpu
catch e
println(stdout, e)
end
end
end
end
end

Expand Down