Description
What happened?
Context
To make effective use of DPP operations available to AMD GPU's, the PR below changed the implementation of warp reduction to preserve subgroup_reduce
ops rather than immediately lowering to butterfly shuffling using gpu.shuffle xor
ops. The goal of preserving the subgroup_reduce
ops is to enable lowering to target-specific ops later in the pipeline, if such ops exist. So, this PR allows you to express reduction within warps and across warps using subgroup_reduce
ops.
However, there is incomplete support in SPIRV for clustered subgroup_reduce
which makes reduction across multiple warps difficult (you're not guaranteed to have 64 x 64 threads within a workgroup), so being able to perform subgroup_reduce
over <64 threads is useful.
Reproduction:
Input.mlir
module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformArithmetic, DotProduct, DotProductInput4x8BitPacked, DotProductInputAll, DotProductInput4x8Bit], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product]>, ARM, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512 : i32, 512 : i32, 512 : i32], subgroup_size = 16, min_subgroup_size = 16, max_subgroup_size = 16, cooperative_matrix_properties_khr = []>>} {
func.func @subgroup_reduce() {
%c7_i32 = arith.constant 7 : i32
%c0_i32 = arith.constant 0 : i32
%c16_i32 = arith.constant 16 : i32
%c128_i32 = arith.constant 128 : i32
%c2 = arith.constant 2 : index
%c256 = arith.constant 256 : index
%cst = arith.constant dense<0.000000e+00> : vector<4xf32>
%cst_0 = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%thread_id_x = gpu.thread_id x upper_bound 128
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer>, #hal.pipeline.binding<storage_buffer>]>) binding(0) alignment(64) offset(%c0) : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>{%c256}
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer>, #hal.pipeline.binding<storage_buffer>]>) binding(1) alignment(64) offset(%c0) : memref<?xf32, #spirv.storage_class<StorageBuffer>>{%c2}
%workgroup_id_x = hal.interface.workgroup.id[0] upper_bound 2 : index
%2 = arith.index_castui %workgroup_id_x : index to i32
%3 = arith.muli %2, %c128_i32 overflow<nsw> : i32
%4 = arith.index_castui %thread_id_x : index to i32
%5 = arith.addi %3, %4 : i32
%6 = arith.index_castui %5 : i32 to index
%7 = memref.load %0[%6] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
%8 = arith.addf %7, %cst : vector<4xf32>
%9 = vector.reduction <add>, %8 : vector<4xf32> into f32
%10 = gpu.subgroup_reduce add %9 cluster(size = 16) : (f32) -> f32
%alloc = memref.alloc() : memref<8xf32, #spirv.storage_class<Workgroup>>
%11 = arith.divui %4, %c16_i32 : i32
%12 = arith.index_castui %11 : i32 to index
%13 = arith.remui %4, %c16_i32 : i32
%14 = arith.cmpi eq, %13, %c0_i32 : i32
scf.if %14 {
memref.store %10, %alloc[%12] : memref<8xf32, #spirv.storage_class<Workgroup>>
}
gpu.barrier
%15 = arith.minui %13, %c7_i32 : i32
%16 = arith.index_castui %15 : i32 to index
%17 = memref.load %alloc[%16] : memref<8xf32, #spirv.storage_class<Workgroup>>
%18 = gpu.subgroup_reduce add %17 cluster(size = 8) : (f32) -> f32
%19 = arith.addf %18, %cst_0 : f32
%20 = arith.cmpi eq, %4, %c0_i32 : i32
scf.if %20 {
memref.store %19, %1[%workgroup_id_x] : memref<?xf32, #spirv.storage_class<StorageBuffer>>
}
return
}
}
Command:
iree-opt --iree-gpu-test-target=valhall1 --pass-pipeline='builtin.module(iree-convert-to-spirv)' <Input.mlir>
Steps to reproduce your issue
- Go to '...'
- Click on '....'
- Scroll down to '....'
- See error
What component(s) does this issue relate to?
No response
Version information
No response
Additional context
No response