-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Description
I'd like to accelerate custom IIR-DSP filters (biquad, but also more application-specific, in any case inherently serial along signal-samples but parallel across signals) via Reactant. From what I understand, the corrects resulting StableHLO-pattern should be a while loop along time over operations broadcasted across samples (e.g. like in RNNs are handled)?
As a test, I tried implementing a cumulative sum manually, but the way I wrote it, Reactant unrolls the loop:
julia> using Reactant; Reactant.set_default_backend("cpu")
julia> A = rand(Float32, 7, 5); rA = Reactant.to_rarray(A);
julia> function foo!(A)
xs = eachrow(A)
for i in eachindex(xs)[begin+1:end]
xs[i] += xs[i-1]
end
return A
end
foo! (generic function with 1 method)
julia> foo!(A)
7×5 Matrix{Float32}:
0.385297 0.487759 0.478373 0.591901 0.940462
1.33295 0.890929 0.481439 1.03631 0.970324
1.96804 1.06544 1.42751 1.55177 1.84208
2.96404 1.2932 2.05575 1.83932 2.3411
3.18363 1.68394 2.55661 1.8492 3.05771
4.15348 1.74053 3.53721 1.88457 3.54661
4.59739 1.75737 3.95119 1.94483 4.04031
julia> @code_hlo foo!(rA)
module @"reactant_foo!" attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<5x7xf32> {tf.aliasing_output = 0 : i32}) -> tensor<5x7xf32> {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<5x7xf32>) -> tensor<7x5xf32>
%1 = stablehlo.slice %0 [1:2, 0:5] : (tensor<7x5xf32>) -> tensor<1x5xf32>
%2 = stablehlo.slice %0 [0:1, 0:5] : (tensor<7x5xf32>) -> tensor<1x5xf32>
%3 = stablehlo.slice %0 [2:3, 0:5] : (tensor<7x5xf32>) -> tensor<1x5xf32>
%4 = stablehlo.slice %0 [3:4, 0:5] : (tensor<7x5xf32>) -> tensor<1x5xf32>
%5 = stablehlo.slice %0 [4:5, 0:5] : (tensor<7x5xf32>) -> tensor<1x5xf32>
%6 = stablehlo.slice %0 [5:6, 0:5] : (tensor<7x5xf32>) -> tensor<1x5xf32>
%7 = stablehlo.slice %0 [6:7, 0:5] : (tensor<7x5xf32>) -> tensor<1x5xf32>
%8 = stablehlo.reshape %2 : (tensor<1x5xf32>) -> tensor<5x1xf32>
%9 = stablehlo.reshape %1 : (tensor<1x5xf32>) -> tensor<5x1xf32>
%10 = stablehlo.add %9, %8 : tensor<5x1xf32>
%11 = stablehlo.reshape %3 : (tensor<1x5xf32>) -> tensor<5x1xf32>
%12 = stablehlo.add %11, %10 : tensor<5x1xf32>
%13 = stablehlo.reshape %4 : (tensor<1x5xf32>) -> tensor<5x1xf32>
%14 = stablehlo.add %13, %12 : tensor<5x1xf32>
%15 = stablehlo.reshape %5 : (tensor<1x5xf32>) -> tensor<5x1xf32>
%16 = stablehlo.add %15, %14 : tensor<5x1xf32>
%17 = stablehlo.reshape %6 : (tensor<1x5xf32>) -> tensor<5x1xf32>
%18 = stablehlo.add %17, %16 : tensor<5x1xf32>
%19 = stablehlo.reshape %7 : (tensor<1x5xf32>) -> tensor<5x1xf32>
%20 = stablehlo.add %19, %18 : tensor<5x1xf32>
%21 = stablehlo.concatenate %8, %10, %12, %14, %16, %18, %20, dim = 1 : (tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>, tensor<5x1xf32>) -> tensor<5x7xf32>
return %21 : tensor<5x7xf32>
}
}Obviously this won't scale to larger arrays (more signal time steps). Should Reactant do this automatically, or do I need to give it some hints somehow?
Metadata
Metadata
Assignees
Labels
No labels