Skip to content

Commit

Permalink
[Unity][Dlight] Fix DecodeGeMV rule for spatial-inner with grouping (#…
Browse files Browse the repository at this point in the history
…15340)

This PR fixes a bug of DecodeGeMV dlight rule when the innermost
tensor dimension is spatial with `unroll_factor` (for example, the
grouping used in group quantization).

Prior to this PR, a reduction loop that is bound to threadIdx was
reordered to reside outside a split spatial loop, which prevents the
TIR LowerCrossThreadReduction pass to successfully apply due to some
safety-guard requirement.

This PR fixes this issue by not reordering the split spatial loop
after the reduction loop, so that the pass can be applied.
Note that we can do this as the order of thread-binding loops does
not matter.
  • Loading branch information
MasterJH5574 committed Jul 18, 2023
1 parent 63b170d commit 1e1ff66
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 1 addition & 2 deletions python/tvm/dlight/gpu/decode_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,7 @@ def _sch_inner_spatial(
s = sch.fuse(*s)
sch.reorder(s, r)
if unroll_spatial_factor:
s, inner = sch.split(s, factors=[None, unroll_spatial_factor])
sch.reorder(s, r, inner)
s, _ = sch.split(s, factors=[None, unroll_spatial_factor])
sch.bind(s, "threadIdx.x")
sch.bind(r, "threadIdx.y")
# Schedule epilogue
Expand Down
4 changes: 2 additions & 2 deletions tests/python/dlight/test_gpu_decode_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16")
vk_fused_0 = T.axis.reduce(256, k_fused_0)
C_rf_local[vk_fused_1, 0, 0, v_i2] = C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 16 + vk_fused_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[vk_fused_0 * 16 + vk_fused_1, v_i2 // 8], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[vk_fused_0 * 16 + vk_fused_1, v_i2 // 32])
for ax1_ax2_ax3_fused_0 in T.thread_binding(16, thread="threadIdx.x"):
for ax0_fused in T.thread_binding(16, thread="threadIdx.y"):
for ax1_ax2_ax3_fused_1 in range(8):
for ax1_ax2_ax3_fused_1 in range(8):
for ax0_fused in T.thread_binding(16, thread="threadIdx.y"):
with T.block("matmul"):
vk_fused_1 = T.axis.reduce(16, ax0_fused)
v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 * 128 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1)
Expand Down

0 comments on commit 1e1ff66

Please sign in to comment.