Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 215 additions & 7 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,14 @@ function NNlib.conv!(
)
result_type = Reactant.MLIR.IR.TensorType(size(y), Reactant.MLIR.IR.Type(T))

weight = W.mlir_data
weight = W
if !flipkernel
weight = Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.reverse(
weight; dimensions=collect(kernel_spatial_dims .- 1)
),
)
weight = Reactant.Ops.reverse(weight; dimensions=kernel_spatial_dims)
end

conv = Reactant.MLIR.Dialects.stablehlo.convolution(
x.mlir_data,
weight;
weight.mlir_data;
result_0=result_type,
window_strides=collect(stride),
padding,
Expand Down Expand Up @@ -377,4 +373,216 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
return dst
end

dilate_shape(s, d) = max(0, 1 + d * (s - 1))

# see lax._conv_general_dilated_transpose_rhs
# https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L495
function NNlib.∇conv_filter!(
dw::Reactant.TracedRArray{T,N},
x::AnyTracedRArray,
dy::AnyTracedRArray,
cdims::NNlib.DenseConvDims,
) where {T,N}
# (w, h, cin, b)
# (w, h, cout, b)
# -> (w, h, cin, cout)

x = T.(materialize_traced_array(x))
dy = T.(materialize_traced_array(dy))

num_spatial_dims = N - 2
input_batch_dim = N - 1
input_feature_dim = N

kernel_input_dim = N
kernel_output_dim = N - 1

output_batch_dim = N - 1
output_feature_dim = N

output_spatial_dims = kernel_spatial_dims = input_spatial_dims = 1:num_spatial_dims

padding = reshape(collect(NNlib.padding(cdims)), (2, num_spatial_dims))
stride = NNlib.stride(cdims)
dilation = NNlib.dilation(cdims)
feature_group_count = NNlib.groupcount(cdims)

padding =
let lhs_shape = first(size(x), num_spatial_dims),
rhs_shape = dilate_shape.(first(size(dw), num_spatial_dims), dilation),
out_shape = dilate_shape.(first(size(dy), num_spatial_dims), stride),

padding = reduce(
hcat,
(
let pad_before = padding[1, i],
pad_after = (
out_shape[i] - lhs_shape[i] + rhs_shape[i] - pad_before - 1
)

[pad_before, pad_after]
end for i in 1:num_spatial_dims
),
)

Reactant.MLIR.IR.DenseElementsAttribute(padding')
end

batch_group_count = 1
if feature_group_count > 1
batch_group_count = feature_group_count
feature_group_count = 1
end

dimension_numbers = MLIR.API.stablehloConvDimensionNumbersGet(
MLIR.IR.context(),
Int64(input_batch_dim - 1),
Int64(input_feature_dim - 1),
length(input_spatial_dims),
Int64[i - 1 for i in input_spatial_dims],
Int64(kernel_input_dim - 1),
Int64(kernel_output_dim - 1),
length(kernel_spatial_dims),
Int64[i - 1 for i in kernel_spatial_dims],
Int64(output_batch_dim - 1),
Int64(output_feature_dim - 1),
length(output_spatial_dims),
Int64[i - 1 for i in output_spatial_dims],
)

result_type = Reactant.MLIR.IR.TensorType(size(dw), Reactant.MLIR.IR.Type(T))
conv = MLIR.Dialects.stablehlo.convolution(
x.mlir_data,
dy.mlir_data;
result_0=result_type,
window_strides=collect(dilation),
padding,
dimension_numbers,
rhs_dilation=collect(stride),
feature_group_count,
batch_group_count,
)

dw.mlir_data = MLIR.IR.result(conv)

if !NNlib.flipkernel(cdims)
dw.mlir_data = Reactant.Ops.reverse(dw; dimensions=output_spatial_dims).mlir_data
end

return dw
end

# see lax._conv_general_dilated_transpose_lhs
# https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L457
function NNlib.∇conv_data!(
dx::Reactant.TracedRArray{T,N},
dy::AnyTracedRArray,
w::AnyTracedRArray,
cdims::NNlib.DenseConvDims,
) where {T,N}
# (w, h, cout, b)
# (w, h, cin, cout)
# -> (w, h, cin, b)

dy = T.(materialize_traced_array(dy))
w = T.(materialize_traced_array(w))

num_spatial_dims = N - 2
input_batch_dim = N
input_feature_dim = N - 1

kernel_input_dim = N
kernel_output_dim = N - 1

output_batch_dim = N
output_feature_dim = N - 1

output_spatial_dims = kernel_spatial_dims = input_spatial_dims = 1:num_spatial_dims

padding = reshape(collect(NNlib.padding(cdims)), (2, num_spatial_dims))
stride = NNlib.stride(cdims)
dilation = NNlib.dilation(cdims)
feature_group_count = NNlib.groupcount(cdims)

# jax does
# (cout, cin, h, w) -> (group, cout ÷ group, cin , h, w) -> (cout ÷ group, group, cin, h, w) -> (cout, cin * group, h, w)
# we perform the same operation but in transposed form
# (w, h, cin, cout) -> (w, h, cin, cout ÷ group, group) -> (w, h, cin, group, cout ÷ group) -> (w, h, cin * group, cout ÷ group)
if feature_group_count > 1
w = reshape(
w,
(size(w, i) for i in kernel_spatial_dims)...,
size(w, N - 1),
size(w, N) ÷ feature_group_count,
feature_group_count,
)
w = permutedims(w, (kernel_spatial_dims..., N - 1, N + 1, N))
w = reshape(
w,
(size(w, i) for i in kernel_spatial_dims)...,
size(w, N - 1) * feature_group_count,
size(w, N + 1),
)
end

padding =
let lhs_shape = first(size(dx), num_spatial_dims),
rhs_shape = dilate_shape.(first(size(w), num_spatial_dims), dilation),
out_shape = dilate_shape.(first(size(dy), num_spatial_dims), stride),

padding = reduce(
hcat,
(
let pad_before = rhs_shape[i] - padding[2i - 1] - 1,
pad_after =
lhs_shape[i] + rhs_shape[i] - 1 - out_shape[i] - pad_before

[pad_before, pad_after]
end for i in input_spatial_dims
),
)

Reactant.MLIR.IR.DenseElementsAttribute(padding')
end

dimension_numbers = MLIR.API.stablehloConvDimensionNumbersGet(
MLIR.IR.context(),
Int64(input_batch_dim - 1),
Int64(input_feature_dim - 1),
length(input_spatial_dims),
Int64[i - 1 for i in input_spatial_dims],
Int64(kernel_input_dim - 1),
Int64(kernel_output_dim - 1),
length(kernel_spatial_dims),
Int64[i - 1 for i in kernel_spatial_dims],
Int64(output_batch_dim - 1),
Int64(output_feature_dim - 1),
length(output_spatial_dims),
Int64[i - 1 for i in output_spatial_dims],
)

result_type = Reactant.MLIR.IR.TensorType(size(dx), Reactant.MLIR.IR.Type(T))

if NNlib.flipkernel(cdims)
w = Reactant.Ops.reverse(w; dimensions=kernel_spatial_dims)
end

conv = MLIR.Dialects.stablehlo.convolution(
dy.mlir_data,
w.mlir_data;
result_0=result_type,
window_strides=1,
padding,
lhs_dilation=collect(stride),
rhs_dilation=collect(dilation),
dimension_numbers,
feature_group_count,
batch_group_count=1,
)

dx.mlir_data = MLIR.IR.result(conv)

return dx
end

end # module ReactantNNlibExt
2 changes: 1 addition & 1 deletion src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ function reverse(
stablehlo.reverse(
x.mlir_data;
result=mlir_type(TracedRArray{T,N}, size(x)),
dimensions=MLIR.IR.DenseArrayAttribute(dimensions .- 1),
dimensions=MLIR.IR.DenseArrayAttribute(collect(dimensions .- 1)),
location,
),
)
Expand Down
46 changes: 44 additions & 2 deletions test/nn/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,26 @@ end

conv_dims = DenseConvDims(x, weight; stride, padding, dilation, groups)

output_size = (
NNlib.output_size(conv_dims)...,
size(weight, ndims(weight)),
size(x, ndims(x)),
)
dy = randn(Float32, output_size)
dy_reactant = Reactant.to_rarray(dy)

conv_compiled = Reactant.compile(
NNlib.conv, (x_reactant, weight_reactant, conv_dims)
)

@test conv_compiled(x_reactant, weight_reactant, conv_dims) ≈
NNlib.conv(x, weight, conv_dims)
end

# TODO: test for gradients
@test Reactant.@jit(NNlib.∇conv_data(dy_reactant, weight_reactant, conv_dims)) ≈
NNlib.∇conv_data(dy, weight, conv_dims)
@test Reactant.@jit(NNlib.∇conv_filter(x_reactant, dy_reactant, conv_dims)) ≈
NNlib.∇conv_filter(x, dy, conv_dims)
end
end

@testset "conv 1d: flip" begin
Expand Down Expand Up @@ -351,3 +362,34 @@ end
@test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...)
end
end

@testset "∇conv(D = $ndim)" for ndim in 1:3
x_spatial_dim = 4
batch_size = 2
n_in_features = 3
n_out_features = 4
kernel_size = Tuple((2 for _ in 1:ndim))

x = randn(Float32, (x_spatial_dim for _ in 1:ndim)..., n_in_features, batch_size)
x_reactant = Reactant.to_rarray(x)

w = randn(Float32, kernel_size..., n_in_features, n_out_features)
w_reactant = Reactant.to_rarray(w)

@testset "conv: padding=$padding stride=$stride dilation=$dilation groups=$groups" for (
padding, stride, dilation, groups
) in Iterators.product(
(0, 2), (1, 2), (1,), (1,)
)
conv_dims = NNlib.DenseConvDims(x, w; padding, stride, dilation, groups)

output_size = (NNlib.output_size(conv_dims)..., n_out_features, batch_size)
dy = randn(Float32, output_size)
dy_reactant = Reactant.to_rarray(dy)

@test Reactant.@jit(NNlib.∇conv_data(dy_reactant, w_reactant, conv_dims)) ≈
NNlib.∇conv_data(dy, w, conv_dims)
@test Reactant.@jit(NNlib.∇conv_filter(x_reactant, dy_reactant, conv_dims)) ≈
NNlib.∇conv_filter(x, dy, conv_dims)
end
end
Loading