Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -393,14 +393,6 @@ buffer_on_cpu(::Any) = true
buffer_on_cpu(x::ConcretePJRTArray) = all(XLA.buffer_on_cpu, x.data)
buffer_on_cpu(x::ConcreteIFRTArray) = XLA.buffer_on_cpu(x.data)

function Ops.constant(x::AbstractConcreteArray; kwargs...)
return Ops.constant(Base.convert(Array, x); kwargs...)
end

function Ops.constant(x::AbstractConcreteNumber{T}; kwargs...) where {T}
return Ops.constant(Base.convert(T, x); kwargs...)
end

function Base.zero(x::ConcretePJRTArray{T,N}) where {T,N}
return ConcretePJRTArray(
zeros(T, size(x)...); client=XLA.client(x), device=XLA.device(x), x.sharding
Expand Down Expand Up @@ -464,3 +456,9 @@ function Base.mapreducedim!(
fn(f, op, R, A)
return R
end

function Base.map!(f, R::Union{AnyConcreteIFRTArray,AnyConcretePJRTArray}, A::AbstractArray)
fn = compile(Base.map!, (f, R, A))
fn(f, R, A)
return R
end
60 changes: 43 additions & 17 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ end
end
end

@noinline function constant(
x::AbstractArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
) where {T,N}
return constant(collect(x); location)
end

@noinline function constant(x::Reactant.AbstractConcreteArray; kwargs...)
return constant(Base.convert(Array, x); kwargs...)
end

@noinline function constant(
x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
) where {T<:Number}
Expand All @@ -127,6 +137,10 @@ end
return TracedRNumber{T}((), res.mlir_data)
end

@noinline function constant(x::Reactant.AbstractConcreteNumber{T}; kwargs...) where {T}
return constant(Base.convert(T, x); kwargs...)
end

function fill(
v, dims::Base.DimOrInd...; location=mlir_stacktrace("fill", @__FILE__, @__LINE__)
)
Expand Down Expand Up @@ -374,7 +388,7 @@ end
end

# shape ops
function reshape(x::TracedRArray, dims...; kwargs...)
function reshape(x::TracedRArray, dims::Integer...; kwargs...)
return reshape(x, collect(dims); kwargs...)
end

Expand Down Expand Up @@ -2377,7 +2391,7 @@ end
x::TracedRArray{T},
init_values::TracedRNumber{T},
dimensions::Vector{Int},
fn::Function,
fn::Function;
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
)

Expand Down Expand Up @@ -2409,25 +2423,43 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
- **CPU version & Julia's `reduce`**:
- Reduce along dimension 1 → `[(15) (21); (18) (24)]`
- Reduce along dimension 3 → `[(33 + 2) (45 + 2)]` → `[35 47]`

- **GPU version**:
- Reduce along dimension 1 → `[(15 + 2) (21 + 2); (18 + 2) (24 + 2)]`
- Reduce along dimension 3 → `[37 49]`
"""
@noinline function reduce(
x::TracedRArray{T},
init_values::TracedRNumber{T},
init_values::Union{TracedRNumber{T},Nothing},
dimensions::Vector{Int},
fn::Function,
fn::Function;
location=mlir_stacktrace("reduce", @__FILE__, @__LINE__),
) where {T}
elT = T
if init_values === nothing
if fn === min || fn === Base.FastMath.min_fast
init = typemax(elT)
elseif fn === max || fn === Base.FastMath.max_fast
init = typemin(elT)
else
init = Base.reduce_empty(Base.BottomRF(fn), elT)
end

initT = unwrapped_eltype(typeof(init))
if initT != elT # Bool, etc. reductions
elT = promote_type(initT, elT)
x = elT.(x)
end
init_values = Reactant.TracedUtils.promote_to(TracedRNumber{elT}, init)
end

reduced_shape = Tuple(deleteat!(collect(size(x)), dimensions))

result_type = mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape)
result_type = mlir_type(TracedRArray{elT,length(reduced_shape)}, reduced_shape)

sample_inputs = [
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
Reactant.TracedUtils.promote_to(TracedRNumber{elT}, 0),
Reactant.TracedUtils.promote_to(TracedRNumber{elT}, 0),
]

func =
Expand All @@ -2441,14 +2473,8 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
return_dialect=:stablehlo,
).f
@assert MLIR.IR.nregions(func) == 1
fn_name = String(
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
)
ftype_attr = MLIR.IR.attr(func, "function_type")
ftype = MLIR.IR.Type(ftype_attr)
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(T)) error (
"$fn return type is not tensor<i1>"
)
ftype = MLIR.IR.Type(MLIR.IR.attr(func, "function_type"))
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(elT)) "$fn return type is not tensor<i1>"
fn = MLIR.IR.Region()
MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1))
MLIR.IR.rmfromparent!(func)
Expand All @@ -2466,7 +2492,7 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
),
)

return TracedRArray{T,length(reduced_shape)}((), res, reduced_shape)
return TracedRArray{elT,length(reduced_shape)}((), res, reduced_shape)
end

end # module Ops
108 changes: 23 additions & 85 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,100 +468,29 @@ function Base.mapreduce(
dims=:,
init=nothing,
) where {T,N}
A = materialize_traced_array(A)
inp = broadcast(f, materialize_traced_array(A))

if dims isa Int
dims = [dims]
end

op_in_T = Core.Compiler.return_type(f, Tuple{T})

if init === nothing
if op === min
init = typemax(op_in_T)
elseif op === max
init = typemin(op_in_T)
else
init = Base.reduce_empty(Base.BottomRF(op), op_in_T)
end
dims isa Number && (dims = (dims,))

if typeof(init) != op_in_T
op_in_T = typeof(init)
A = typeof(init).(A)
end
if init !== nothing && typeof(init) != unwrapped_eltype(inp)
inp = typeof(init).(inp)
end

init = [TracedUtils.broadcast_to_size(init, ()).mlir_data]

inp = [broadcast(f, A).mlir_data]
rdims = dims == (:) ? collect(Int64, 1:N) : collect(Int64, dims)

rdims = Int64[]
reduction_result = Ops.reduce(inp, nothing, rdims, op)

if dims == (:)
for i in 0:(N - 1)
push!(rdims, i)
end
reduction_result = if dims != (:)
Ops.reshape(reduction_result, Int64[i ∈ rdims ? 1 : size(A, i) for i in 1:N])
else
for i in dims
push!(rdims, i - 1)
end
TracedRNumber{unwrapped_eltype(reduction_result)}((), reduction_result.mlir_data)
end

in_tys = [
MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(inp[1]))),
MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(init[1]))),
]

fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location(), MLIR.IR.Location()])

args = (
TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 1)),
TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 2)),
)

resty = MLIR.IR.block!(fnbody) do
tmp = TracedUtils.broadcast_to_size(op(args...), ())
Ops.return_(tmp)
return eltype(MLIR.IR.type(tmp.mlir_data))
end

toonedims = Int[]
outdims = Int[]
for i in 1:N
tmp = if in(i - 1, rdims)
1
else
sz = size(A, i)
push!(outdims, sz)
sz
end
push!(toonedims, tmp)
end

TT = MLIR.IR.Type[MLIR.IR.TensorType(outdims, resty)]

body = MLIR.IR.Region()
push!(body, fnbody)
red = MLIR.Dialects.stablehlo.reduce(
inp, init; result_0=TT, dimensions=MLIR.IR.DenseArrayAttribute(rdims), body
)

red = MLIR.IR.result(red, 1)
redT = eltype(MLIR.IR.julia_type(MLIR.IR.type(red)))

if dims != (:)
red = Ops.reshape(TracedRArray(red), toonedims...)
else
if length(outdims) == 0
red = TracedRNumber{redT}((), red)
else
red = TracedRArray{redT,length(outdims)}((), red, (outdims...,))
end
end
return red
init === nothing && return reduction_result
return broadcast(op, reduction_result, init)
end

function Base.mapreducedim!(
function Base._mapreducedim!(
@nospecialize(f),
@nospecialize(op),
@nospecialize(R::AnyTracedRArray),
Expand All @@ -573,9 +502,11 @@ function Base.mapreducedim!(
@assert sR == 1
return i
end

isempty(A) && return R

tmp = mapreduce(f, op, A; dims=filter(!isnothing, dims))
# set_mlir_data!(R, get_mlir_data(tmp))
R .= op.(R, tmp) # match native Julia's behavior
R .= op.(R, tmp)
return R
end

Expand Down Expand Up @@ -1084,4 +1015,11 @@ function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin
return (values, linear_indices)
end

Base.map(f, x::AnyTracedRArray) = f.(x)

function Base.map!(f, y::AnyTracedRArray, x::AbstractArray)
y .= f.(x)
return y
end

end
12 changes: 12 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -938,3 +938,15 @@ end
rv
)
end

@testset "mapreduce with init" begin
x = reshape(collect(Float32, 1:12), 3, 4)
x_ra = Reactant.to_rarray(x)

init = 3.0
init_ra = Reactant.to_rarray(init; track_numbers=Number)

fn(x, init; kwargs...) = sum(x; init, kwargs...)

@test @jit(fn(x_ra, init_ra; dims=2)) ≈ fn(x, init; dims=2)
end
Loading