Skip to content

Conversation

@avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Dec 6, 2024

Removing the previous NNlib workaround:

julia> x = ConcreteRArray(rand(10, 10))
10×10 ConcreteRArray{Float64, 2}:
 0.0924137  0.582144   0.661002   0.403855   0.149149  0.695333   0.118886    0.0510647  0.711194   0.13487
 0.405112   0.228939   0.592257   0.730104   0.20901   0.240011   0.0733813   0.73111    0.640295   0.0959099
 0.572328   0.0694854  0.995223   0.0445609  0.930034  0.834714   0.694936    0.136905   0.797604   0.671286
 0.486331   0.552139   0.0900191  0.642302   0.582156  0.0716531  0.432385    0.91194    0.229147   0.314344
 0.639372   0.263377   0.337498   0.295831   0.943187  0.203924   0.993949    0.552247   0.95561    0.716701
 0.0716346  0.414622   0.851568   0.166069   0.548155  0.955265   0.522439    0.630945   0.176363   0.529564
 0.817722   0.913406   0.269638   0.713552   0.061855  0.12266    0.00595915  0.156552   0.985733   0.874415
 0.397864   0.701692   0.472078   0.445713   0.35231   0.479703   0.0666829   0.100765   0.0321458  0.321973
 0.812532   0.651512   0.482309   0.550568   0.90786   0.29578    0.95864     0.224822   0.351912   0.410754
 0.351531   0.393741   0.704482   0.304211   0.841324  0.168506   0.386456    0.794857   0.0337528  0.772713

julia> @code_hlo NNlib.make_causal_mask(x)
module {
  func.func @main(%arg0: tensor<10x10xf64>) -> tensor<10x10xi1> {
    %c = stablehlo.constant dense<[[true, false, false, false, false, false, false, false, false, false], [true, true, false, false, false, false, false, false, false, false], [true, true, true, false, false, false, false, false, false, false], [true, true, true, true, false, false, false, false, false, false], [true, true, true, true, true, false, false, false, false, false], [true, true, true, true, true, true, false, false, false, false], [true, true, true, true, true, true, true, false, false, false], [true, true, true, true, true, true, true, true, false, false], [true, true, true, true, true, true, true, true, true, false], [true, true, true, true, true, true, true, true, true, true]]> : tensor<10x10xi1>
    return %c : tensor<10x10xi1>
  }
}

@mofeing
Copy link
Collaborator

mofeing commented Dec 6, 2024

oh wow, that's a big change in the generated IR. Cool!

@wsmoses
Copy link
Member

wsmoses commented Dec 7, 2024

@avik-pal we definitely should add the pass, but I'm not sure about removing the overload per,

Test threw exception
  Expression: #= /home/runner/work/Reactant.jl/Reactant.jl/test/nn/nnlib.jl:185 =# @jit(causal_mask2(x_ra)) ≈ causal_mask2(x)
  Scalar indexing is disallowed.
  Invocation of getindex(::TracedRArray, ::Vararg{Int, N}) resulted in scalar indexing of a GPU array.
  This is typically caused by calling an iterating implementation of a method.
  Such implementations *do not* execute on the GPU, but very slowly on the CPU,
  and therefore should be avoided.
  
  If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
  to enable scalar iteration globally or for the operations in question.
  Stacktrace:
    [1] error(s::String)
      @ Base ./error.jl:35
    [2] errorscalar(op::String)
      @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:151
    [3] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
      @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:124
    [4] assertscalar(op::String)
      @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:112
    [5] getindex(::Reactant.TracedRArray{Bool, 2}, ::Int64, ::Int64)
      @ Reactant ~/work/Reactant.jl/Reactant.jl/src/TracedRArray.jl:86
    [6] triu!(M::Reactant.TracedRArray{Bool, 2}, k::Int64)
      @ LinearAlgebra /opt/hostedtoolcache/julia/1.10.7/x64/share/julia/stdlib/v1.10/LinearAlgebra/src/dense.jl:139
    [7] triu!
      @ /opt/hostedtoolcache/julia/1.10.7/x64/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:[435](https://github.com/EnzymeAD/Reactant.jl/actions/runs/12203183353/job/34045577736?pr=334#step:9:436) [inlined]
    [8] triu
      @ /opt/hostedtoolcache/julia/1.10.7/x64/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:341 [inlined]
    [9] #make_causal_mask#162
      @ ~/.julia/packages/NNlib/mRRJu/src/attention.jl:151 [inlined]
   [10] make_causal_mask
      @ ~/.julia/packages/NNlib/mRRJu/src/attention.jl:149 [inlined]
   [11] causal_mask2
      @ ~/work/Reactant.jl/Reactant.jl/test/nn/nnlib.jl:184 [inlined]
   [12] (::Tuple{})(none::Reactant.TracedRArray{Float64, 2})

At minimum we need to add an overload of triu and tril to tracedrarrays

@wsmoses
Copy link
Member

wsmoses commented Dec 7, 2024

for some inspiration: https://github.com/jax-ml/jax/blob/baedb62b71d9cf32d1922254d8faa3b03903ad77/jax/_src/lax/lax.py#L1780

so select iota dim1 > iota dim2, x, 0 or something would presumably do the trick

@avik-pal
Copy link
Collaborator Author

avik-pal commented Dec 7, 2024

I will add some dispatches for triu and tril. Scalar Indexing seems to be a problem only in 1.10, but the 1.11 version generates quite an inefficient IR:

julia> @code_hlo triu(x_ra, 1)
module {
  func.func @main(%arg0: tensor<4x4xf64>) -> tensor<4x4xf64> {
    %c = stablehlo.constant dense<3> : tensor<i64>
    %c_0 = stablehlo.constant dense<2> : tensor<i64>
    %c_1 = stablehlo.constant dense<1> : tensor<i64>
    %c_2 = stablehlo.constant dense<0> : tensor<i64>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<1x1xf64>
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64>
    %1 = stablehlo.dynamic_update_slice %0, %cst, %c_2, %c_2 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
    %2 = stablehlo.dynamic_update_slice %1, %cst, %c_1, %c_2 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
    %3 = stablehlo.dynamic_update_slice %2, %cst, %c_0, %c_2 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
    %4 = stablehlo.dynamic_update_slice %3, %cst, %c, %c_2 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
    %5 = stablehlo.dynamic_update_slice %4, %cst, %c_1, %c_1 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
    %6 = stablehlo.dynamic_update_slice %5, %cst, %c_0, %c_1 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
    %7 = stablehlo.dynamic_update_slice %6, %cst, %c, %c_1 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
    %8 = stablehlo.dynamic_update_slice %7, %cst, %c_0, %c_0 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
    %9 = stablehlo.dynamic_update_slice %8, %cst, %c, %c_0 : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
    %10 = stablehlo.dynamic_update_slice %9, %cst, %c, %c : (tensor<4x4xf64>, tensor<1x1xf64>, tensor<i64>, tensor<i64>) -> tensor<4x4xf64>
    %11 = stablehlo.transpose %10, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64>
    return %11 : tensor<4x4xf64>
  }
}

@wsmoses
Copy link
Member

wsmoses commented Dec 7, 2024

huh yeah I presume that 1.11 sets rows/slices and thus doesn't hit scalar indexing

function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T}
idxs =
Ops.iota(Int64, [size(X)...]; iota_dimension=1) .<
Ops.iota(Int64, [size(X)...]; iota_dimension=2) .- (k - 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once again having ops here is such a quality of life improvement (thanks @mofeing )

@avik-pal
Copy link
Collaborator Author

avik-pal commented Dec 7, 2024

julia> @code_hlo triu(x_ra, 1)
module {
  func.func @main(%arg0: tensor<4x4xf64>) -> tensor<4x4xf64> {
    %c = stablehlo.constant dense<[[false, true, true, true], [false, false, true, true], [false, false, false, true], [false, false, false, false]]> : tensor<4x4xi1>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<4x4xf64>
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64>
    %1 = stablehlo.select %c, %0, %cst : tensor<4x4xi1>, tensor<4x4xf64>
    %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64>
    return %2 : tensor<4x4xf64>
  }
}

@wsmoses
Copy link
Member

wsmoses commented Dec 7, 2024

julia> @code_hlo triu(x_ra, 1)
module {
  func.func @main(%arg0: tensor<4x4xf64>) -> tensor<4x4xf64> {
    %c = stablehlo.constant dense<[[false, true, true, true], [false, false, true, true], [false, false, false, true], [false, false, false, false]]> : tensor<4x4xi1>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<4x4xf64>
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64>
    %1 = stablehlo.select %c, %0, %cst : tensor<4x4xi1>, tensor<4x4xf64>
    %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64>
    return %2 : tensor<4x4xf64>
  }
}

Looks like we should probably make a transpose / select / transpose optimization (or perhaps even more generally some transpose [a bunch of elementwise ops and constants] transpose optimization

@avik-pal avik-pal force-pushed the ap/dynamic_slice_cprop branch from fc60d68 to 49128c3 Compare December 7, 2024 05:29
@avik-pal avik-pal changed the title feat: add dynamic_update_slice_const_prop pass feat: add dynamic_update_slice_const_prop pass + tril + triu Dec 7, 2024
@avik-pal avik-pal requested a review from wsmoses December 7, 2024 06:17
@avik-pal avik-pal merged commit 3a0710d into main Dec 7, 2024
22 of 37 checks passed
@avik-pal avik-pal deleted the ap/dynamic_slice_cprop branch December 7, 2024 06:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants