-
Notifications
You must be signed in to change notification settings - Fork 38
feat: add dynamic_update_slice_const_prop pass + tril + triu
#334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
oh wow, that's a big change in the generated IR. Cool! |
|
@avik-pal we definitely should add the pass, but I'm not sure about removing the overload per, At minimum we need to add an overload of triu and tril to tracedrarrays |
|
for some inspiration: https://github.com/jax-ml/jax/blob/baedb62b71d9cf32d1922254d8faa3b03903ad77/jax/_src/lax/lax.py#L1780 so |
|
I will add some dispatches for julia> @code_hlo triu(x_ra, 1)
module {
func.func @main(%arg0: tensor<4x4xf64>) -> tensor<4x4xf64> {
%c = stablehlo.constant dense<3> : tensor<i64>
%c_0 = stablehlo.constant dense<2> : tensor<i64>
%c_1 = stablehlo.constant dense<1> : tensor<i64>
%c_2 = stablehlo.constant dense<0> : tensor<i64>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<1x1xf64>
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64>
%1 = stablehlo.dynamic_update_slice %0, %cst, %c_2, %c_2 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
%2 = stablehlo.dynamic_update_slice %1, %cst, %c_1, %c_2 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
%3 = stablehlo.dynamic_update_slice %2, %cst, %c_0, %c_2 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
%4 = stablehlo.dynamic_update_slice %3, %cst, %c, %c_2 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
%5 = stablehlo.dynamic_update_slice %4, %cst, %c_1, %c_1 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
%6 = stablehlo.dynamic_update_slice %5, %cst, %c_0, %c_1 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
%7 = stablehlo.dynamic_update_slice %6, %cst, %c, %c_1 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
%8 = stablehlo.dynamic_update_slice %7, %cst, %c_0, %c_0 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
%9 = stablehlo.dynamic_update_slice %8, %cst, %c, %c_0 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
%10 = stablehlo.dynamic_update_slice %9, %cst, %c, %c : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
%11 = stablehlo.transpose %10, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64>
return %11 : tensor<4x4xf64>
}
} |
|
huh yeah I presume that 1.11 sets rows/slices and thus doesn't hit scalar indexing |
src/linear_algebra.jl
Outdated
| function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} | ||
| idxs = | ||
| Ops.iota(Int64, [size(X)...]; iota_dimension=1) .< | ||
| Ops.iota(Int64, [size(X)...]; iota_dimension=2) .- (k - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
once again having ops here is such a quality of life improvement (thanks @mofeing )
julia> @code_hlo triu(x_ra, 1)
module {
func.func @main(%arg0: tensor<4x4xf64>) -> tensor<4x4xf64> {
%c = stablehlo.constant dense<[[false, true, true, true], [false, false, true, true], [false, false, false, true], [false, false, false, false]]> : tensor<4x4xi1>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<4x4xf64>
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64>
%1 = stablehlo.select %c, %0, %cst : tensor<4x4xi1>, tensor<4x4xf64>
%2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64>
return %2 : tensor<4x4xf64>
}
} |
Looks like we should probably make a transpose / select / transpose optimization (or perhaps even more generally some transpose [a bunch of elementwise ops and constants] transpose optimization |
fc60d68 to
49128c3
Compare
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
dynamic_update_slice_const_prop pass + tril + triu
Removing the previous NNlib workaround: