diff --git a/Project.toml b/Project.toml index 0dda8a7812..af33c5249a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal ", "Mosè Giordano "] -version = "0.2.180" +version = "0.2.181" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.265" +Reactant_jll = "0.0.267" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" diff --git a/src/Compiler.jl b/src/Compiler.jl index d014d6a202..7beda792f4 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -703,6 +703,8 @@ function optimization_passes( recognize_comms::Bool=true, lower_comms::Bool=true, backend::String="gpu", + is_sharded::Bool=false, + raise_shlo_to_blas_lapack::Bool=true, ) (; max_constant_threshold) = compile_options @@ -909,8 +911,19 @@ function optimization_passes( "transpose_symmetric_simplify", "divide_negated_operands_simplify", "multiply_negated_operands_simplify", + "transpose_syrk_to_syrk", + "fuse_mul_into_syrk", + "fuse_add_into_syrk", + "factor_scalars_in_dot_general", ] + if !is_sharded + # these passes don't have optimized sharding implementations + if raise_shlo_to_blas_lapack + append!(transform_passes_list, ["dot_general_to_syrk"]) + end + end + if !compile_options.disable_auto_batching_passes append!( transform_passes_list, @@ -1693,10 +1706,10 @@ function compile_mlir!( end opt_passes = optimization_passes( - compile_options; sroa=true, recognize_comms, lower_comms, backend + compile_options; sroa=true, recognize_comms, lower_comms, backend, is_sharded ) opt_passes2 = optimization_passes( - compile_options; sroa=false, recognize_comms, lower_comms, backend + compile_options; sroa=false, recognize_comms, lower_comms, backend, is_sharded ) raise_passes = if raise isa String @@ -1718,6 +1731,7 @@ function compile_mlir!( recognize_comms, lower_comms, backend, + is_sharded, ) result = result * "," * opt_passes3 end @@ -1728,6 +1742,8 @@ function compile_mlir!( blas_int_width = sizeof(BlasInt) * 8 lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \ + blas_int_width=$blas_int_width},\ + lower-enzymexla-blas{backend=$backend \ blas_int_width=$blas_int_width},\ lower-enzymexla-lapack{backend=$backend \ blas_int_width=$blas_int_width}" @@ -2012,6 +2028,8 @@ function compile_mlir!( recognize_comms, lower_comms, backend, + is_sharded, + raise_shlo_to_blas_lapack=false, ), "post_op_transpose_reshape", ) @@ -2154,7 +2172,15 @@ function compile_mlir!( run_pass_pipeline!( mod, join( - [opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2], + [ + opt_passes, + "canonicalize", + "cse", + "canonicalize", + opt_passes2, + lower_enzymexla_linalg_pass, + jit, + ], ",", ), "mid_pad_opts", diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index d575be61b0..7c329b8ae0 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -46,6 +46,21 @@ function __init__() (BLAS.@blasfunc(dgesvj_), :enzymexla_lapack_dgesvj_), (BLAS.@blasfunc(cgesvj_), :enzymexla_lapack_cgesvj_), (BLAS.@blasfunc(zgesvj_), :enzymexla_lapack_zgesvj_), + # syrk + (BLAS.@blasfunc(ssyrk_), :enzymexla_blas_ssyrk_), + (BLAS.@blasfunc(dsyrk_), :enzymexla_blas_dsyrk_), + (BLAS.@blasfunc(csyrk_), :enzymexla_blas_csyrk_), + (BLAS.@blasfunc(zsyrk_), :enzymexla_blas_zsyrk_), + # trmm + (BLAS.@blasfunc(strmm_), :enzymexla_blas_strmm_), + (BLAS.@blasfunc(dtrmm_), :enzymexla_blas_dtrmm_), + (BLAS.@blasfunc(ctrmm_), :enzymexla_blas_ctrmm_), + (BLAS.@blasfunc(ztrmm_), :enzymexla_blas_ztrmm_), + # symm + (BLAS.@blasfunc(ssymm_), :enzymexla_blas_ssymm_), + (BLAS.@blasfunc(dsymm_), :enzymexla_blas_dsymm_), + (BLAS.@blasfunc(csymm_), :enzymexla_blas_csymm_), + (BLAS.@blasfunc(zsymm_), :enzymexla_blas_zsymm_), ] sym = Libdl.dlsym(libblastrampoline_handle, cname) @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index d9db88932a..cb3fe29577 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -723,3 +723,24 @@ end @jit LinearAlgebra.normalize!(x_ra) @test x_ra ≈ x end + +raise_to_syrk(x, y) = 3 .* (x * transpose(x)) .+ 5 .* y +raise_to_syrk2(x, y) = 3 .* (transpose(x) * x) .+ 5 .* y + +@testset "syrk optimizations" begin + @testset for elty in (Float32, Float64, ComplexF32, ComplexF64) + x = Reactant.TestUtils.construct_test_array(elty, 4, 5) + y1 = Reactant.TestUtils.construct_test_array(elty, 4, 4) + y2 = Reactant.TestUtils.construct_test_array(elty, 5, 5) + x_ra = Reactant.to_rarray(x) + + @testset for (fn, y) in ((raise_to_syrk, y1), (raise_to_syrk2, y2)) + y_ra = Reactant.to_rarray(y) + + hlo = @code_hlo optimize = :before_jit fn(x_ra, y_ra) + @test occursin("enzymexla.blas.syrk", repr(hlo)) + + @test @jit(fn(x_ra, y_ra)) ≈ fn(x, y) + end + end +end