From a86d7f3df4a6873647f0d04ace2e7bc8ef04513d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Dec 2024 22:50:02 +0530 Subject: [PATCH 01/10] feat: add dynamic_update_slice_const_prop pass --- ext/ReactantNNlibExt.jl | 15 --------------- src/Compiler.jl | 1 + 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index b78716d291..b90fa60fb2 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -298,21 +298,6 @@ function NNlib.pad_constant( return TracedRArray{T,N}((), res, size(MLIR.IR.type(res))) end -function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2) - len = size(x, dims) - # directly generating booleans were causing an incorrect constant attribute generation - # but the optimized IR removes the type case so we are probably ok - mask = MLIR.IR.DenseElementsAttribute(collect(triu(fill(1, (len, len))))) - return Reactant.promote_to( - TracedRArray{Bool,2}, - TracedRArray{Int,2}( - (), - MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=mask), 1), - (len, len), - ), - ) -end - # XXX: reevaluate this manual optimization once # https://github.com/EnzymeAD/Enzyme-JAX/issues/164 is handled function NNlib.gather!( diff --git a/src/Compiler.jl b/src/Compiler.jl index 0bd3eaa4f8..586f33b053 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -245,6 +245,7 @@ const opt_passes::String = join( "pad_dot_general<1>(1)", "if_inline<1>", "if_to_select<1>", + "dynamic_update_slice_const_prop", ], ';', ) * From a579baaddf9502237b3823816f09c9542036a072 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Dec 2024 08:41:24 +0530 Subject: [PATCH 02/10] refactor: move linear algebra overloads to a different file --- src/Reactant.jl | 3 ++ src/TracedRArray.jl | 89 ------------------------------------------- src/linear_algebra.jl | 88 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 89 deletions(-) create mode 100644 src/linear_algebra.jl diff --git a/src/Reactant.jl b/src/Reactant.jl index 036edfa567..0b73d3d96e 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -98,8 +98,11 @@ include("utils.jl") include("ConcreteRArray.jl") include("TracedRNumber.jl") include("TracedRArray.jl") + include("Ops.jl") +include("linear_algebra.jl") + const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} include("ControlFlow.jl") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 91d8df0049..97a29e56f9 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -415,95 +415,6 @@ for (jlop, hloop, hlocomp, merge) in end end -function LinearAlgebra.mul!( - @nospecialize(C::TracedRArray{T1,1}), - @nospecialize(A::AnyTracedRArray{T2,2}), - @nospecialize(B::AnyTracedRArray{T3,1}), - α::Number=true, - β::Number=false, -) where {T1,T2,T3} - # TODO: The reshape operations are not getting optimized, we should directly call dot_general - rC = reshape(C, :, 1) - LinearAlgebra.mul!(rC, A, reshape(B, :, 1), α, β) - C.mlir_data = get_mlir_data(vec(rC)) - return C -end - -function LinearAlgebra.mul!( - @nospecialize(C::TracedRArray{T1,2}), - @nospecialize(A::AnyTracedRArray{T2,2}), - @nospecialize(B::AnyTracedRArray{T3,1}), - α::Number=true, - β::Number=false, -) where {T1,T2,T3} - LinearAlgebra.mul!(C, A, reshape(B, :, 1), α, β) - return C -end - -function LinearAlgebra.mul!( - @nospecialize(C::TracedRArray{T1,2}), - @nospecialize(A::AnyTracedRArray{T2,2}), - @nospecialize(B::AnyTracedRArray{T3,2}), - α::Number=true, - β::Number=false, -) where {T1,T2,T3} - if size(C) != (size(A, 1), size(B, 2)) - throw( - DimensionMismatch( - "C has size $(size(C)), A has size $(size(A)), B has size $(size(B))" - ), - ) - end - if size(A, 2) != size(B, 1) - throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B))")) - end - resty = MLIR.IR.TensorType(size(C), MLIR.IR.Type(T1)) - dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet( - MLIR.IR.context(), 0, [], 0, [], 1, [1], 1, [0] - ) - prec = MLIR.IR.Attribute( - MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT") - ) - precar = MLIR.IR.Attribute([prec, prec]) - res = MLIR.IR.result( - MLIR.Dialects.stablehlo.dot_general( - get_mlir_data(A), - get_mlir_data(B); - result_0=resty, - dot_dimension_numbers=dot_dimension_numbers, - precision_config=precar, - ), - 1, - ) - if iszero(β) - if isone(α) - C.mlir_data = res - else - C.mlir_data = MLIR.IR.result( - MLIR.Dialects.stablehlo.multiply( - res, broadcast_to_size(T1(α), size(C)).mlir_data - ), - 1, - ) - end - else - α_res = MLIR.IR.result( - MLIR.Dialects.stablehlo.multiply( - res, broadcast_to_size(T1(α), size(C)).mlir_data - ), - 1, - ) - β_C = MLIR.IR.result( - MLIR.Dialects.stablehlo.multiply( - C.mlir_data, broadcast_to_size(T1(β), size(C)).mlir_data - ), - 1, - ) - C.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(α_res, β_C), 1) - end - return C -end - function Enzyme.Compiler.active_reg_inner( ::Type{TracedRArray{T,N}}, seen::ST, diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl new file mode 100644 index 0000000000..29c9662c77 --- /dev/null +++ b/src/linear_algebra.jl @@ -0,0 +1,88 @@ +function LinearAlgebra.mul!( + @nospecialize(C::TracedRArray{T1,1}), + @nospecialize(A::AnyTracedRArray{T2,2}), + @nospecialize(B::AnyTracedRArray{T3,1}), + α::Number=true, + β::Number=false, +) where {T1,T2,T3} + # TODO: The reshape operations are not getting optimized, we should directly call dot_general + rC = reshape(C, :, 1) + LinearAlgebra.mul!(rC, A, reshape(B, :, 1), α, β) + C.mlir_data = get_mlir_data(vec(rC)) + return C +end + +function LinearAlgebra.mul!( + @nospecialize(C::TracedRArray{T1,2}), + @nospecialize(A::AnyTracedRArray{T2,2}), + @nospecialize(B::AnyTracedRArray{T3,1}), + α::Number=true, + β::Number=false, +) where {T1,T2,T3} + LinearAlgebra.mul!(C, A, reshape(B, :, 1), α, β) + return C +end + +function LinearAlgebra.mul!( + @nospecialize(C::TracedRArray{T1,2}), + @nospecialize(A::AnyTracedRArray{T2,2}), + @nospecialize(B::AnyTracedRArray{T3,2}), + α::Number=true, + β::Number=false, +) where {T1,T2,T3} + if size(C) != (size(A, 1), size(B, 2)) + throw( + DimensionMismatch( + "C has size $(size(C)), A has size $(size(A)), B has size $(size(B))" + ), + ) + end + if size(A, 2) != size(B, 1) + throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B))")) + end + resty = MLIR.IR.TensorType(size(C), MLIR.IR.Type(T1)) + dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet( + MLIR.IR.context(), 0, [], 0, [], 1, [1], 1, [0] + ) + prec = MLIR.IR.Attribute( + MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT") + ) + precar = MLIR.IR.Attribute([prec, prec]) + res = MLIR.IR.result( + MLIR.Dialects.stablehlo.dot_general( + get_mlir_data(A), + get_mlir_data(B); + result_0=resty, + dot_dimension_numbers=dot_dimension_numbers, + precision_config=precar, + ), + 1, + ) + if iszero(β) + if isone(α) + C.mlir_data = res + else + C.mlir_data = MLIR.IR.result( + MLIR.Dialects.stablehlo.multiply( + res, broadcast_to_size(T1(α), size(C)).mlir_data + ), + 1, + ) + end + else + α_res = MLIR.IR.result( + MLIR.Dialects.stablehlo.multiply( + res, broadcast_to_size(T1(α), size(C)).mlir_data + ), + 1, + ) + β_C = MLIR.IR.result( + MLIR.Dialects.stablehlo.multiply( + C.mlir_data, broadcast_to_size(T1(β), size(C)).mlir_data + ), + 1, + ) + C.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(α_res, β_C), 1) + end + return C +end From 39ae791936244518f03daa8a9095439d3fc6822e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Dec 2024 09:43:49 +0530 Subject: [PATCH 03/10] feat: add triu and tril impl --- src/linear_algebra.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 29c9662c77..82fcfbdd87 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -86,3 +86,17 @@ function LinearAlgebra.mul!( end return C end + +function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} + idxs = + Ops.iota(Int64, [size(X)...]; iota_dimension=1) .< + Ops.iota(Int64, [size(X)...]; iota_dimension=2) .- (k - 1) + return ifelse.(idxs, X, promote_to(TracedRNumber{T}, zero(T))) +end + +function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} + idxs = + Ops.iota(Int64, [size(X)...]; iota_dimension=1) .> + Ops.iota(Int64, [size(X)...]; iota_dimension=2) .- (k + 1) + return ifelse.(idxs, X, promote_to(TracedRNumber{T}, zero(T))) +end From 1fa25aea76b70d196de01114906acd4ba2819d40 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Dec 2024 09:58:15 +0530 Subject: [PATCH 04/10] refactor: minimize batch_op --- src/linear_algebra.jl | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 82fcfbdd87..790de98363 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -88,15 +88,19 @@ function LinearAlgebra.mul!( end function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} - idxs = - Ops.iota(Int64, [size(X)...]; iota_dimension=1) .< - Ops.iota(Int64, [size(X)...]; iota_dimension=2) .- (k - 1) - return ifelse.(idxs, X, promote_to(TracedRNumber{T}, zero(T))) + iota_1 = Ops.iota(Int64, [size(X)...]; iota_dimension=1) + iota_2 = Ops.subtract( + Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X)) + ) + idxs = iota_1 .≤ iota_2 + return Ops.select(idxs, X, zero(X)) end function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} - idxs = - Ops.iota(Int64, [size(X)...]; iota_dimension=1) .> - Ops.iota(Int64, [size(X)...]; iota_dimension=2) .- (k + 1) - return ifelse.(idxs, X, promote_to(TracedRNumber{T}, zero(T))) + iota_1 = Ops.iota(Int64, [size(X)...]; iota_dimension=1) + iota_2 = Ops.add( + Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X)) + ) + idxs = iota_1 .≥ iota_2 + return Ops.select(idxs, X, zero(X)) end From 49128c30bb4df99c2c9d82b27c0e6e4690be8ca0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Dec 2024 10:58:53 +0530 Subject: [PATCH 05/10] feat: add Ops.compare --- src/Ops.jl | 32 ++++++++++++++++++++++++++++++++ src/linear_algebra.jl | 6 ++---- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index e9f3d32520..f3fc6abb92 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1014,4 +1014,36 @@ function select( return TracedRNumber{T}((), res) end +# comparison +function compare( + lhs::Union{TracedRArray{T},TracedRNumber{T}}, + rhs::Union{TracedRArray{T},TracedRNumber{T}}; + comparison_direction::String, + compare_type=nothing, + location=mlir_stacktrace("compare", @__FILE__, @__LINE__), +) where {T} + @assert comparison_direction in ("EQ", "NE", "GE", "GT", "LE", "LT") + @assert size(lhs) == size(rhs) + if lhs isa TracedRNumber + @assert rhs isa TracedRNumber + else + @assert rhs isa TracedRArray + end + + res = MLIR.IR.result( + MLIR.Dialects.stablehlo.compare( + lhs.mlir_data, + rhs.mlir_data; + comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( + MLIR.IR.context(), comparison_direction + ), + compare_type, + location, + ), + 1 + ) + lhs isa TracedRNumber && return TracedRNumber{Bool}((), res) + return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs)) +end + end diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 790de98363..1e333b4b26 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -92,8 +92,7 @@ function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) wh iota_2 = Ops.subtract( Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X)) ) - idxs = iota_1 .≤ iota_2 - return Ops.select(idxs, X, zero(X)) + return Ops.select(Ops.compare(iota_1, iota_2; comparison_direction="LE"), X, zero(X)) end function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} @@ -101,6 +100,5 @@ function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) wh iota_2 = Ops.add( Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X)) ) - idxs = iota_1 .≥ iota_2 - return Ops.select(idxs, X, zero(X)) + return Ops.select(Ops.compare(iota_1, iota_2; comparison_direction="GE"), X, zero(X)) end From 4b41a2bbc9c9fbbf2d39520b825deeec36163519 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Dec 2024 11:01:09 +0530 Subject: [PATCH 06/10] chore: apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Ops.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ops.jl b/src/Ops.jl index f3fc6abb92..2148cb5ebd 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1040,7 +1040,7 @@ function compare( compare_type, location, ), - 1 + 1, ) lhs isa TracedRNumber && return TracedRNumber{Bool}((), res) return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs)) From 0e3bb16920e058bb0b96402d0196590a98afe9ed Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Dec 2024 11:02:20 +0530 Subject: [PATCH 07/10] refactor: use ops in base dispatches --- src/TracedRNumber.jl | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 9780faf9d7..12c5123b75 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -151,19 +151,7 @@ for (jlop, hloop, hlocomp) in ( function $(jlop)( @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) ) where {T} - return TracedRNumber{Bool}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$(hloop)( - lhs.mlir_data, - rhs.mlir_data; - comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( - MLIR.IR.context(), $hlocomp - ), - ), - 1, - ), - ) + return Ops.compare(lhs, rhs; comparison_direction=$(hlocomp)) end function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T} From f4c0eb95f91fd6fd768accfb6a09d396b9e59714 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Dec 2024 11:03:24 +0530 Subject: [PATCH 08/10] refactor: move linear algebra tests --- test/{ => integration}/linear_algebra.jl | 0 test/runtests.jl | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename test/{ => integration}/linear_algebra.jl (100%) diff --git a/test/linear_algebra.jl b/test/integration/linear_algebra.jl similarity index 100% rename from test/linear_algebra.jl rename to test/integration/linear_algebra.jl diff --git a/test/runtests.jl b/test/runtests.jl index 34212e696b..fddc963ced 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,10 +56,10 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Shortcuts to MLIR ops" include("ops.jl") @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") @safetestset "Control Flow" include("control_flow.jl") - @safetestset "Linear Algebra" include("linear_algebra.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" + @safetestset "Linear Algebra" include("integration/linear_algebra.jl") @safetestset "AbstractFFTs" include("integration/fft.jl") end From f2247052af26de39ac22e815e8bfe1878563f228 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Dec 2024 11:16:52 +0530 Subject: [PATCH 09/10] fix: tril defn and inplace ops --- src/linear_algebra.jl | 10 +++++++--- test/integration/linear_algebra.jl | 14 +++++++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 1e333b4b26..c7e72651d7 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -92,13 +92,17 @@ function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) wh iota_2 = Ops.subtract( Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X)) ) - return Ops.select(Ops.compare(iota_1, iota_2; comparison_direction="LE"), X, zero(X)) + idxs = Ops.compare(iota_1, iota_2; comparison_direction="LE") + X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data + return X end function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} iota_1 = Ops.iota(Int64, [size(X)...]; iota_dimension=1) - iota_2 = Ops.add( + iota_2 = Ops.subtract( Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X)) ) - return Ops.select(Ops.compare(iota_1, iota_2; comparison_direction="GE"), X, zero(X)) + idxs = Ops.compare(iota_1, iota_2; comparison_direction="GE") + X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data + return X end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 1b35c64833..910d6ea9af 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -45,7 +45,7 @@ function mul_with_view3(A, x) return C end -@testset begin +@testset "Matrix Multiplication" begin A = rand(4, 4) x = rand(4, 2) b = rand(4) @@ -77,3 +77,15 @@ end @jit(mul!(C_ra, A_ra, x_ra)) @test C_ra ≈ A * x end + +@testset "triu & tril" begin + A = rand(4, 6) + A_ra = Reactant.to_rarray(A) + + @test @jit(triu(A_ra)) ≈ triu(A) + @test @jit(tril(A_ra)) ≈ tril(A) + @test @jit(triu(A_ra, 2)) ≈ triu(A, 2) + @test @jit(tril(A_ra, 2)) ≈ tril(A, 2) + @test @jit(triu(A_ra, -1)) ≈ triu(A, -1) + @test @jit(tril(A_ra, -1)) ≈ tril(A, -1) +end From 1bddd1b4e7ebe1db8483bf042f203d46cc378b43 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Dec 2024 11:32:47 +0530 Subject: [PATCH 10/10] test: add inplace tests --- test/integration/linear_algebra.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 910d6ea9af..22fe07c1f6 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -88,4 +88,28 @@ end @test @jit(tril(A_ra, 2)) ≈ tril(A, 2) @test @jit(triu(A_ra, -1)) ≈ triu(A, -1) @test @jit(tril(A_ra, -1)) ≈ tril(A, -1) + + A_ra = Reactant.to_rarray(A) + @jit(triu!(A_ra)) + @test A_ra ≈ triu(A) + + A_ra = Reactant.to_rarray(A) + @jit(tril!(A_ra)) + @test A_ra ≈ tril(A) + + A_ra = Reactant.to_rarray(A) + @jit(triu!(A_ra, 2)) + @test A_ra ≈ triu(A, 2) + + A_ra = Reactant.to_rarray(A) + @jit(tril!(A_ra, 2)) + @test A_ra ≈ tril(A, 2) + + A_ra = Reactant.to_rarray(A) + @jit(triu!(A_ra, -1)) + @test A_ra ≈ triu(A, -1) + + A_ra = Reactant.to_rarray(A) + @jit(tril!(A_ra, -1)) + @test A_ra ≈ tril(A, -1) end