diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index 88d640e336..ad3aae51a2 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -155,6 +155,8 @@ Fine-grained control over the compilation options for the Reactant compiler. optimization passes. This is `false` by default. - `disable_pad_optimization_passes`: Disables the pad optimization passes. This is `false` by default. + - `disable_licm_optimization_passes`: Disables the Loop Invariant Code Motion (LICM) + optimization passes. This is `false` by default. """ struct CompileOptions optimization_passes::Union{Symbol,String} @@ -182,6 +184,7 @@ struct CompileOptions ## private options for ablation studies disable_scatter_gather_optimization_passes::Bool disable_pad_optimization_passes::Bool + disable_licm_optimization_passes::Bool end function CompileOptions(; @@ -204,6 +207,7 @@ function CompileOptions(; sync::Bool=false, disable_scatter_gather_optimization_passes::Bool=false, disable_pad_optimization_passes::Bool=false, + disable_licm_optimization_passes::Bool=false, ) optimization_passes isa Bool && (optimization_passes = ifelse(optimization_passes, :all, :none)) @@ -251,6 +255,7 @@ function CompileOptions(; sync, disable_scatter_gather_optimization_passes, disable_pad_optimization_passes, + disable_licm_optimization_passes, ) end @@ -291,6 +296,7 @@ function __compile_options_with_reversed_propagation(compile_options::CompileOpt compile_options.sync, compile_options.disable_scatter_gather_optimization_passes, compile_options.disable_pad_optimization_passes, + compile_options.disable_licm_optimization_passes, ) end diff --git a/src/Compiler.jl b/src/Compiler.jl index 2932390a5f..003719e80c 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -833,16 +833,11 @@ function optimization_passes( "slice_reduce_window<1>", "while_deadresult", "while_dus", - "dus_licm(0)", "while_op_induction_replacement", "dus_concat", "slice_dus_to_concat", "while_induction_reduction", - "slice_licm(0)", - "elementwise_licm(0)", - "concatenate_licm(0)", "slice_broadcast", - "while_licm<1>(1)", "associative_common_mul_op_reordering", "slice_select_to_select_slice", "slice_if", @@ -915,6 +910,22 @@ function optimization_passes( "concat_insert_dim_reduce_window", ] + if !compile_options.disable_licm_optimization_passes + append!( + transform_passes_list, + [ + "dus_licm(0)", + "slice_licm(0)", + "elementwise_licm(0)", + "concatenate_licm(0)", + "while_licm<1>(1)", + "transpose_licm(0)", + "broadcastindim_licm(0)", + "reshape_licm(0)", + ], + ) + end + if !compile_options.disable_scatter_gather_optimization_passes append!( transform_passes_list, @@ -973,7 +984,6 @@ function optimization_passes( "unary_pad_push_tanh<1>", "unary_pad_push_exp<1>", "concat_to_pad<1>", - "pad_licm(0)", "while_pad_induction_reduction", "pad_concat_to_concat_pad", "rotate_pad", @@ -981,6 +991,10 @@ function optimization_passes( "speculate_if_pad_to_select", ], ) + + if !compile_options.disable_licm_optimization_passes + push!(transform_passes_list, "pad_licm(0)") + end end # constant prop patterns diff --git a/src/Ops.jl b/src/Ops.jl index 805126ab46..203749ee2d 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2932,8 +2932,9 @@ end location=mlir_stacktrace("dynamic_slice", @__FILE__, @__LINE__), ) where {T,N} start_indices = [ - Reactant.TracedUtils.promote_to(TracedRNumber{Int32}, index - 1).mlir_data for - index in start_indices + Reactant.TracedUtils.promote_to( + TracedRNumber{Int32}, index - Reactant.unwrapped_eltype(index)(1) + ).mlir_data for index in start_indices ] res = MLIR.IR.result( stablehlo.dynamic_slice(