Skip to content

[MLIR] Bufferization of tensor.generate does not take into account repetitions #141667

Open
@erick-xanadu

Description

@erick-xanadu

Hello,

I have what I believe is an error in the bufferization of tensor.generate. When tensor.generate is being bufferized, it will bufferize the function body with the same rules as outside the body. In the following example, we see that circuit_0 is called with a different value in the argument each time it is called. This value is obtained through the extraction, addition, and insertion into a tensor obtained from the context above.

  func.func private @circuit_0.finitediff0(%arg0: tensor<2xf64>) -> tensor<2x2xf64> {
    %cst = arith.constant 3.000000e-01 : f64
    %cst_0 = arith.constant dense<3.000000e-01> : tensor<2x2xf64>
    %0 = call @circuit_0(%arg0) : (tensor<2xf64>) -> tensor<2xf64>
    %generated = tensor.generate  {
    ^bb0(%arg1: index, %arg2: index):

      // important bit
      %extracted = tensor.extract %arg0[%arg2] : tensor<2xf64>
      %2 = arith.addf %extracted, %cst : f64
      %inserted = tensor.insert %2 into %arg0[%arg2] : tensor<2xf64>
      // new value being passed here each time we loop trhough tensor.generate
      %3 = func.call @circuit_0(%inserted) : (tensor<2xf64>) -> tensor<2xf64>


      %4 = arith.subf %3, %0 : tensor<2xf64>
      %extracted_1 = tensor.extract %4[%arg1] : tensor<2xf64>
      tensor.yield %extracted_1 : f64
    } : tensor<2x2xf64>
    %1 = arith.divf %generated, %cst_0 : tensor<2x2xf64>
    return %1 : tensor<2x2xf64>
  }

However, after bufferization, we see the following code:

  func.func private @circuit_0.finitediff0(%arg0: memref<2xf64>) -> memref<2x2xf64> {
    %cst = arith.constant 3.000000e-01 : f64                                                             
    %0 = memref.get_global @__constant_2x2xf64 : memref<2x2xf64>                                         
    %1 = call @circuit_0(%arg0) : (memref<2xf64>) -> memref<2xf64>                                       
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x2xf64>                                     
    linalg.map outs(%alloc : memref<2x2xf64>)                                                            
      () { 


        %2 = linalg.index 0 : index
        %3 = linalg.index 1 : index                                                                      
        %4 = memref.load %arg0[%3] : memref<2xf64>                                                       
        %5 = arith.addf %4, %cst : f64
        memref.store %5, %arg0[%3] : memref<2xf64>

        // value of arg0 changes
        // with each iteration
        %6 = func.call @circuit_0(%arg0) : (memref<2xf64>) -> memref<2xf64>

                              
        linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%6, %1 : memref<2xf64>, memref<2xf64>) outs(%6 : memref<2xf64>) {
        ^bb0(%in: f64, %in_0: f64, %out: f64):                                                           
          %8 = arith.subf %in, %in_0 : f64                                                               
          linalg.yield %8 : f64
        }                                                                                                
        %7 = memref.load %6[%2] : memref<2xf64>                                                          
        linalg.yield %7 : f64                                                                            
      }
    linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%alloc, %0 : memref<2x2xf64>, memref<2x2xf64>) outs(%alloc : memref<2x2xf64>) {                                                  
    ^bb0(%in: f64, %in_0: f64, %out: f64):                                                               
      %2 = arith.divf %in, %in_0 : f64                                                                   
      linalg.yield %2 : f64                                                                              
    }                                                                                                    
    return %alloc : memref<2x2xf64>
  }

It looks like this may stem from the lack of bufferization of the linalg.map op, but I am not entirely sure.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions