From d78043ab7ad13afca67b69dbb88f4b37f42fbc36 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 16 Nov 2025 15:43:39 -0600 Subject: [PATCH 1/4] feat: svd op --- src/Ops.jl | 40 ++++++++++++++++++++++++++++++++++++ src/stdlibs/LinearAlgebra.jl | 4 ++++ 2 files changed, 44 insertions(+) diff --git a/src/Ops.jl b/src/Ops.jl index 22bf67679b..cae1c9bc90 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -3312,6 +3312,46 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors ` return (res, ipiv, perm, info) end +@noinline function svd( + x::TracedRArray{T,N}, + ::Type{iT}=Int32; + full::Bool=false, + location=mlir_stacktrace("svd", @__FILE__, @__LINE__), +) where {T,iT,N} + @assert N >= 2 + + batch_sizes = size(x)[1:(end - 2)] + m, n = size(x)[(end - 1):end] + r = min(m, n) + + U_size = (batch_sizes..., m, full ? m : r) + S_size = (batch_sizes..., r) + Vt_size = (batch_sizes..., full ? n : r, n) + info_size = batch_sizes + + svd_op = enzymexla.linalg_svd( + x.mlir_data; + U=mlir_type(TracedRArray{T,N}, U_size), + S=mlir_type(TracedRArray{Base.real(T),N - 1}, S_size), + Vt=mlir_type(TracedRArray{T,N}, Vt_size), + info=mlir_type(TracedRArray{iT,N - 2}, info_size), + full=full, + location, + ) + + U = TracedRArray{T,N}((), MLIR.IR.result(svd_op, 1), U_size) + S = TracedRArray{Base.real(T),N - 1}((), MLIR.IR.result(svd_op, 2), S_size) + Vt = TracedRArray{T,N}((), MLIR.IR.result(svd_op, 3), Vt_size) + + if N == 2 + info = TracedRNumber{iT}((), MLIR.IR.result(svd_op, 4)) + else + info = TracedRArray{iT,N - 2}((), MLIR.IR.result(svd_op, 4), info_size) + end + + return U, S, Vt, info +end + @noinline function reduce_window( f::F, inputs::Vector{TracedRArray{T,N}}, diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index ff67859c5e..055077c09a 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -29,6 +29,10 @@ function __init__() (BLAS.@blasfunc(dgetrf_), :enzymexla_lapack_dgetrf_), (BLAS.@blasfunc(cgetrf_), :enzymexla_lapack_cgetrf_), (BLAS.@blasfunc(zgetrf_), :enzymexla_lapack_zgetrf_), + (BLAS.@blasfunc(sgesvd_), :enzymexla_lapack_sgesvd_), + (BLAS.@blasfunc(dgesvd_), :enzymexla_lapack_dgesvd_), + (BLAS.@blasfunc(cgesvd_), :enzymexla_lapack_cgesvd_), + (BLAS.@blasfunc(zgesvd_), :enzymexla_lapack_zgesvd_), ] sym = Libdl.dlsym(libblastrampoline_handle, cname) @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( From 8245a8188f5207681f646f19104b040ed6737f7d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 17 Nov 2025 15:21:21 -0600 Subject: [PATCH 2/4] feat: map more symbols --- src/stdlibs/LinearAlgebra.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 055077c09a..9bfc2ba2ae 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -25,14 +25,26 @@ function __init__() libblastrampoline_handle = Libdl.dlopen(BLAS.libblas) for (cname, enzymexla_name) in [ + # LU (BLAS.@blasfunc(sgetrf_), :enzymexla_lapack_sgetrf_), (BLAS.@blasfunc(dgetrf_), :enzymexla_lapack_dgetrf_), (BLAS.@blasfunc(cgetrf_), :enzymexla_lapack_cgetrf_), (BLAS.@blasfunc(zgetrf_), :enzymexla_lapack_zgetrf_), + # SVD QR Iteration (BLAS.@blasfunc(sgesvd_), :enzymexla_lapack_sgesvd_), (BLAS.@blasfunc(dgesvd_), :enzymexla_lapack_dgesvd_), (BLAS.@blasfunc(cgesvd_), :enzymexla_lapack_cgesvd_), (BLAS.@blasfunc(zgesvd_), :enzymexla_lapack_zgesvd_), + # SVD Divide and Conquer + (BLAS.@blasfunc(sgesdd_), :enzymexla_lapack_sgesdd_), + (BLAS.@blasfunc(dgesdd_), :enzymexla_lapack_dgesdd_), + (BLAS.@blasfunc(cgesdd_), :enzymexla_lapack_cgesdd_), + (BLAS.@blasfunc(zgesdd_), :enzymexla_lapack_zgesdd_), + # SVD Jacobi + (BLAS.@blasfunc(sgesvj_), :enzymexla_lapack_sgesvj_), + (BLAS.@blasfunc(dgesvj_), :enzymexla_lapack_dgesvj_), + (BLAS.@blasfunc(cgesvj_), :enzymexla_lapack_cgesvj_), + (BLAS.@blasfunc(zgesvj_), :enzymexla_lapack_zgesvj_), ] sym = Libdl.dlsym(libblastrampoline_handle, cname) @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( From 36c4a2fc8cd9ed5f90cc0e7a52b3a2b3ee298f32 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 17 Nov 2025 20:36:20 -0600 Subject: [PATCH 3/4] feat: update to allow algorithms --- src/Compiler.jl | 2 ++ src/Ops.jl | 14 ++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/Compiler.jl b/src/Compiler.jl index 1366e87a1c..278eab497c 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1722,6 +1722,8 @@ function compile_mlir!( blas_int_width = sizeof(BlasInt) * 8 lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \ + blas_int_width=$blas_int_width},\ + lower-enzymexla-lapack{backend=$backend \ blas_int_width=$blas_int_width}" legalize_chlo_to_stablehlo = diff --git a/src/Ops.jl b/src/Ops.jl index cae1c9bc90..ed162501ba 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -3316,6 +3316,7 @@ end x::TracedRArray{T,N}, ::Type{iT}=Int32; full::Bool=false, + algorithm::String="DEFAULT", location=mlir_stacktrace("svd", @__FILE__, @__LINE__), ) where {T,iT,N} @assert N >= 2 @@ -3329,6 +3330,18 @@ end Vt_size = (batch_sizes..., full ? n : r, n) info_size = batch_sizes + if algorithm == "DEFAULT" + algint = 0 + elseif algorithm == "QRIteration" + algint = 1 + elseif algorithm == "DivideAndConquer" + algint = 2 + elseif algorithm == "Jacobi" + algint = 3 + else + error("Unsupported SVD algorithm: $algorithm") + end + svd_op = enzymexla.linalg_svd( x.mlir_data; U=mlir_type(TracedRArray{T,N}, U_size), @@ -3336,6 +3349,7 @@ end Vt=mlir_type(TracedRArray{T,N}, Vt_size), info=mlir_type(TracedRArray{iT,N - 2}, info_size), full=full, + algorithm=MLIR.API.enzymexlaSVDAlgorithmAttrGet(MLIR.IR.context(), algint), location, ) From bbf907dc8bec8592084fcd2d13409d4212bb9590 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 18 Nov 2025 18:50:02 -0500 Subject: [PATCH 4/4] chore: bump reactant_jll version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1e59127584..547112a4b6 100644 --- a/Project.toml +++ b/Project.toml @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.262" +Reactant_jll = "0.0.263" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10"