diff --git a/src/Compiler.jl b/src/Compiler.jl index 0639c142d5..bdbbf230f2 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1738,6 +1738,8 @@ function compile_mlir!( () end + legal_to_run_shardy_passes = compile_options.optimization_passes === :all + if compile_options.optimization_passes === :all run_pass_pipeline!( mod, @@ -2173,7 +2175,7 @@ function compile_mlir!( # shardy passes use_shardy_partitioner = false result_shardings = missing - if is_sharded + if is_sharded && legal_to_run_shardy_passes module_op = copy(MLIR.IR.Operation(mod)) mod_copied = MLIR.IR.Module(module_op)