diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 7b2c5bbf5d..19e94b2054 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -6,6 +6,7 @@ using ..Reactant: AnyTracedRArray, AnyTracedRMatrix, AnyTracedRVector, + AnyTracedRVecOrMat, unwrapped_eltype, Ops, MLIR @@ -347,4 +348,55 @@ function LinearAlgebra.ldiv!( return B end +# Kronecker Product +function LinearAlgebra.kron( + x::AnyTracedRVecOrMat{T1}, y::AnyTracedRVecOrMat{T2} +) where {T1,T2} + x = materialize_traced_array(x) + y = materialize_traced_array(y) + z = similar(x, Base.promote_op(*, T1, T2), LinearAlgebra._kronsize(x, y)) + LinearAlgebra.kron!(z, x, y) + return z +end + +function LinearAlgebra.kron(x::AnyTracedRVector{T1}, y::AnyTracedRVector{T2}) where {T1,T2} + x = materialize_traced_array(x) + y = materialize_traced_array(y) + z = similar(x, Base.promote_op(*, T1, T2), length(x) * length(y)) + LinearAlgebra.kron!(z, x, y) + return z +end + +function LinearAlgebra.kron!(C::AnyTracedRVector, A::AnyTracedRVector, B::AnyTracedRVector) + LinearAlgebra.kron!( + reshape(C, length(B), length(A)), reshape(A, 1, length(A)), reshape(B, length(B), 1) + ) + return C +end + +function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTracedRMatrix) + A = materialize_traced_array(A) + B = materialize_traced_array(B) + + final_shape = Int64[size(B, 1), size(A, 1), size(B, 2), size(A, 2)] + + A = Ops.broadcast_in_dim(A, Int64[2, 4], final_shape) + B = Ops.broadcast_in_dim(B, Int64[1, 3], final_shape) + + C_tmp = Ops.reshape(Ops.multiply(A, B), size(C)...) + set_mlir_data!(C, get_mlir_data(C_tmp)) + + return C +end + +function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRVector, B::AnyTracedRMatrix) + LinearAlgebra._kron!(C, reshape(A, length(A), 1), B) + return C +end + +function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTracedRVector) + LinearAlgebra._kron!(C, A, reshape(B, length(B), 1)) + return C +end + end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index ea39556f95..cd804d150e 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -169,3 +169,17 @@ mul_symmetric(x) = Symmetric(x) * x @test @jit(fn(x_ra)) ≈ fn(x) end end + +@testset "kron" begin + @testset for T in (Int64, Float64, ComplexF64) + @testset for (x_sz, y_sz) in [ + ((3, 4), (2, 5)), ((3, 4), (2,)), ((3,), (2, 5)), ((3,), (5,)), ((10,), ()) + ] + x = x_sz == () ? rand(T) : rand(T, x_sz) + y = y_sz == () ? rand(T) : rand(T, y_sz) + x_ra = Reactant.to_rarray(x; track_numbers=Number) + y_ra = Reactant.to_rarray(y; track_numbers=Number) + @test @jit(kron(x_ra, y_ra)) ≈ kron(x, y) + end + end +end