From 6bc4a7da8406d2822e17501426dec97403154e4b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 4 Sep 2025 12:37:18 -0400 Subject: [PATCH 1/2] fix: missing licm passes + accidental promotions --- src/Compiler.jl | 14 +++++++++----- src/Ops.jl | 5 +++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 2932390a5f..aae58965fc 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -833,16 +833,11 @@ function optimization_passes( "slice_reduce_window<1>", "while_deadresult", "while_dus", - "dus_licm(0)", "while_op_induction_replacement", "dus_concat", "slice_dus_to_concat", "while_induction_reduction", - "slice_licm(0)", - "elementwise_licm(0)", - "concatenate_licm(0)", "slice_broadcast", - "while_licm<1>(1)", "associative_common_mul_op_reordering", "slice_select_to_select_slice", "slice_if", @@ -913,6 +908,15 @@ function optimization_passes( "concat_insert_dim_reduce", "concat_insert_dim_sort", "concat_insert_dim_reduce_window", + # licm patterns + "dus_licm(0)", + "slice_licm(0)", + "elementwise_licm(0)", + "concatenate_licm(0)", + "while_licm<1>(1)", + "transpose_licm(0)", + "broadcastindim_licm(0)", + "reshape_licm(0)", ] if !compile_options.disable_scatter_gather_optimization_passes diff --git a/src/Ops.jl b/src/Ops.jl index 805126ab46..203749ee2d 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2932,8 +2932,9 @@ end location=mlir_stacktrace("dynamic_slice", @__FILE__, @__LINE__), ) where {T,N} start_indices = [ - Reactant.TracedUtils.promote_to(TracedRNumber{Int32}, index - 1).mlir_data for - index in start_indices + Reactant.TracedUtils.promote_to( + TracedRNumber{Int32}, index - Reactant.unwrapped_eltype(index)(1) + ).mlir_data for index in start_indices ] res = MLIR.IR.result( stablehlo.dynamic_slice( From 522c70a06e7c52579ba248cd481920a2e6e59316 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 4 Sep 2025 14:01:03 -0400 Subject: [PATCH 2/2] feat: add a flag for the licm passes --- src/CompileOptions.jl | 6 ++++++ src/Compiler.jl | 30 ++++++++++++++++++++---------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index 88d640e336..ad3aae51a2 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -155,6 +155,8 @@ Fine-grained control over the compilation options for the Reactant compiler. optimization passes. This is `false` by default. - `disable_pad_optimization_passes`: Disables the pad optimization passes. This is `false` by default. + - `disable_licm_optimization_passes`: Disables the Loop Invariant Code Motion (LICM) + optimization passes. This is `false` by default. """ struct CompileOptions optimization_passes::Union{Symbol,String} @@ -182,6 +184,7 @@ struct CompileOptions ## private options for ablation studies disable_scatter_gather_optimization_passes::Bool disable_pad_optimization_passes::Bool + disable_licm_optimization_passes::Bool end function CompileOptions(; @@ -204,6 +207,7 @@ function CompileOptions(; sync::Bool=false, disable_scatter_gather_optimization_passes::Bool=false, disable_pad_optimization_passes::Bool=false, + disable_licm_optimization_passes::Bool=false, ) optimization_passes isa Bool && (optimization_passes = ifelse(optimization_passes, :all, :none)) @@ -251,6 +255,7 @@ function CompileOptions(; sync, disable_scatter_gather_optimization_passes, disable_pad_optimization_passes, + disable_licm_optimization_passes, ) end @@ -291,6 +296,7 @@ function __compile_options_with_reversed_propagation(compile_options::CompileOpt compile_options.sync, compile_options.disable_scatter_gather_optimization_passes, compile_options.disable_pad_optimization_passes, + compile_options.disable_licm_optimization_passes, ) end diff --git a/src/Compiler.jl b/src/Compiler.jl index aae58965fc..003719e80c 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -908,17 +908,24 @@ function optimization_passes( "concat_insert_dim_reduce", "concat_insert_dim_sort", "concat_insert_dim_reduce_window", - # licm patterns - "dus_licm(0)", - "slice_licm(0)", - "elementwise_licm(0)", - "concatenate_licm(0)", - "while_licm<1>(1)", - "transpose_licm(0)", - "broadcastindim_licm(0)", - "reshape_licm(0)", ] + if !compile_options.disable_licm_optimization_passes + append!( + transform_passes_list, + [ + "dus_licm(0)", + "slice_licm(0)", + "elementwise_licm(0)", + "concatenate_licm(0)", + "while_licm<1>(1)", + "transpose_licm(0)", + "broadcastindim_licm(0)", + "reshape_licm(0)", + ], + ) + end + if !compile_options.disable_scatter_gather_optimization_passes append!( transform_passes_list, @@ -977,7 +984,6 @@ function optimization_passes( "unary_pad_push_tanh<1>", "unary_pad_push_exp<1>", "concat_to_pad<1>", - "pad_licm(0)", "while_pad_induction_reduction", "pad_concat_to_concat_pad", "rotate_pad", @@ -985,6 +991,10 @@ function optimization_passes( "speculate_if_pad_to_select", ], ) + + if !compile_options.disable_licm_optimization_passes + push!(transform_passes_list, "pad_licm(0)") + end end # constant prop patterns