diff --git a/Project.toml b/Project.toml index 1e59127584..547112a4b6 100644 --- a/Project.toml +++ b/Project.toml @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.262" +Reactant_jll = "0.0.263" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" diff --git a/src/Compiler.jl b/src/Compiler.jl index 1366e87a1c..278eab497c 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1722,6 +1722,8 @@ function compile_mlir!( blas_int_width = sizeof(BlasInt) * 8 lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \ + blas_int_width=$blas_int_width},\ + lower-enzymexla-lapack{backend=$backend \ blas_int_width=$blas_int_width}" legalize_chlo_to_stablehlo = diff --git a/src/Ops.jl b/src/Ops.jl index 22bf67679b..ed162501ba 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -3312,6 +3312,60 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors ` return (res, ipiv, perm, info) end +@noinline function svd( + x::TracedRArray{T,N}, + ::Type{iT}=Int32; + full::Bool=false, + algorithm::String="DEFAULT", + location=mlir_stacktrace("svd", @__FILE__, @__LINE__), +) where {T,iT,N} + @assert N >= 2 + + batch_sizes = size(x)[1:(end - 2)] + m, n = size(x)[(end - 1):end] + r = min(m, n) + + U_size = (batch_sizes..., m, full ? m : r) + S_size = (batch_sizes..., r) + Vt_size = (batch_sizes..., full ? n : r, n) + info_size = batch_sizes + + if algorithm == "DEFAULT" + algint = 0 + elseif algorithm == "QRIteration" + algint = 1 + elseif algorithm == "DivideAndConquer" + algint = 2 + elseif algorithm == "Jacobi" + algint = 3 + else + error("Unsupported SVD algorithm: $algorithm") + end + + svd_op = enzymexla.linalg_svd( + x.mlir_data; + U=mlir_type(TracedRArray{T,N}, U_size), + S=mlir_type(TracedRArray{Base.real(T),N - 1}, S_size), + Vt=mlir_type(TracedRArray{T,N}, Vt_size), + info=mlir_type(TracedRArray{iT,N - 2}, info_size), + full=full, + algorithm=MLIR.API.enzymexlaSVDAlgorithmAttrGet(MLIR.IR.context(), algint), + location, + ) + + U = TracedRArray{T,N}((), MLIR.IR.result(svd_op, 1), U_size) + S = TracedRArray{Base.real(T),N - 1}((), MLIR.IR.result(svd_op, 2), S_size) + Vt = TracedRArray{T,N}((), MLIR.IR.result(svd_op, 3), Vt_size) + + if N == 2 + info = TracedRNumber{iT}((), MLIR.IR.result(svd_op, 4)) + else + info = TracedRArray{iT,N - 2}((), MLIR.IR.result(svd_op, 4), info_size) + end + + return U, S, Vt, info +end + @noinline function reduce_window( f::F, inputs::Vector{TracedRArray{T,N}}, diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index ff67859c5e..9bfc2ba2ae 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -25,10 +25,26 @@ function __init__() libblastrampoline_handle = Libdl.dlopen(BLAS.libblas) for (cname, enzymexla_name) in [ + # LU (BLAS.@blasfunc(sgetrf_), :enzymexla_lapack_sgetrf_), (BLAS.@blasfunc(dgetrf_), :enzymexla_lapack_dgetrf_), (BLAS.@blasfunc(cgetrf_), :enzymexla_lapack_cgetrf_), (BLAS.@blasfunc(zgetrf_), :enzymexla_lapack_zgetrf_), + # SVD QR Iteration + (BLAS.@blasfunc(sgesvd_), :enzymexla_lapack_sgesvd_), + (BLAS.@blasfunc(dgesvd_), :enzymexla_lapack_dgesvd_), + (BLAS.@blasfunc(cgesvd_), :enzymexla_lapack_cgesvd_), + (BLAS.@blasfunc(zgesvd_), :enzymexla_lapack_zgesvd_), + # SVD Divide and Conquer + (BLAS.@blasfunc(sgesdd_), :enzymexla_lapack_sgesdd_), + (BLAS.@blasfunc(dgesdd_), :enzymexla_lapack_dgesdd_), + (BLAS.@blasfunc(cgesdd_), :enzymexla_lapack_cgesdd_), + (BLAS.@blasfunc(zgesdd_), :enzymexla_lapack_zgesdd_), + # SVD Jacobi + (BLAS.@blasfunc(sgesvj_), :enzymexla_lapack_sgesvj_), + (BLAS.@blasfunc(dgesvj_), :enzymexla_lapack_dgesvj_), + (BLAS.@blasfunc(cgesvj_), :enzymexla_lapack_cgesvj_), + (BLAS.@blasfunc(zgesvj_), :enzymexla_lapack_zgesvj_), ] sym = Libdl.dlsym(libblastrampoline_handle, cname) @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(