From 4d263767a50e5e7d7689fa398641072b65d96e93 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 24 Nov 2025 22:48:58 -0600 Subject: [PATCH 1/7] feat: new syrk passes + lowering [skip ci] --- src/Compiler.jl | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index d014d6a202..95a8e2219d 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -703,6 +703,7 @@ function optimization_passes( recognize_comms::Bool=true, lower_comms::Bool=true, backend::String="gpu", + is_sharded::Bool=false, ) (; max_constant_threshold) = compile_options @@ -909,8 +910,16 @@ function optimization_passes( "transpose_symmetric_simplify", "divide_negated_operands_simplify", "multiply_negated_operands_simplify", + "transpose_syrk_to_syrk", + "fuse_mul_into_syrk", + "fuse_add_into_syrk", ] + if !is_sharded + # these passes don't have optimized sharding implementations + append!(transform_passes_list, ["dot_general_to_syrk"]) + end + if !compile_options.disable_auto_batching_passes append!( transform_passes_list, @@ -1693,10 +1702,10 @@ function compile_mlir!( end opt_passes = optimization_passes( - compile_options; sroa=true, recognize_comms, lower_comms, backend + compile_options; sroa=true, recognize_comms, lower_comms, backend, is_sharded ) opt_passes2 = optimization_passes( - compile_options; sroa=false, recognize_comms, lower_comms, backend + compile_options; sroa=false, recognize_comms, lower_comms, backend, is_sharded ) raise_passes = if raise isa String @@ -1718,6 +1727,7 @@ function compile_mlir!( recognize_comms, lower_comms, backend, + is_sharded, ) result = result * "," * opt_passes3 end @@ -1728,6 +1738,8 @@ function compile_mlir!( blas_int_width = sizeof(BlasInt) * 8 lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \ + blas_int_width=$blas_int_width},\ + lower-enzymexla-blas{backend=$backend \ blas_int_width=$blas_int_width},\ lower-enzymexla-lapack{backend=$backend \ blas_int_width=$blas_int_width}" @@ -2012,6 +2024,7 @@ function compile_mlir!( recognize_comms, lower_comms, backend, + is_sharded, ), "post_op_transpose_reshape", ) From ff5c046b34d46cf58f16dceddb7f36de01cf5f3c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 25 Nov 2025 14:00:33 -0600 Subject: [PATCH 2/7] test: raising to syrk --- src/stdlibs/LinearAlgebra.jl | 15 +++++++++++++++ test/integration/linear_algebra.jl | 21 +++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index d575be61b0..7c329b8ae0 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -46,6 +46,21 @@ function __init__() (BLAS.@blasfunc(dgesvj_), :enzymexla_lapack_dgesvj_), (BLAS.@blasfunc(cgesvj_), :enzymexla_lapack_cgesvj_), (BLAS.@blasfunc(zgesvj_), :enzymexla_lapack_zgesvj_), + # syrk + (BLAS.@blasfunc(ssyrk_), :enzymexla_blas_ssyrk_), + (BLAS.@blasfunc(dsyrk_), :enzymexla_blas_dsyrk_), + (BLAS.@blasfunc(csyrk_), :enzymexla_blas_csyrk_), + (BLAS.@blasfunc(zsyrk_), :enzymexla_blas_zsyrk_), + # trmm + (BLAS.@blasfunc(strmm_), :enzymexla_blas_strmm_), + (BLAS.@blasfunc(dtrmm_), :enzymexla_blas_dtrmm_), + (BLAS.@blasfunc(ctrmm_), :enzymexla_blas_ctrmm_), + (BLAS.@blasfunc(ztrmm_), :enzymexla_blas_ztrmm_), + # symm + (BLAS.@blasfunc(ssymm_), :enzymexla_blas_ssymm_), + (BLAS.@blasfunc(dsymm_), :enzymexla_blas_dsymm_), + (BLAS.@blasfunc(csymm_), :enzymexla_blas_csymm_), + (BLAS.@blasfunc(zsymm_), :enzymexla_blas_zsymm_), ] sym = Libdl.dlsym(libblastrampoline_handle, cname) @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index d9db88932a..4c0be8170d 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -723,3 +723,24 @@ end @jit LinearAlgebra.normalize!(x_ra) @test x_ra ≈ x end + +raise_to_syrk(x, y) = 3 .* (x * transpose(x)) .+ 5 .* y +raise_to_syrk2(x, y) = 3 .* (transpose(x) * x) .+ 5 .* y + +@testset "syrk optimizations" begin + @testset for elty in (Float32, Float64, ComplexF32, ComplexF64) + x = Reactant.TestUtils.construct_test_array(elty, 4, 5) + y1 = Reactant.TestUtils.construct_test_array(elty, 4, 4) + y2 = Reactant.TestUtils.construct_test_array(elty, 5, 5) + x_ra = Reactant.to_rarray(x) + + @testset for (fn, y) in ((raise_to_syrk, y1), (raise_to_syrk2, y2)) + y_ra = Reactant.to_rarray(y) + + hlo = @code_hlo optimize=:before_jit fn(x_ra, y_ra) + @test occursin("enzymexla.blas.syrk", repr(hlo)) + + @test @jit(fn(x_ra, y_ra)) ≈ fn(x, y) + end + end +end From 4787af127f2b2960b3709d10380d867725125562 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 25 Nov 2025 14:01:42 -0600 Subject: [PATCH 3/7] feat: more passes --- src/Compiler.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Compiler.jl b/src/Compiler.jl index 95a8e2219d..1a652763ca 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -913,6 +913,7 @@ function optimization_passes( "transpose_syrk_to_syrk", "fuse_mul_into_syrk", "fuse_add_into_syrk", + "factor_scalars_in_dot_general", ] if !is_sharded From f3d54a24fdd3500e29372a8f304d899797305dd0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 25 Nov 2025 14:02:22 -0600 Subject: [PATCH 4/7] chore: run formatting --- test/integration/linear_algebra.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 4c0be8170d..cb3fe29577 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -737,7 +737,7 @@ raise_to_syrk2(x, y) = 3 .* (transpose(x) * x) .+ 5 .* y @testset for (fn, y) in ((raise_to_syrk, y1), (raise_to_syrk2, y2)) y_ra = Reactant.to_rarray(y) - hlo = @code_hlo optimize=:before_jit fn(x_ra, y_ra) + hlo = @code_hlo optimize = :before_jit fn(x_ra, y_ra) @test occursin("enzymexla.blas.syrk", repr(hlo)) @test @jit(fn(x_ra, y_ra)) ≈ fn(x, y) From 75ca0b399298457bfb68bc760e5a8912799b0e8f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 25 Nov 2025 15:07:58 -0600 Subject: [PATCH 5/7] fix: dont accidentally raise after fallback lowering --- src/Compiler.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 1a652763ca..7beda792f4 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -704,6 +704,7 @@ function optimization_passes( lower_comms::Bool=true, backend::String="gpu", is_sharded::Bool=false, + raise_shlo_to_blas_lapack::Bool=true, ) (; max_constant_threshold) = compile_options @@ -918,7 +919,9 @@ function optimization_passes( if !is_sharded # these passes don't have optimized sharding implementations - append!(transform_passes_list, ["dot_general_to_syrk"]) + if raise_shlo_to_blas_lapack + append!(transform_passes_list, ["dot_general_to_syrk"]) + end end if !compile_options.disable_auto_batching_passes @@ -2026,6 +2029,7 @@ function compile_mlir!( lower_comms, backend, is_sharded, + raise_shlo_to_blas_lapack=false, ), "post_op_transpose_reshape", ) @@ -2168,7 +2172,15 @@ function compile_mlir!( run_pass_pipeline!( mod, join( - [opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2], + [ + opt_passes, + "canonicalize", + "cse", + "canonicalize", + opt_passes2, + lower_enzymexla_linalg_pass, + jit, + ], ",", ), "mid_pad_opts", From 34dc4ae028b04c6dfc4cb4eb158b3dfc93dafec3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 28 Nov 2025 21:57:19 -0500 Subject: [PATCH 6/7] chore: bump versions --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 0dda8a7812..b3f5df9fa3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal ", "Mosè Giordano "] -version = "0.2.180" +version = "0.2.181" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.265" +Reactant_jll = "0.0.266" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" From ebfc9b20e9907160363437f92e24ec4016353d59 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 30 Nov 2025 10:28:35 -0500 Subject: [PATCH 7/7] chore: bump versions --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b3f5df9fa3..af33c5249a 100644 --- a/Project.toml +++ b/Project.toml @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.266" +Reactant_jll = "0.0.267" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10"