diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index ef749eaf60..bccdce05db 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -3,6 +3,7 @@ module ReactantNNlibExt using NNlib using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber +using LinearAlgebra: LinearAlgebra, triu for (jlop, hloop) in ( (:(NNlib.tanh_fast), :tanh), @@ -280,4 +281,19 @@ function NNlib.pad_constant( return TracedRArray{T,N}((), res, size(MLIR.IR.type(res))) end +function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2) + len = size(x, dims) + # directly generating booleans were causing an incorrect constant attribute generation + # but the optimized IR removes the type case so we are probably ok + mask = MLIR.IR.DenseElementsAttribute(collect(triu(fill(1, (len, len)))')) + return Reactant.promote_to( + TracedRArray{Bool,2}, + TracedRArray{Int,2}( + (), + MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=mask), 1), + (len, len), + ), + ) +end + end # module ReactantNNlibExt diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 65240ed885..9aa024c73f 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -163,3 +163,13 @@ end @test @jit(NNlib.pad_constant(x_ra, (1, 1))) ≈ NNlib.pad_constant(x, (1, 1)) end + +@testset "make_causal_mask" begin + x = rand(2, 10) + x_ra = Reactant.ConcreteRArray(x) + + @test @jit(NNlib.make_causal_mask(x_ra)) ≈ NNlib.make_causal_mask(x) + + causal_mask2(x) = NNlib.make_causal_mask(x; dims=1) + @test @jit(causal_mask2(x_ra)) ≈ causal_mask2(x) +end