Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -936,24 +936,24 @@ end
end

# broadcast ops
# function broadcast_in_dim(
# x::TracedRArray{T,N},
# dims::Vector{Int};
# location=mlir_stacktrace(
# "broadcast_in_dim", @__FILE__, @__LINE__
# ),
# ) where {T,N}
# rsize = restype = MLIR.IR.TensorType([...], mlir_type(T)) # mlir_type(TracedRArray{T,N}, size(x))
# res = MLIR.IR.result(
# stablehlo.broadcast_in_dim(
# x.mlir_data;
# result_0=restype,
# broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims),
# location,
# ),
# )
# return TracedRArray{T,N}((), res, size(x))
# end
function broadcast_in_dim(
x::TracedRArray{T,N},
dims::Vector{Int},
result_size::Vector{Int};
location=mlir_stacktrace("broadcast_in_dim", @__FILE__, @__LINE__),
) where {T,N}
@assert length(dims) == N

res = MLIR.IR.result(
stablehlo.broadcast_in_dim(
x.mlir_data;
result_0=MLIR.IR.TensorType(result_size, MLIR.IR.Type(T)),
broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims .- 1),
location,
),
)
return TracedRArray{T,Int64(length(result_size))}((), res, Tuple(result_size))
end

@noinline function sort(
x::TracedRArray{T,N};
Expand Down
17 changes: 15 additions & 2 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,21 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
return v
end

v = TracedUtils.broadcast_to_size(v, length.(indices))
v = TracedUtils.promote_to(TracedRArray{T,N}, v)
if v isa Number
v = TracedUtils.broadcast_to_size(v, length.(indices))
v = TracedUtils.promote_to(TracedRArray{T,N}, v)
else
v = TracedUtils.promote_to(TracedRArray{T,ndims(v)}, v)
non_integer_indices = [!(idx isa Integer) for idx in indices]
broadcast_dims = findall(non_integer_indices)
if length(broadcast_dims) == N
v = TracedUtils.broadcast_to_size(v, length.(indices))
else
v = Ops.broadcast_in_dim(
materialize_traced_array(v), broadcast_dims, Int64.(length.(indices))
)
end
end

indices = [
(
Expand Down
25 changes: 1 addition & 24 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -496,30 +496,7 @@ function broadcast_to_size(arg::Broadcast.Extruded, rsize)
end

@noinline function broadcast_to_size_internal(x::TracedRArray{T}, rsize) where {T}
dims = collect(Int64, 0:(length(size(x)) - 1))

if length(size(MLIR.IR.type(get_mlir_data(x)))) != length(dims)
@show x
@show arg
@show rsize
@show rsize2
@show dims
end
@assert length(size(MLIR.IR.type(get_mlir_data(x)))) == length(dims)
mlirty = MLIR.IR.type(get_mlir_data(x))

return TracedRArray{T,Int(length(rsize))}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.broadcast_in_dim(
get_mlir_data(x);
result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)),
broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims),
),
1,
),
collect(rsize),
)
return Ops.broadcast_in_dim(x, collect(Int64, 1:ndims(x)), collect(Int64, rsize))
end

end
33 changes: 33 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,39 @@ end
# get_view_compiled = @compile get_view(x_concrete)
end

function write_with_broadcast1!(x, y)
x[1, :, :] .= reshape(y, 4, 3)
return x
end
function write_with_broadcast2!(x, y)
x[:, 1, :] .= view(y, :, 1:3)
return x
end

@testset "write_with_broadcast" begin
x_ra = Reactant.to_rarray(zeros(3, 4, 3))
y_ra = Reactant.to_rarray(rand(3, 4))

res = @jit write_with_broadcast1!(x_ra, y_ra)

@test res.data === x_ra.data

res = Array(res)
y = Array(y_ra)
@test res[1, :, :] ≈ reshape(y, 4, 3)

x_ra = Reactant.to_rarray(zeros(3, 4, 3))
y_ra = Reactant.to_rarray(rand(3, 4))

res = @jit write_with_broadcast2!(x_ra, y_ra)

@test res.data === x_ra.data

res = Array(res)
y = Array(y_ra)
@test res[:, 1, :] ≈ view(y, :, 1:3)
end

function masking(x)
y = similar(x)
y[1:2, :] .= 0
Expand Down
Loading