Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Reactant"
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>", "Mosè Giordano <mose@gnu.org>"]
version = "0.2.180"
version = "0.2.181"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -105,7 +105,7 @@ PythonCall = "0.9.25"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.16"
Reactant_jll = "0.0.265"
Reactant_jll = "0.0.267"
ScopedValues = "1.3.0"
Scratch = "1.2"
Sockets = "1.10"
Expand Down
32 changes: 29 additions & 3 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,8 @@ function optimization_passes(
recognize_comms::Bool=true,
lower_comms::Bool=true,
backend::String="gpu",
is_sharded::Bool=false,
raise_shlo_to_blas_lapack::Bool=true,
)
(; max_constant_threshold) = compile_options

Expand Down Expand Up @@ -909,8 +911,19 @@ function optimization_passes(
"transpose_symmetric_simplify",
"divide_negated_operands_simplify",
"multiply_negated_operands_simplify",
"transpose_syrk_to_syrk",
"fuse_mul_into_syrk",
"fuse_add_into_syrk",
"factor_scalars_in_dot_general",
]

if !is_sharded
# these passes don't have optimized sharding implementations
if raise_shlo_to_blas_lapack
append!(transform_passes_list, ["dot_general_to_syrk"])
end
end

if !compile_options.disable_auto_batching_passes
append!(
transform_passes_list,
Expand Down Expand Up @@ -1693,10 +1706,10 @@ function compile_mlir!(
end

opt_passes = optimization_passes(
compile_options; sroa=true, recognize_comms, lower_comms, backend
compile_options; sroa=true, recognize_comms, lower_comms, backend, is_sharded
)
opt_passes2 = optimization_passes(
compile_options; sroa=false, recognize_comms, lower_comms, backend
compile_options; sroa=false, recognize_comms, lower_comms, backend, is_sharded
)

raise_passes = if raise isa String
Expand All @@ -1718,6 +1731,7 @@ function compile_mlir!(
recognize_comms,
lower_comms,
backend,
is_sharded,
)
result = result * "," * opt_passes3
end
Expand All @@ -1728,6 +1742,8 @@ function compile_mlir!(

blas_int_width = sizeof(BlasInt) * 8
lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \
blas_int_width=$blas_int_width},\
lower-enzymexla-blas{backend=$backend \
blas_int_width=$blas_int_width},\
lower-enzymexla-lapack{backend=$backend \
blas_int_width=$blas_int_width}"
Expand Down Expand Up @@ -2012,6 +2028,8 @@ function compile_mlir!(
recognize_comms,
lower_comms,
backend,
is_sharded,
raise_shlo_to_blas_lapack=false,
),
"post_op_transpose_reshape",
)
Expand Down Expand Up @@ -2154,7 +2172,15 @@ function compile_mlir!(
run_pass_pipeline!(
mod,
join(
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2],
[
opt_passes,
"canonicalize",
"cse",
"canonicalize",
opt_passes2,
lower_enzymexla_linalg_pass,
jit,
],
",",
),
"mid_pad_opts",
Expand Down
15 changes: 15 additions & 0 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ function __init__()
(BLAS.@blasfunc(dgesvj_), :enzymexla_lapack_dgesvj_),
(BLAS.@blasfunc(cgesvj_), :enzymexla_lapack_cgesvj_),
(BLAS.@blasfunc(zgesvj_), :enzymexla_lapack_zgesvj_),
# syrk
(BLAS.@blasfunc(ssyrk_), :enzymexla_blas_ssyrk_),
(BLAS.@blasfunc(dsyrk_), :enzymexla_blas_dsyrk_),
(BLAS.@blasfunc(csyrk_), :enzymexla_blas_csyrk_),
(BLAS.@blasfunc(zsyrk_), :enzymexla_blas_zsyrk_),
# trmm
(BLAS.@blasfunc(strmm_), :enzymexla_blas_strmm_),
(BLAS.@blasfunc(dtrmm_), :enzymexla_blas_dtrmm_),
(BLAS.@blasfunc(ctrmm_), :enzymexla_blas_ctrmm_),
(BLAS.@blasfunc(ztrmm_), :enzymexla_blas_ztrmm_),
# symm
(BLAS.@blasfunc(ssymm_), :enzymexla_blas_ssymm_),
(BLAS.@blasfunc(dsymm_), :enzymexla_blas_dsymm_),
(BLAS.@blasfunc(csymm_), :enzymexla_blas_csymm_),
(BLAS.@blasfunc(zsymm_), :enzymexla_blas_zsymm_),
]
sym = Libdl.dlsym(libblastrampoline_handle, cname)
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(
Expand Down
21 changes: 21 additions & 0 deletions test/integration/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,24 @@ end
@jit LinearAlgebra.normalize!(x_ra)
@test x_ra ≈ x
end

raise_to_syrk(x, y) = 3 .* (x * transpose(x)) .+ 5 .* y
raise_to_syrk2(x, y) = 3 .* (transpose(x) * x) .+ 5 .* y

@testset "syrk optimizations" begin
@testset for elty in (Float32, Float64, ComplexF32, ComplexF64)
x = Reactant.TestUtils.construct_test_array(elty, 4, 5)
y1 = Reactant.TestUtils.construct_test_array(elty, 4, 4)
y2 = Reactant.TestUtils.construct_test_array(elty, 5, 5)
x_ra = Reactant.to_rarray(x)

@testset for (fn, y) in ((raise_to_syrk, y1), (raise_to_syrk2, y2))
y_ra = Reactant.to_rarray(y)

hlo = @code_hlo optimize = :before_jit fn(x_ra, y_ra)
@test occursin("enzymexla.blas.syrk", repr(hlo))

@test @jit(fn(x_ra, y_ra)) ≈ fn(x, y)
end
end
end
Loading