Skip to content

[CodeGen][SPIRV] Lowering for clustered reduce not implemented #20872

Open
@Muzammiluddin-Syed-ECE

Description

@Muzammiluddin-Syed-ECE

What happened?

Context

To make effective use of DPP operations available to AMD GPU's, the PR below changed the implementation of warp reduction to preserve subgroup_reduce ops rather than immediately lowering to butterfly shuffling using gpu.shuffle xor ops. The goal of preserving the subgroup_reduce ops is to enable lowering to target-specific ops later in the pipeline, if such ops exist. So, this PR allows you to express reduction within warps and across warps using subgroup_reduce ops.

#20468

However, there is incomplete support in SPIRV for clustered subgroup_reduce which makes reduction across multiple warps difficult (you're not guaranteed to have 64 x 64 threads within a workgroup), so being able to perform subgroup_reduce over <64 threads is useful.

Reproduction:

Input.mlir

module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformArithmetic, DotProduct, DotProductInput4x8BitPacked, DotProductInputAll, DotProductInput4x8Bit], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product]>, ARM, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512 : i32, 512 : i32, 512 : i32], subgroup_size = 16, min_subgroup_size = 16, max_subgroup_size = 16, cooperative_matrix_properties_khr = []>>} {
  func.func @subgroup_reduce() {
    %c7_i32 = arith.constant 7 : i32
    %c0_i32 = arith.constant 0 : i32
    %c16_i32 = arith.constant 16 : i32
    %c128_i32 = arith.constant 128 : i32
    %c2 = arith.constant 2 : index
    %c256 = arith.constant 256 : index
    %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
    %cst_0 = arith.constant 0.000000e+00 : f32
    %c0 = arith.constant 0 : index
    %thread_id_x = gpu.thread_id  x upper_bound 128
    %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer>, #hal.pipeline.binding<storage_buffer>]>) binding(0) alignment(64) offset(%c0) : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>{%c256}
    %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer>, #hal.pipeline.binding<storage_buffer>]>) binding(1) alignment(64) offset(%c0) : memref<?xf32, #spirv.storage_class<StorageBuffer>>{%c2}
    %workgroup_id_x = hal.interface.workgroup.id[0] upper_bound 2 : index
    %2 = arith.index_castui %workgroup_id_x : index to i32
    %3 = arith.muli %2, %c128_i32 overflow<nsw> : i32
    %4 = arith.index_castui %thread_id_x : index to i32
    %5 = arith.addi %3, %4 : i32
    %6 = arith.index_castui %5 : i32 to index
    %7 = memref.load %0[%6] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
    %8 = arith.addf %7, %cst : vector<4xf32>
    %9 = vector.reduction <add>, %8 : vector<4xf32> into f32
    %10 = gpu.subgroup_reduce  add %9 cluster(size = 16) : (f32) -> f32
    %alloc = memref.alloc() : memref<8xf32, #spirv.storage_class<Workgroup>>
    %11 = arith.divui %4, %c16_i32 : i32
    %12 = arith.index_castui %11 : i32 to index
    %13 = arith.remui %4, %c16_i32 : i32
    %14 = arith.cmpi eq, %13, %c0_i32 : i32
    scf.if %14 {
      memref.store %10, %alloc[%12] : memref<8xf32, #spirv.storage_class<Workgroup>>
    }
    gpu.barrier
    %15 = arith.minui %13, %c7_i32 : i32
    %16 = arith.index_castui %15 : i32 to index
    %17 = memref.load %alloc[%16] : memref<8xf32, #spirv.storage_class<Workgroup>>
    %18 = gpu.subgroup_reduce  add %17 cluster(size = 8) : (f32) -> f32
    %19 = arith.addf %18, %cst_0 : f32
    %20 = arith.cmpi eq, %4, %c0_i32 : i32
    scf.if %20 {
      memref.store %19, %1[%workgroup_id_x] : memref<?xf32, #spirv.storage_class<StorageBuffer>>
    }
    return
  }
}

Command:

iree-opt --iree-gpu-test-target=valhall1 --pass-pipeline='builtin.module(iree-convert-to-spirv)' <Input.mlir>

Steps to reproduce your issue

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

What component(s) does this issue relate to?

No response

Version information

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug 🐞Something isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions