From 1f6198b04fa3d9acccee9bde60dd851f6b19be6c Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 1 Dec 2024 22:01:12 +0100 Subject: [PATCH 1/2] =?UTF-8?q?implement=20NNlib.=E2=88=87conv=5Fdata=20an?= =?UTF-8?q?d=20NNlib.=E2=88=87conv=5Ffilter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ext/ReactantNNlibExt.jl | 210 ++++++++++++++++++++++++++++++++++++++++ src/Ops.jl | 2 +- test/nn/nnlib.jl | 46 ++++++++- 3 files changed, 255 insertions(+), 3 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 90b1cd2ccc..88628c936a 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -377,4 +377,214 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr return dst end +dilate_shape(s, d) = max(0, 1 + d * (s - 1)) + +# see lax._conv_general_dilated_transpose_rhs +# https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L495 +function NNlib.∇conv_filter!( + dw::Reactant.TracedRArray{T,N}, + x::AnyTracedRArray, + dy::AnyTracedRArray, + cdims::NNlib.DenseConvDims, +) where {T,N} + # (w, h, cin, b) + # (w, h, cout, b) + # -> (w, h, cin, cout) + + x = T.(materialize_traced_array(x)) + dy = T.(materialize_traced_array(dy)) + + num_spatial_dims = N - 2 + input_batch_dim = N - 1 + input_feature_dim = N + + kernel_input_dim = N + kernel_output_dim = N - 1 + + output_batch_dim = N - 1 + output_feature_dim = N + + output_spatial_dims = kernel_spatial_dims = input_spatial_dims = 1:num_spatial_dims + + padding = reshape(collect(NNlib.padding(cdims)), (2, num_spatial_dims)) + stride = NNlib.stride(cdims) + dilation = NNlib.dilation(cdims) + feature_group_count = NNlib.groupcount(cdims) + + padding = + let lhs_shape = first(size(x), num_spatial_dims), + rhs_shape = dilate_shape.(first(size(dw), num_spatial_dims), dilation), + out_shape = dilate_shape.(first(size(dy), num_spatial_dims), stride), + + padding = reduce( + hcat, + ( + let pad_before = padding[1, i], + pad_after = ( + out_shape[i] - lhs_shape[i] + rhs_shape[i] - pad_before - 1 + ) + + [pad_before, pad_after] + end for i in 1:num_spatial_dims + ), + ) + + Reactant.MLIR.IR.DenseElementsAttribute(padding') + end + + batch_group_count = 1 + if feature_group_count > 1 + batch_group_count = feature_group_count + feature_group_count = 1 + end + + dimension_numbers = MLIR.API.stablehloConvDimensionNumbersGet( + MLIR.IR.context(), + Int64(input_batch_dim - 1), + Int64(input_feature_dim - 1), + length(input_spatial_dims), + Int64[i - 1 for i in input_spatial_dims], + Int64(kernel_input_dim - 1), + Int64(kernel_output_dim - 1), + length(kernel_spatial_dims), + Int64[i - 1 for i in kernel_spatial_dims], + Int64(output_batch_dim - 1), + Int64(output_feature_dim - 1), + length(output_spatial_dims), + Int64[i - 1 for i in output_spatial_dims], + ) + + result_type = Reactant.MLIR.IR.TensorType(size(dw), Reactant.MLIR.IR.Type(T)) + conv = MLIR.Dialects.stablehlo.convolution( + x.mlir_data, + dy.mlir_data; + result_0=result_type, + window_strides=collect(dilation), + padding, + dimension_numbers, + rhs_dilation=collect(stride), + feature_group_count, + batch_group_count, + ) + + dw.mlir_data = MLIR.IR.result(conv) + + dw.mlir_data = Reactant.Ops.reverse(dw; dimensions=output_spatial_dims).mlir_data + + return dw +end + +# see lax._conv_general_dilated_transpose_lhs +# https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L457 +function NNlib.∇conv_data!( + dx::Reactant.TracedRArray{T,N}, + dy::AnyTracedRArray, + w::AnyTracedRArray, + cdims::NNlib.DenseConvDims, +) where {T,N} + # (w, h, cout, b) + # (w, h, cin, cout) + # -> (w, h, cin, b) + + dy = T.(materialize_traced_array(dy)) + w = T.(materialize_traced_array(w)) + + num_spatial_dims = N - 2 + input_batch_dim = N + input_feature_dim = N - 1 + + kernel_input_dim = N + kernel_output_dim = N - 1 + + output_batch_dim = N + output_feature_dim = N - 1 + + output_spatial_dims = kernel_spatial_dims = input_spatial_dims = 1:num_spatial_dims + + padding = reshape(collect(NNlib.padding(cdims)), (2, num_spatial_dims)) + stride = NNlib.stride(cdims) + dilation = NNlib.dilation(cdims) + feature_group_count = NNlib.groupcount(cdims) + + # jax does + # (cout, cin, h, w) -> (group, cout ÷ group, cin , h, w) -> (cout ÷ group, group, cin, h, w) -> (cout, cin * group, h, w) + # we perform the same operation but in transposed form + # (w, h, cin, cout) -> (w, h, cin, cout ÷ group, group) -> (w, h, cin, group, cout ÷ group) -> (w, h, cin * group, cout ÷ group) + if feature_group_count > 1 + w = reshape( + w, + (size(w, i) for i in kernel_spatial_dims)..., + size(w, N - 1), + size(w, N) ÷ feature_group_count, + feature_group_count, + ) + w = permutedims(w, (kernel_spatial_dims..., N - 1, N + 1, N)) + w = reshape( + w, + (size(w, i) for i in kernel_spatial_dims)..., + size(w, N - 1) * feature_group_count, + size(w, N + 1), + ) + end + + padding = + let lhs_shape = first(size(dx), num_spatial_dims), + rhs_shape = dilate_shape.(first(size(w), num_spatial_dims), dilation), + out_shape = dilate_shape.(first(size(dy), num_spatial_dims), stride), + + padding = reduce( + hcat, + ( + let pad_before = rhs_shape[i] - padding[2i - 1] - 1, + pad_after = + lhs_shape[i] + rhs_shape[i] - 1 - out_shape[i] - pad_before + + [pad_before, pad_after] + end for i in input_spatial_dims + ), + ) + + Reactant.MLIR.IR.DenseElementsAttribute(padding') + end + + dimension_numbers = MLIR.API.stablehloConvDimensionNumbersGet( + MLIR.IR.context(), + Int64(input_batch_dim - 1), + Int64(input_feature_dim - 1), + length(input_spatial_dims), + Int64[i - 1 for i in input_spatial_dims], + Int64(kernel_input_dim - 1), + Int64(kernel_output_dim - 1), + length(kernel_spatial_dims), + Int64[i - 1 for i in kernel_spatial_dims], + Int64(output_batch_dim - 1), + Int64(output_feature_dim - 1), + length(output_spatial_dims), + Int64[i - 1 for i in output_spatial_dims], + ) + + result_type = Reactant.MLIR.IR.TensorType(size(dx), Reactant.MLIR.IR.Type(T)) + + if NNlib.flipkernel(cdims) + w = Reactant.Ops.reverse(w; dimensions=kernel_spatial_dims) + end + + conv = MLIR.Dialects.stablehlo.convolution( + dy.mlir_data, + w.mlir_data; + result_0=result_type, + window_strides=1, + padding, + lhs_dilation=collect(stride), + rhs_dilation=collect(dilation), + dimension_numbers, + feature_group_count, + batch_group_count=1, + ) + + dx.mlir_data = MLIR.IR.result(conv) + + return dx +end + end # module ReactantNNlibExt diff --git a/src/Ops.jl b/src/Ops.jl index 82888608b4..bf63df6924 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -885,7 +885,7 @@ function reverse( stablehlo.reverse( x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), - dimensions=MLIR.IR.DenseArrayAttribute(dimensions .- 1), + dimensions=MLIR.IR.DenseArrayAttribute(collect(dimensions .- 1)), location, ), ) diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 3be3d97fce..5f9c92ef74 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -67,15 +67,26 @@ end conv_dims = DenseConvDims(x, weight; stride, padding, dilation, groups) + output_size = ( + NNlib.output_size(conv_dims)..., + size(weight, ndims(weight)), + size(x, ndims(x)), + ) + dy = randn(Float32, output_size) + dy_reactant = Reactant.to_rarray(dy) + conv_compiled = Reactant.compile( NNlib.conv, (x_reactant, weight_reactant, conv_dims) ) @test conv_compiled(x_reactant, weight_reactant, conv_dims) ≈ NNlib.conv(x, weight, conv_dims) - end - # TODO: test for gradients + @test Reactant.@jit(NNlib.∇conv_data(dy_reactant, weight_reactant, conv_dims)) ≈ + NNlib.∇conv_data(dy, weight, conv_dims) + @test Reactant.@jit(NNlib.∇conv_filter(x_reactant, dy_reactant, conv_dims)) ≈ + NNlib.∇conv_filter(x, dy, conv_dims) + end end @testset "conv 1d: flip" begin @@ -351,3 +362,34 @@ end @test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...) end end + +@testset "∇conv(D = $ndim)" for ndim in 1:3 + x_spatial_dim = 4 + batch_size = 2 + n_in_features = 3 + n_out_features = 4 + kernel_size = Tuple((2 for _ in 1:ndim)) + + x = randn(Float32, (x_spatial_dim for _ in 1:ndim)..., n_in_features, batch_size) + x_reactant = Reactant.to_rarray(x) + + w = randn(Float32, kernel_size..., n_in_features, n_out_features) + w_reactant = Reactant.to_rarray(w) + + @testset "conv: padding=$padding stride=$stride dilation=$dilation groups=$groups" for ( + padding, stride, dilation, groups + ) in Iterators.product( + (0, 2), (1, 2), (1,), (1,) + ) + conv_dims = NNlib.DenseConvDims(x, w; padding, stride, dilation, groups) + + output_size = (NNlib.output_size(conv_dims)..., n_out_features, batch_size) + dy = randn(Float32, output_size) + dy_reactant = Reactant.to_rarray(dy) + + @test Reactant.@jit(NNlib.∇conv_data(dy_reactant, w_reactant, conv_dims)) ≈ + NNlib.∇conv_data(dy, w, conv_dims) + @test Reactant.@jit(NNlib.∇conv_filter(x_reactant, dy_reactant, conv_dims)) ≈ + NNlib.∇conv_filter(x, dy, conv_dims) + end +end From a7347c967548c6cce2feab908e2d89fdd926bbb4 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Mon, 2 Dec 2024 17:04:08 +0100 Subject: [PATCH 2/2] cond filter flipkernel --- ext/ReactantNNlibExt.jl | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 88628c936a..b78716d291 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -113,18 +113,14 @@ function NNlib.conv!( ) result_type = Reactant.MLIR.IR.TensorType(size(y), Reactant.MLIR.IR.Type(T)) - weight = W.mlir_data + weight = W if !flipkernel - weight = Reactant.MLIR.IR.result( - Reactant.MLIR.Dialects.stablehlo.reverse( - weight; dimensions=collect(kernel_spatial_dims .- 1) - ), - ) + weight = Reactant.Ops.reverse(weight; dimensions=kernel_spatial_dims) end conv = Reactant.MLIR.Dialects.stablehlo.convolution( x.mlir_data, - weight; + weight.mlir_data; result_0=result_type, window_strides=collect(stride), padding, @@ -469,7 +465,9 @@ function NNlib.∇conv_filter!( dw.mlir_data = MLIR.IR.result(conv) - dw.mlir_data = Reactant.Ops.reverse(dw; dimensions=output_spatial_dims).mlir_data + if !NNlib.flipkernel(cdims) + dw.mlir_data = Reactant.Ops.reverse(dw; dimensions=output_spatial_dims).mlir_data + end return dw end