diff --git a/ext/ReactantArrayInterfaceExt.jl b/ext/ReactantArrayInterfaceExt.jl index f1d1659b93..ffd7ddc60e 100644 --- a/ext/ReactantArrayInterfaceExt.jl +++ b/ext/ReactantArrayInterfaceExt.jl @@ -2,7 +2,7 @@ module ReactantArrayInterfaceExt using ArrayInterface: ArrayInterface using Reactant: - Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray + Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray, Ops ArrayInterface.can_setindex(::Type{<:RArray}) = false ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false @@ -14,7 +14,7 @@ function ArrayInterface.aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where end function ArrayInterface.aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T} - return reshape(vcat(x...), size(x)) + return Ops.reshape(vcat(x...), size(x)...) end end diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index b90fa60fb2..8bfa5de02a 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -3,7 +3,15 @@ module ReactantNNlibExt using NNlib using GPUArraysCore: @allowscalar using Reactant: - Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber + Reactant, + Ops, + TracedRArray, + AnyTracedRArray, + materialize_traced_array, + MLIR, + TracedRNumber, + get_mlir_data, + set_mlir_data! using ReactantCore: @trace using LinearAlgebra: LinearAlgebra, triu @@ -12,14 +20,7 @@ for (jlop, hloop) in ( (:(NNlib.sigmoid_fast), :logistic), (:(NNlib.sigmoid), :logistic), ) - @eval function $(jlop)(x::TracedRNumber{T}) where {T} - return TracedRNumber{T}( - (), - Reactant.MLIR.IR.result( - Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1 - ), - ) - end + @eval $(jlop)(x::TracedRNumber) = Ops.$(hloop)(x) end function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N} @@ -82,13 +83,6 @@ function NNlib.conv!( kernel_input_dim = N - 1 kernel_output_dim = N - output_spatial_shapes = map(input_spatial_dims) do i - K = kernel_size[i] - pl, pr = padding[2i - 1], padding[2i] - d = dilation[i] - s = stride[i] - return (size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1 - end output_batch_dim = input_batch_dim output_feature_dim = input_feature_dim output_spatial_dims = input_spatial_dims @@ -119,8 +113,8 @@ function NNlib.conv!( end conv = Reactant.MLIR.Dialects.stablehlo.convolution( - x.mlir_data, - weight.mlir_data; + get_mlir_data(x), + get_mlir_data(weight); result_0=result_type, window_strides=collect(stride), padding, @@ -130,7 +124,7 @@ function NNlib.conv!( feature_group_count, batch_group_count=1, ) - y.mlir_data = Reactant.MLIR.IR.result(conv) + set_mlir_data!(y, Reactant.MLIR.IR.result(conv)) return y end @@ -165,7 +159,9 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N} output_shape = (output_spatial_shapes..., size(x, N - 1), size(x, N)) result_type = Reactant.MLIR.IR.TensorType(output_shape, Reactant.MLIR.IR.Type(T)) - unranked = Reactant.MLIR.IR.TensorType((), eltype(Reactant.MLIR.IR.type(x.mlir_data))) + unranked = Reactant.MLIR.IR.TensorType( + (), eltype(Reactant.MLIR.IR.type(get_mlir_data(x))) + ) body = let body = Reactant.MLIR.IR.Region(), loc = Reactant.MLIR.IR.Location(), @@ -189,7 +185,7 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N} Reactant.MLIR.Dialects.stablehlo.constant(; value=attr) ) reduction = Reactant.MLIR.Dialects.stablehlo.reduce_window( - [x.mlir_data], + [get_mlir_data(x)], [init_value]; result_0=[result_type], window_dimensions, @@ -205,10 +201,10 @@ end function NNlib.maxpool!( y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims ) where {T} - y.mlir_data = - reduce_window( - Reactant.MLIR.Dialects.stablehlo.maximum, T.(x), pdims; init=typemin(T) - ).mlir_data + res = reduce_window( + Reactant.MLIR.Dialects.stablehlo.maximum, T.(x), pdims; init=typemin(T) + ) + set_mlir_data!(y, get_mlir_data(res)) return y end @@ -216,13 +212,13 @@ function NNlib.meanpool!( y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims ) where {T} res = reduce_window(Reactant.MLIR.Dialects.stablehlo.add, T.(x), pdims; init=zero(T)) - y.mlir_data = (res ./ T(prod(NNlib.kernel_size(pdims)))).mlir_data + set_mlir_data!(y, get_mlir_data(res ./ T(prod(NNlib.kernel_size(pdims))))) return y end -NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = permutedims(x, (2, 1, 3)) +NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = PermutedDimsArray(x, (2, 1, 3)) function NNlib.batched_adjoint(x::AnyTracedRArray{T,3}) where {T} - y = permutedims(x, (2, 1, 3)) + y = NNlib.batched_transpose(x) conj!(y) return y end @@ -238,14 +234,21 @@ function NNlib.batched_mul!( ), ) end + + if size(x, 3) != size(y, 3) + B = max(size(x, 3), size(y, 3)) + if size(x, 3) == 1 + x = Reactant.broadcast_to_size(x, (size(x, 1), size(x, 2), B)) + elseif size(y, 3) == 1 + y = Reactant.broadcast_to_size(y, (size(y, 1), size(y, 2), B)) + end + end + x = permutedims(x, (3, 1, 2)) y = permutedims(y, (3, 1, 2)) - B = max(size(x, 1), size(y, 1)) - out_shape = (B, size(x, 2), size(y, 3)) - resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(res.mlir_data))) - if size(x, 1) != size(y, 1) + B = max(size(x, 1), size(y, 1)) if size(x, 1) == 1 x = Reactant.broadcast_to_size(x, (B, size(x, 2), size(x, 3))) elseif size(y, 1) == 1 @@ -253,49 +256,25 @@ function NNlib.batched_mul!( end end - dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet( - MLIR.IR.context(), 1, [0], 1, [0], 1, [2], 1, [1] + tmp = Ops.dot_general( + T1.(materialize_traced_array(x)), + T1.(materialize_traced_array(y)); + contracting_dimensions=([3], [2]), + batching_dimensions=([1], [1]), ) + set_mlir_data!(res, get_mlir_data(permutedims(tmp, (2, 3, 1)))) - prec = MLIR.IR.Attribute( - MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT") - ) - tmp = TracedRArray{T1,3}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.dot_general( - x.mlir_data, - y.mlir_data; - result_0=resty, - dot_dimension_numbers=dot_dimension_numbers, - precision_config=prec, - ), - 1, - ), - size(resty), - ) - res.mlir_data = permutedims(tmp, (2, 3, 1)).mlir_data return res end function NNlib.pad_constant( - x::TracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value + x::AnyTracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value ) where {T,N} value = Reactant.promote_to(TracedRNumber{T}, value) - edge_padding_low = [i[1] for i in pad] - edge_padding_high = [i[2] for i in pad] - interior_padding = [0 for i in pad] - res = MLIR.IR.result( - MLIR.Dialects.stablehlo.pad( - x.mlir_data, - value.mlir_data; - edge_padding_low, - edge_padding_high, - interior_padding, - ), - 1, - ) - return TracedRArray{T,N}((), res, size(MLIR.IR.type(res))) + low = [i[1] for i in pad] + high = [i[2] for i in pad] + interior = [0 for i in pad] + return Ops.pad(materialize_traced_array(x), value; low, high, interior) end # XXX: reevaluate this manual optimization once @@ -305,7 +284,7 @@ function NNlib.gather!( src::AnyTracedRArray{T2,2}, idxs::Union{AbstractUnitRange{<:Number}}, ) where {T1,T2} - dst.mlir_data = src[:, idxs].mlir_data + set_mlir_data!(dst, get_mlir_data(src[:, idxs])) return dst end @@ -314,8 +293,8 @@ function NNlib.gather!( ) where {T1,T2} dims = NNlib.scatter_dims(src, dst, idxs) @assert dims == 1 # scatter_dims lets us do some size checks so we call that function - idxs = (Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1).mlir_data - slice_sizes = Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]).mlir_data + idxs = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1) + slice_sizes = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1])) #! format: off dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( @@ -331,11 +310,11 @@ function NNlib.gather!( res = MLIR.IR.result( Reactant.MLIR.Dialects.stablehlo.dynamic_gather( - src.mlir_data, idxs, slice_sizes; dimension_numbers + get_mlir_data(src), idxs, slice_sizes; dimension_numbers ), 1, ) - dst.mlir_data = res + set_mlir_data!(dst, res) return dst end @@ -354,7 +333,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr return reshape(res, start_sizes..., :) end res = reshape(cat(results...; dims=(dims + 1)), size(dst)) - dst.mlir_data = res.mlir_data + set_mlir_data!(dst, get_mlir_data(res)) return dst end @@ -363,7 +342,7 @@ dilate_shape(s, d) = max(0, 1 + d * (s - 1)) # see lax._conv_general_dilated_transpose_rhs # https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L495 function NNlib.∇conv_filter!( - dw::Reactant.TracedRArray{T,N}, + dw::TracedRArray{T,N}, x::AnyTracedRArray, dy::AnyTracedRArray, cdims::NNlib.DenseConvDims, @@ -437,8 +416,8 @@ function NNlib.∇conv_filter!( result_type = Reactant.MLIR.IR.TensorType(size(dw), Reactant.MLIR.IR.Type(T)) conv = MLIR.Dialects.stablehlo.convolution( - x.mlir_data, - dy.mlir_data; + get_mlir_data(x), + get_mlir_data(dy); result_0=result_type, window_strides=collect(dilation), padding, @@ -447,11 +426,12 @@ function NNlib.∇conv_filter!( feature_group_count, batch_group_count, ) - - dw.mlir_data = MLIR.IR.result(conv) + set_mlir_data!(dw, MLIR.IR.result(conv)) if !NNlib.flipkernel(cdims) - dw.mlir_data = Reactant.Ops.reverse(dw; dimensions=output_spatial_dims).mlir_data + set_mlir_data!( + dw, get_mlir_data(Reactant.Ops.reverse(dw; dimensions=output_spatial_dims)) + ) end return dw @@ -553,8 +533,8 @@ function NNlib.∇conv_data!( end conv = MLIR.Dialects.stablehlo.convolution( - dy.mlir_data, - w.mlir_data; + get_mlir_data(dy), + get_mlir_data(w); result_0=result_type, window_strides=1, padding, @@ -564,8 +544,7 @@ function NNlib.∇conv_data!( feature_group_count, batch_group_count=1, ) - - dx.mlir_data = MLIR.IR.result(conv) + set_mlir_data!(dx, MLIR.IR.result(conv)) return dx end diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 42c454b31a..ceb0844026 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -8,6 +8,9 @@ mutable struct ConcreteRArray{T,N} <: RArray{T,N} shape::NTuple{N,Int} end +const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}} +const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}} + mutable struct ConcreteRNumber{T} <: RNumber{T} data::XLA.AsyncBuffer end @@ -74,21 +77,15 @@ function ConcreteRArray( ) end -Base.size(x::ConcreteRArray) = x.shape - -function Base.reshape(A::ConcreteRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT} - prod(dims) == prod(size(A)) || Base._throw_dmrsa(dims, prod(size(A))) - host = convert(Array{T,N}, A) - # HLO reshape semantics collapse the opposite so enforce on Julia Side - # until we later make the transpose/reshape/transpose - host = reshape(host, dims) - client = XLA.client(A.data) - device = XLA.device(A.data) - buffer = XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, host, device), nothing) - return ConcreteRArray{T,NT}(buffer, dims) - # ConcreteRArray{T, dims, NT}(XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, XLA.to_row_major(host), device), nothing)) +ConcreteRArray(x::AnyConcreteRArray) = ConcreteRArray{eltype(x),ndims(x)}(x) +ConcreteRArray{T}(x::AnyConcreteRArray) where {T} = ConcreteRArray{T,ndims(x)}(x) +ConcreteRArray{T,N}(x::ConcreteRArray{T,N}) where {T,N} = x +function ConcreteRArray{T,N}(x::AnyConcreteRArray) where {T,N} + return ConcreteRArray(convert(Array{T,N}, x)) end +Base.size(x::ConcreteRArray) = x.shape + function Base.convert(::Type{T}, X::ConcreteRArray{ElType,N}) where {T<:Array,ElType,N} data = Array{ElType,N}(undef, size(X)...) # TODO replace for `similar`? XLA.await(X.data) @@ -99,7 +96,13 @@ function Base.convert(::Type{T}, X::ConcreteRArray{ElType,N}) where {T<:Array,El return data # XLA.from_row_major(data) end -Base.Array(x::ConcreteRArray) = convert(Array, x) +function Base.convert( + ::Type{T}, X::WrappedConcreteRArray{ElType,N} +) where {T<:Array,ElType,N} + fn = compile(materialize_traced_array, (X,)) + return convert(Array, fn(X)) +end +Base.Array(x::AnyConcreteRArray) = convert(Array, x) function synchronize(x::Union{ConcreteRArray,ConcreteRNumber}) XLA.synced_buffer(x.data) @@ -165,19 +168,21 @@ for T in (ConcreteRNumber, ConcreteRArray{<:Any,0}) end end -function Base.isapprox(x::ConcreteRArray, y::AbstractArray; kwargs...) +function Base.isapprox(x::AnyConcreteRArray, y::AbstractArray; kwargs...) return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) end -function Base.isapprox(x::AbstractArray, y::ConcreteRArray; kwargs...) +function Base.isapprox(x::AbstractArray, y::AnyConcreteRArray; kwargs...) return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) end -function Base.isapprox(x::ConcreteRArray, y::ConcreteRArray; kwargs...) +function Base.isapprox(x::AnyConcreteRArray, y::AnyConcreteRArray; kwargs...) return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) end -Base.:(==)(x::ConcreteRArray, y::AbstractArray) = convert(Array, x) == convert(Array, y) -Base.:(==)(x::AbstractArray, y::ConcreteRArray) = convert(Array, x) == convert(Array, y) -Base.:(==)(x::ConcreteRArray, y::ConcreteRArray) = convert(Array, x) == convert(Array, y) +Base.:(==)(x::AnyConcreteRArray, y::AbstractArray) = convert(Array, x) == convert(Array, y) +Base.:(==)(x::AbstractArray, y::AnyConcreteRArray) = convert(Array, x) == convert(Array, y) +function Base.:(==)(x::AnyConcreteRArray, y::AnyConcreteRArray) + return convert(Array, x) == convert(Array, y) +end function Base.show(io::IO, X::ConcreteRScalar{T}) where {T} if X.data == XLA.AsyncEmptyBuffer @@ -264,7 +269,7 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N end GPUArraysCore.assertscalar("setindex!(::ConcreteRArray, ::Any, ::Vararg{Int, N})") - fn = Reactant.compile(mysetindex!, (a, v, args...)) + fn = compile(mysetindex!, (a, v, args...)) fn(a, v, args...) return a end @@ -307,7 +312,7 @@ function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteR return ConcreteRArray(aux) end - fn = Reactant.compile(Broadcast.BroadcastFunction(bc.f), (bc.args...,)) + fn = compile(Broadcast.BroadcastFunction(bc.f), (bc.args...,)) return fn(bc.args...) end @@ -323,7 +328,7 @@ function Base.mapreduce( dims=:, init=nothing, ) where {T,N} - fn = Reactant.compile(CallMapReduce(f, op, dims, init), (A,)) + fn = compile(CallMapReduce(f, op, dims, init), (A,)) return fn(A) end diff --git a/src/Ops.jl b/src/Ops.jl index 748d06f662..2f1c7f6eb9 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -253,12 +253,14 @@ function reshape( dims::Vector{Int}; location=mlir_stacktrace("reshape", @__FILE__, @__LINE__), ) where {T,N} - restype = mlir_type(TracedRArray{T,length(dims)}, dims) - res = MLIR.IR.result(stablehlo.reshape(x.mlir_data; result_0=restype, location)) - result = TracedRArray{T,length(dims)}((), res, dims) + # HLO reshape semantics collapse the opposite way + res1 = transpose(x, Int64[N:-1:1...]) + restype = mlir_type(TracedRArray{T,length(dims)}, collect(Base.reverse(dims))) + res = MLIR.IR.result(stablehlo.reshape(res1.mlir_data; result_0=restype, location)) + result = TracedRArray{T,length(dims)}((), res, collect(Base.reverse(dims))) # NOTE this last `transpose` is required for consistency with Julia's column-major order # do not remove, as it will be optimized away by the compiler - return transpose(result, [length(dims):-1:1...]) + return transpose(result, Int64[length(dims):-1:1...]) end function get_dimension_size( diff --git a/src/Reactant.jl b/src/Reactant.jl index 0b73d3d96e..06fd59affe 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -57,10 +57,6 @@ abstract type RNumber{T<:ReactantPrimitive} <: Number end Base.collect(A::RArray) = copy(A) -function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}}) - return reshape(A, Base._reshape_uncolon(A, dims)) -end - function Enzyme.make_zero( ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) )::RT where {copy_if_inactive,RT<:RArray} diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 97a29e56f9..0e9bf6f779 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -17,6 +17,21 @@ mutable struct TracedRArray{T,N} <: RArray{T,N} end end +const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} +const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} +const AnyTracedRVector{T} = AnyTracedRArray{T,1} +const AnyTracedRMatrix{T} = Union{ + AnyTracedRArray{T,2},LinearAlgebra.Diagonal{T,TracedRArray{T,1}} +} +const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} + +function TracedRArray(data::MLIR.IR.Value) + data_type = MLIR.IR.type(data) + return TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}( + (), data, size(data_type) + ) +end + ReactantCore.is_traced(::TracedRArray) = true new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A)) @@ -38,6 +53,10 @@ function TracedRArray{T,N}(rhs::TracedRArray{T0,N}) where {T,T0,N} end end +function TracedRArray{T,N}(rhs::WrappedTracedRArray{T0,N}) where {T0,T,N} + return TracedRArray{T,N}(materialize_traced_array(rhs)) +end + function TracedRArray{T,N}(rhs::AbstractArray{T0,N}) where {T0,T,N} attr = MLIR.IR.DenseElementsAttribute(collect(rhs)) return TracedRArray{T,N}( @@ -47,14 +66,31 @@ function TracedRArray{T,N}(rhs::AbstractArray{T0,N}) where {T0,T,N} ) end -const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} -const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} -const AnyTracedRVector{T} = AnyTracedRArray{T,1} -const AnyTracedRMatrix{T} = AnyTracedRArray{T,2} -const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} - materialize_traced_array(x::TracedRArray) = x materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] +function materialize_traced_array( + x::Adapt.WrappedReshapedArray{T,N,<:TracedRArray} +) where {T,N} + return Ops.reshape(materialize_traced_array(parent(x)), size(x)...) +end +function materialize_traced_array( + x::LinearAlgebra.Transpose{T,TracedRArray{T,N}} +) where {T,N} + px = parent(x) + A = ndims(px) == 1 ? reshape(px, :, 1) : px + return permutedims(A, (2, 1)) +end +function materialize_traced_array(x::LinearAlgebra.Adjoint{T,TracedRArray{T,N}}) where {T,N} + return conj(materialize_traced_array(transpose(parent(x)))) +end +function materialize_traced_array( + x::PermutedDimsArray{T,N,perm,iperm,<:TracedRArray{T,N}} +) where {T,N,perm,iperm} + return permutedims(parent(x), perm) +end +function materialize_traced_array(x::LinearAlgebra.Diagonal{T,TracedRArray{T,1}}) where {T} + return LinearAlgebra.diagm(parent(x)) +end get_mlir_data(x::TracedRArray) = x.mlir_data get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x)) @@ -63,12 +99,43 @@ function set_mlir_data!(x::TracedRArray, data) x.mlir_data = data return x end +function set_mlir_data!(x::Adapt.WrappedReshapedArray{T,N,<:TracedRArray}, data) where {T,N} + res_mlir_data = Ops.reshape(TracedRArray(data), size(parent(x))...).mlir_data + set_mlir_data!(parent(x), res_mlir_data) + return x +end +function set_mlir_data!(x::LinearAlgebra.Transpose{T,TracedRArray{T,N}}, data) where {T,N} + tdata = TracedRArray(data) + px = parent(x) + px.mlir_data = ( + if ndims(px) == 1 + Ops.reshape(tdata, length(tdata)) + else + Ops.transpose(tdata, [2, 1]) + end + ).mlir_data + return x +end +function set_mlir_data!(x::LinearAlgebra.Adjoint{T,TracedRArray{T,N}}, data) where {T,N} + tdata = TracedRArray(data) + px = parent(x) + transposed_data = + ndims(px) == 1 ? Ops.reshape(tdata, length(tdata)) : Ops.transpose(tdata, [2, 1]) + px.mlir_data = (T <: Real ? transposed_data : Ops.conj(transposed_data)).mlir_data + return x +end +function set_mlir_data!( + x::PermutedDimsArray{T,N,perm,iperm,TracedRArray{T,N}}, data +) where {T,N,perm,iperm} + parent(x).mlir_data = permutedims(TracedRArray(data), iperm).mlir_data + return x +end +function set_mlir_data!(x::LinearAlgebra.Diagonal{T,TracedRArray{T,1}}, data) where {T} + parent(x).mlir_data = LinearAlgebra.diag(TracedRArray(data)).mlir_data + return x +end function set_mlir_data!(x::AnyTracedRArray, data) - data_type = MLIR.IR.type(data) - data = TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}( - (), data, size(data_type) - ) - setindex!(x, data, axes(x)...) + setindex!(x, TracedRArray(data), axes(x)...) return x end @@ -198,46 +265,6 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC # return print(io, X.mlir_data, ")") end -function Base.reshape(A::AnyTracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT} - if prod(dims) != prod(size(A)) - throw( - DimensionMismatch( - "new shape $(dims) is incompatible with array size $(size(A))" - ), - ) - end - - # HLO reshape semantics collapse the opposite way - res1 = MLIR.IR.result( - MLIR.Dialects.stablehlo.transpose( - get_mlir_data(A); - permutation=MLIR.IR.DenseArrayAttribute([Int64(N - 1 - i) for i in 0:(N - 1)]), - ), - 1, - ) - - res2 = MLIR.IR.result( - MLIR.Dialects.stablehlo.reshape( - res1; - result_0=MLIR.IR.TensorType( - [Int64(i) for i in reverse(dims)], eltype(MLIR.IR.type(res1)) - ), - ), - ) - - res3 = MLIR.IR.result( - MLIR.Dialects.stablehlo.transpose( - res2; - permutation=MLIR.IR.DenseArrayAttribute([ - Int64(NT - 1 - i) for i in 0:(NT - 1) - ]), - ), - 1, - ) - - return TracedRArray{T,NT}((), res3, dims) -end - function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N} return TracedRArray{T,N}( (), @@ -303,12 +330,6 @@ function Base.imag(A::TracedRArray{Complex{T},N}) where {T,N} ) end -function Base.transpose(A::AnyTracedRVecOrMat) - A = ndims(A) == 1 ? reshape(A, :, 1) : A - return permutedims(A, (2, 1)) -end -Base.adjoint(A::AnyTracedRVecOrMat) = conj(transpose(A)) - promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} = TracedRArray{T,N}(rhs) promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs) @@ -550,10 +571,10 @@ function BroadcastStyle(::Type{<:AnyTracedRArray{T,N}}) where {T,N} end function Base.similar( - bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims + ::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims ) where {T<:ReactantPrimitive,N} @assert N isa Int - return TracedRArray{T,N}((), nothing, map(length, dims)) + return TracedRArray{T,length(dims)}((), nothing, map(length, dims)) end function Base.similar( @@ -648,7 +669,7 @@ end function broadcast_to_size(arg::AnyTracedRArray, rsize) arg = materialize_traced_array(arg) - size(arg) == rsize && return arg + size(arg) == Tuple(rsize) && return arg return broadcast_to_size_internal(arg, rsize) end @@ -794,9 +815,3 @@ end Base.all(f::Function, x::TracedRArray) = mapreduce(f, &, x) Base.any(f::Function, x::TracedRArray) = mapreduce(f, |, x) - -# LinearAlgebra defines norm with some conditionals which cannot be traced directly -function LinearAlgebra.norm(x::TracedRArray{T,N}, p::Real=2) where {T,N} - isinf(p) && return maximum(abs, x) - return mapreduce(Base.Fix2(^, p), +, x)^(1 / p) -end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index d3a6f3b990..ebe733ce62 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -12,6 +12,9 @@ mutable struct TracedRNumber{T} <: RNumber{T} end end +get_mlir_data(x::TracedRNumber) = x.mlir_data +set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x) + ReactantCore.is_traced(::TracedRNumber) = true new_traced_value(::TracedRNumber{T}) where {T} = TracedRNumber{T}((), nothing) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index c7e72651d7..7451952173 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -1,35 +1,35 @@ function LinearAlgebra.mul!( - @nospecialize(C::TracedRArray{T1,1}), - @nospecialize(A::AnyTracedRArray{T2,2}), - @nospecialize(B::AnyTracedRArray{T3,1}), + @nospecialize(C::TracedRArray{T,1}), + @nospecialize(A::AnyTracedRMatrix), + @nospecialize(B::AnyTracedRVector), α::Number=true, β::Number=false, -) where {T1,T2,T3} +) where {T} # TODO: The reshape operations are not getting optimized, we should directly call dot_general - rC = reshape(C, :, 1) + rC = Ops.reshape(C, length(C), 1) LinearAlgebra.mul!(rC, A, reshape(B, :, 1), α, β) C.mlir_data = get_mlir_data(vec(rC)) return C end function LinearAlgebra.mul!( - @nospecialize(C::TracedRArray{T1,2}), - @nospecialize(A::AnyTracedRArray{T2,2}), - @nospecialize(B::AnyTracedRArray{T3,1}), + @nospecialize(C::TracedRArray{T,2}), + @nospecialize(A::AnyTracedRMatrix), + @nospecialize(B::AnyTracedRVector), α::Number=true, β::Number=false, -) where {T1,T2,T3} +) where {T} LinearAlgebra.mul!(C, A, reshape(B, :, 1), α, β) return C end function LinearAlgebra.mul!( - @nospecialize(C::TracedRArray{T1,2}), - @nospecialize(A::AnyTracedRArray{T2,2}), - @nospecialize(B::AnyTracedRArray{T3,2}), + @nospecialize(C::TracedRArray{T,2}), + @nospecialize(A::AnyTracedRMatrix), + @nospecialize(B::AnyTracedRMatrix), α::Number=true, β::Number=false, -) where {T1,T2,T3} +) where {T} if size(C) != (size(A, 1), size(B, 2)) throw( DimensionMismatch( @@ -40,50 +40,21 @@ function LinearAlgebra.mul!( if size(A, 2) != size(B, 1) throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B))")) end - resty = MLIR.IR.TensorType(size(C), MLIR.IR.Type(T1)) - dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet( - MLIR.IR.context(), 0, [], 0, [], 1, [1], 1, [0] - ) - prec = MLIR.IR.Attribute( - MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT") - ) - precar = MLIR.IR.Attribute([prec, prec]) - res = MLIR.IR.result( - MLIR.Dialects.stablehlo.dot_general( - get_mlir_data(A), - get_mlir_data(B); - result_0=resty, - dot_dimension_numbers=dot_dimension_numbers, - precision_config=precar, - ), - 1, + + tmp = Ops.dot_general( + T.(materialize_traced_array(A)), + T.(materialize_traced_array(B)); + contracting_dimensions=([2], [1]), ) - if iszero(β) - if isone(α) - C.mlir_data = res - else - C.mlir_data = MLIR.IR.result( - MLIR.Dialects.stablehlo.multiply( - res, broadcast_to_size(T1(α), size(C)).mlir_data - ), - 1, - ) - end + + res = if iszero(β) + isone(α) ? tmp : Ops.multiply(tmp, broadcast_to_size(T(α), size(C))) else - α_res = MLIR.IR.result( - MLIR.Dialects.stablehlo.multiply( - res, broadcast_to_size(T1(α), size(C)).mlir_data - ), - 1, - ) - β_C = MLIR.IR.result( - MLIR.Dialects.stablehlo.multiply( - C.mlir_data, broadcast_to_size(T1(β), size(C)).mlir_data - ), - 1, - ) - C.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(α_res, β_C), 1) + α_res = Ops.multiply(tmp, broadcast_to_size(T(α), size(C))) + β_C = Ops.multiply(C, broadcast_to_size(T(β), size(C))) + Ops.add(α_res, β_C) end + set_mlir_data!(C, get_mlir_data(res)) return C end @@ -106,3 +77,68 @@ function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) wh X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data return X end + +# LinearAlgebra defines norm with some conditionals which cannot be traced directly +function LinearAlgebra.norm(x::TracedRArray{T,N}, p::Real=2) where {T,N} + isinf(p) && return maximum(abs, x) + return mapreduce(Base.Fix2(^, p), +, x)^(1 / p) +end + +function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T} + y = materialize_traced_array(x) + + rows, cols = size(y) + (start_row, start_col) = k ≥ 0 ? (0, k) : (-k, 0) + diag_length = min(rows - start_row, cols - start_col) + + indices = stack(( + start_row:(start_row + diag_length - 1), start_col:(start_col + diag_length - 1) + )) + + # XXX: creating an empty array causes + # terminate called after throwing an instance of 'xla::XlaRuntimeError' + # what(): UNKNOWN: :0: error: 'tensor.empty' op unsupported op for export to XLA + # :0: note: see current operation: %0 = "tensor.empty"() : () -> tensor<0xf64> + length(indices) ≤ 0 && return promote_to(TracedRArray{T,1}, T[]) + + idxs = get_mlir_data(Reactant.promote_to(TracedRArray{Int,2}, indices)) + + #! format: off + dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( + MLIR.IR.context(), + Int64(0), Int64[], + Int64(2), Int64[0, 1], + Int64(0), Int64[], + Int64(0), Int64[], + Int64(2), Int64[0, 1], + Int64(1) + ) + #! format: on + + slice_sizes = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, [1, 1])) + res = MLIR.IR.result( + MLIR.Dialects.stablehlo.dynamic_gather( + get_mlir_data(y), idxs, slice_sizes; dimension_numbers + ), + 1, + ) + return TracedRArray{T,1}((), res, (diag_length,)) +end + +function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T} + return LinearAlgebra.diagm(length(v), length(v), v) +end +function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T} + m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check + + v = materialize_traced_array(v) + D = length(v) + row_idxs = Ops.iota(Int, [D, D]; iota_dimension=1) + col_idxs = Ops.iota(Int, [D, D]; iota_dimension=2) + diag_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="EQ") + + mat = (v .+ zero(v)') .* diag_indicator + return Ops.pad( + mat, promote_to(TracedRNumber{T}, 0); high=[m - length(v), n - length(v)] + ) +end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 22fe07c1f6..0c6efc5fdb 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -113,3 +113,32 @@ end @jit(tril!(A_ra, -1)) @test A_ra ≈ tril(A, -1) end + +@testset "diag / diagm" begin + x = rand(2, 4) + x_ra = Reactant.to_rarray(x) + + @testset for k in (-size(x, 1) + 1):(size(x, 1) - 1) + @test @jit(diag(x_ra, k)) ≈ diag(x, k) + end + + x = rand(4) + x_ra = Reactant.to_rarray(x) + + @test @jit(diagm(x_ra)) ≈ diagm(x) + @test @jit(diagm(5, 4, x_ra)) ≈ diagm(5, 4, x) + @test @jit(diagm(4, 5, x_ra)) ≈ diagm(4, 5, x) + @test @jit(diagm(6, 6, x_ra)) ≈ diagm(6, 6, x) + @test_throws DimensionMismatch @jit(diagm(3, 3, x_ra)) +end + +# TODO: Currently Diagonal(x) * x goes down the generic matmul path but it should clearly be +# optimized +mul_diagonal(x) = Diagonal(x) * x + +@testset "mul_diagonal" begin + x = rand(4) + x_ra = Reactant.to_rarray(x) + + @test @jit(mul_diagonal(x_ra)) ≈ mul_diagonal(x) +end diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 5f9c92ef74..295d256ef8 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -393,3 +393,10 @@ end NNlib.∇conv_filter(x, dy, conv_dims) end end + +@testset "Upsampling" begin + x = randn(Float32, 4, 4, 3, 2) + x_ra = Reactant.ConcreteRArray(x) + + @test @jit(NNlib.upsample_nearest(x_ra, (2, 2))) ≈ NNlib.upsample_nearest(x, (2, 2)) +end diff --git a/test/ops.jl b/test/ops.jl index 0600d0b861..07f911e88b 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -242,8 +242,8 @@ end @test a .* b ≈ @jit f1(a, b) @test reshape(kron(Array(b), Array(a)), 4, 4) ≈ @jit f2(a, b) - x = reshape(a, (2, 2)) - y = reshape(b, (2, 2)) + x = ConcreteRArray(reshape(a, (2, 2))) + y = ConcreteRArray(reshape(b, (2, 2))) @test x .* y ≈ @jit f3(x, y) @test Array(x) * Array(y) ≈ @jit f4(x, y) end @@ -521,6 +521,9 @@ end @testset "reshape" begin x = ConcreteRArray([1, 2, 3, 4]) @test reshape(Array(x), 2, 2) == @jit Ops.reshape(x, 2, 2) + + x = ConcreteRArray(collect(reshape(1:12, 2, 2, 3))) + @test reshape(Array(x), 3, 1, 4) == @jit Ops.reshape(x, 3, 1, 4) end @testset "reverse" begin diff --git a/test/wrapped_arrays.jl b/test/wrapped_arrays.jl index f2aafbe2cd..f5418e5c80 100644 --- a/test/wrapped_arrays.jl +++ b/test/wrapped_arrays.jl @@ -1,4 +1,4 @@ -using Reactant, Test, Statistics, NNlib +using Reactant, Test, Statistics, NNlib, LinearAlgebra function view_getindex_1(x) x = view(x, 2:3, 1:2, :) @@ -21,8 +21,8 @@ end x_ra = Reactant.to_rarray(x) @test @allowscalar(@jit(view_getindex_1(x_ra))) ≈ view_getindex_1(x) - @test @jit(view_getindex_2(x_ra)) ≈ view_getindex_2(x) - @test @jit(view_getindex_3(x_ra)) ≈ view_getindex_3(x) + @test Array(@jit(view_getindex_2(x_ra))) ≈ view_getindex_2(x) + @test Array(@jit(view_getindex_3(x_ra))) ≈ view_getindex_3(x) end function reshape_wrapper(x) @@ -94,9 +94,81 @@ function bypass_permutedims(x) return view(x, 2:3, 1:2, :) end +add_perm_dims(x) = x .+ PermutedDimsArray(x, (2, 1)) + @testset "PermutedDimsArray" begin x = rand(4, 4, 3) x_ra = Reactant.to_rarray(x) y_ra = @jit(bypass_permutedims(x_ra)) @test @allowscalar(Array(y_ra)) ≈ bypass_permutedims(x) + + x = rand(4, 4) + x_ra = Reactant.to_rarray(x) + + @test @jit(add_perm_dims(x_ra)) ≈ add_perm_dims(x) +end + +function writeto_reshaped_array!(x) + z1 = similar(x) + z2 = reshape(z1, 1, 2, 3, 1) + @. z2 = 1.0 + return z1 +end + +function write_to_transposed_array!(x) + z1 = similar(x) + z2 = transpose(z1) + @. z2 = 1.0 + return z1 +end + +function write_to_adjoint_array!(x) + z1 = similar(x) + z2 = adjoint(z1) + @. z2 = 1.0 + return z1 +end + +function write_to_permuted_dims_array!(x) + z1 = similar(x) + z2 = PermutedDimsArray(z1, (2, 1)) + @. z2 = 1.0 + return z1 +end + +function write_to_diagonal_array!(x) + z = Diagonal(x) + @. z = 1.0 + return z +end + +@testset "Preserve Aliasing with Parent" begin + @testset "$(aType)" for (aType, fn) in [ + ("ReshapedArray", writeto_reshaped_array!), + ("Transpose", write_to_transposed_array!), + ("Adjoint", write_to_adjoint_array!), + ] + x = ConcreteRArray(rand(3, 2)) + y = @jit fn(x) + @test all(isone, Array(y)) + end + + @testset "PermutedDimsArray" begin + x = rand(4, 4) + x_ra = Reactant.to_rarray(x) + @test @jit(write_to_permuted_dims_array!(x_ra)) ≈ write_to_permuted_dims_array!(x) + end + + @testset "Diagonal" begin + x = rand(4, 4) + x_ra = Reactant.to_rarray(x) + y_ra = copy(x_ra) + + y = @jit(write_to_diagonal_array!(x_ra)) + y_res = @allowscalar Array(y) + @test x_ra ≈ y_ra + @test all(isone, diag(y_res)) + y_res[diagind(y_res)] .= 0 + @test all(iszero, y_res) + end end