Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ext/ReactantArrayInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module ReactantArrayInterfaceExt

using ArrayInterface: ArrayInterface
using Reactant:
Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray
Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray, Ops

ArrayInterface.can_setindex(::Type{<:RArray}) = false
ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false
Expand All @@ -14,7 +14,7 @@ function ArrayInterface.aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where
end

function ArrayInterface.aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T}
return reshape(vcat(x...), size(x))
return Ops.reshape(vcat(x...), size(x)...)
end

end
145 changes: 62 additions & 83 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@ module ReactantNNlibExt
using NNlib
using GPUArraysCore: @allowscalar
using Reactant:
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
Reactant,
Ops,
TracedRArray,
AnyTracedRArray,
materialize_traced_array,
MLIR,
TracedRNumber,
get_mlir_data,
set_mlir_data!
using ReactantCore: @trace
using LinearAlgebra: LinearAlgebra, triu

Expand All @@ -12,14 +20,7 @@ for (jlop, hloop) in (
(:(NNlib.sigmoid_fast), :logistic),
(:(NNlib.sigmoid), :logistic),
)
@eval function $(jlop)(x::TracedRNumber{T}) where {T}
return TracedRNumber{T}(
(),
Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1
),
)
end
@eval $(jlop)(x::TracedRNumber) = Ops.$(hloop)(x)
end

function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
Expand Down Expand Up @@ -82,13 +83,6 @@ function NNlib.conv!(
kernel_input_dim = N - 1
kernel_output_dim = N

output_spatial_shapes = map(input_spatial_dims) do i
K = kernel_size[i]
pl, pr = padding[2i - 1], padding[2i]
d = dilation[i]
s = stride[i]
return (size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
end
output_batch_dim = input_batch_dim
output_feature_dim = input_feature_dim
output_spatial_dims = input_spatial_dims
Expand Down Expand Up @@ -119,8 +113,8 @@ function NNlib.conv!(
end

conv = Reactant.MLIR.Dialects.stablehlo.convolution(
x.mlir_data,
weight.mlir_data;
get_mlir_data(x),
get_mlir_data(weight);
result_0=result_type,
window_strides=collect(stride),
padding,
Expand All @@ -130,7 +124,7 @@ function NNlib.conv!(
feature_group_count,
batch_group_count=1,
)
y.mlir_data = Reactant.MLIR.IR.result(conv)
set_mlir_data!(y, Reactant.MLIR.IR.result(conv))
return y
end

Expand Down Expand Up @@ -165,7 +159,9 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
output_shape = (output_spatial_shapes..., size(x, N - 1), size(x, N))
result_type = Reactant.MLIR.IR.TensorType(output_shape, Reactant.MLIR.IR.Type(T))

unranked = Reactant.MLIR.IR.TensorType((), eltype(Reactant.MLIR.IR.type(x.mlir_data)))
unranked = Reactant.MLIR.IR.TensorType(
(), eltype(Reactant.MLIR.IR.type(get_mlir_data(x)))
)
body =
let body = Reactant.MLIR.IR.Region(),
loc = Reactant.MLIR.IR.Location(),
Expand All @@ -189,7 +185,7 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
Reactant.MLIR.Dialects.stablehlo.constant(; value=attr)
)
reduction = Reactant.MLIR.Dialects.stablehlo.reduce_window(
[x.mlir_data],
[get_mlir_data(x)],
[init_value];
result_0=[result_type],
window_dimensions,
Expand All @@ -205,24 +201,24 @@ end
function NNlib.maxpool!(
y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims
) where {T}
y.mlir_data =
reduce_window(
Reactant.MLIR.Dialects.stablehlo.maximum, T.(x), pdims; init=typemin(T)
).mlir_data
res = reduce_window(
Reactant.MLIR.Dialects.stablehlo.maximum, T.(x), pdims; init=typemin(T)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this can already be replaced by Ops.maximum?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first argument here is being used directly with block arguments. It would be an unnecessary in-direction to go down the Ops route here

)
set_mlir_data!(y, get_mlir_data(res))
return y
end

function NNlib.meanpool!(
y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims
) where {T}
res = reduce_window(Reactant.MLIR.Dialects.stablehlo.add, T.(x), pdims; init=zero(T))
y.mlir_data = (res ./ T(prod(NNlib.kernel_size(pdims)))).mlir_data
set_mlir_data!(y, get_mlir_data(res ./ T(prod(NNlib.kernel_size(pdims)))))
return y
end

NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = permutedims(x, (2, 1, 3))
NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = PermutedDimsArray(x, (2, 1, 3))
function NNlib.batched_adjoint(x::AnyTracedRArray{T,3}) where {T}
y = permutedims(x, (2, 1, 3))
y = NNlib.batched_transpose(x)
conj!(y)
return y
end
Expand All @@ -238,64 +234,47 @@ function NNlib.batched_mul!(
),
)
end

if size(x, 3) != size(y, 3)
B = max(size(x, 3), size(y, 3))
if size(x, 3) == 1
x = Reactant.broadcast_to_size(x, (size(x, 1), size(x, 2), B))
elseif size(y, 3) == 1
y = Reactant.broadcast_to_size(y, (size(y, 1), size(y, 2), B))
end
end

x = permutedims(x, (3, 1, 2))
y = permutedims(y, (3, 1, 2))

B = max(size(x, 1), size(y, 1))
out_shape = (B, size(x, 2), size(y, 3))
resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(res.mlir_data)))

if size(x, 1) != size(y, 1)
B = max(size(x, 1), size(y, 1))
if size(x, 1) == 1
x = Reactant.broadcast_to_size(x, (B, size(x, 2), size(x, 3)))
elseif size(y, 1) == 1
y = Reactant.broadcast_to_size(y, (B, size(y, 2), size(y, 3)))
end
end

dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet(
MLIR.IR.context(), 1, [0], 1, [0], 1, [2], 1, [1]
tmp = Ops.dot_general(
T1.(materialize_traced_array(x)),
T1.(materialize_traced_array(y));
contracting_dimensions=([3], [2]),
batching_dimensions=([1], [1]),
)
set_mlir_data!(res, get_mlir_data(permutedims(tmp, (2, 3, 1))))

prec = MLIR.IR.Attribute(
MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT")
)
tmp = TracedRArray{T1,3}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.dot_general(
x.mlir_data,
y.mlir_data;
result_0=resty,
dot_dimension_numbers=dot_dimension_numbers,
precision_config=prec,
),
1,
),
size(resty),
)
res.mlir_data = permutedims(tmp, (2, 3, 1)).mlir_data
return res
end

function NNlib.pad_constant(
x::TracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value
x::AnyTracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value
) where {T,N}
value = Reactant.promote_to(TracedRNumber{T}, value)
edge_padding_low = [i[1] for i in pad]
edge_padding_high = [i[2] for i in pad]
interior_padding = [0 for i in pad]
res = MLIR.IR.result(
MLIR.Dialects.stablehlo.pad(
x.mlir_data,
value.mlir_data;
edge_padding_low,
edge_padding_high,
interior_padding,
),
1,
)
return TracedRArray{T,N}((), res, size(MLIR.IR.type(res)))
low = [i[1] for i in pad]
high = [i[2] for i in pad]
interior = [0 for i in pad]
return Ops.pad(materialize_traced_array(x), value; low, high, interior)
end

# XXX: reevaluate this manual optimization once
Expand All @@ -305,7 +284,7 @@ function NNlib.gather!(
src::AnyTracedRArray{T2,2},
idxs::Union{AbstractUnitRange{<:Number}},
) where {T1,T2}
dst.mlir_data = src[:, idxs].mlir_data
set_mlir_data!(dst, get_mlir_data(src[:, idxs]))
return dst
end

Expand All @@ -314,8 +293,8 @@ function NNlib.gather!(
) where {T1,T2}
dims = NNlib.scatter_dims(src, dst, idxs)
@assert dims == 1 # scatter_dims lets us do some size checks so we call that function
idxs = (Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1).mlir_data
slice_sizes = Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]).mlir_data
idxs = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1)
slice_sizes = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]))

#! format: off
dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
Expand All @@ -331,11 +310,11 @@ function NNlib.gather!(

res = MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.dynamic_gather(
src.mlir_data, idxs, slice_sizes; dimension_numbers
get_mlir_data(src), idxs, slice_sizes; dimension_numbers
),
1,
)
dst.mlir_data = res
set_mlir_data!(dst, res)
return dst
end

Expand All @@ -354,7 +333,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
return reshape(res, start_sizes..., :)
end
res = reshape(cat(results...; dims=(dims + 1)), size(dst))
dst.mlir_data = res.mlir_data
set_mlir_data!(dst, get_mlir_data(res))
return dst
end

Expand All @@ -363,7 +342,7 @@ dilate_shape(s, d) = max(0, 1 + d * (s - 1))
# see lax._conv_general_dilated_transpose_rhs
# https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L495
function NNlib.∇conv_filter!(
dw::Reactant.TracedRArray{T,N},
dw::TracedRArray{T,N},
x::AnyTracedRArray,
dy::AnyTracedRArray,
cdims::NNlib.DenseConvDims,
Expand Down Expand Up @@ -437,8 +416,8 @@ function NNlib.∇conv_filter!(

result_type = Reactant.MLIR.IR.TensorType(size(dw), Reactant.MLIR.IR.Type(T))
conv = MLIR.Dialects.stablehlo.convolution(
x.mlir_data,
dy.mlir_data;
get_mlir_data(x),
get_mlir_data(dy);
result_0=result_type,
window_strides=collect(dilation),
padding,
Expand All @@ -447,11 +426,12 @@ function NNlib.∇conv_filter!(
feature_group_count,
batch_group_count,
)

dw.mlir_data = MLIR.IR.result(conv)
set_mlir_data!(dw, MLIR.IR.result(conv))

if !NNlib.flipkernel(cdims)
dw.mlir_data = Reactant.Ops.reverse(dw; dimensions=output_spatial_dims).mlir_data
set_mlir_data!(
dw, get_mlir_data(Reactant.Ops.reverse(dw; dimensions=output_spatial_dims))
)
end

return dw
Expand Down Expand Up @@ -553,8 +533,8 @@ function NNlib.∇conv_data!(
end

conv = MLIR.Dialects.stablehlo.convolution(
dy.mlir_data,
w.mlir_data;
get_mlir_data(dy),
get_mlir_data(w);
result_0=result_type,
window_strides=1,
padding,
Expand All @@ -564,8 +544,7 @@ function NNlib.∇conv_data!(
feature_group_count,
batch_group_count=1,
)

dx.mlir_data = MLIR.IR.result(conv)
set_mlir_data!(dx, MLIR.IR.result(conv))

return dx
end
Expand Down
Loading
Loading