[TIR]Fix Inlining of Non-Output Consumers in TileWithTensorIntrin with Padding#17161
Closed
YXY-0922 wants to merge 1 commit intoapache:mainfrom
Closed
[TIR]Fix Inlining of Non-Output Consumers in TileWithTensorIntrin with Padding#17161YXY-0922 wants to merge 1 commit intoapache:mainfrom
YXY-0922 wants to merge 1 commit intoapache:mainfrom
Conversation
…nsorIntrin In the TileWithTensorIntrin function, modified the inlining behavior of consumer blocks. Now, when padding is applied, the function inlines only non-output consumer blocks. This ensures that the padding and inlining process is correctly handled for both producers and consumers. Changes: - Added a check to ensure only non-output consumer blocks are inlined using tir::IsOutputBlock. - Updated the loop iterating over consumers to include the new check. This fix addresses issues where output blocks were being inappropriately inlined, maintaining the correct block shapes and dependencies.
Contributor
|
Thank you for the contribution ! |
cbalint13
reviewed
Jul 16, 2024
| for (const auto& consumer : consumers) { | ||
| sch->ComputeInline(consumer); | ||
| auto sref = sch->GetSRef(consumer); | ||
| if (!tir::IsOutputBlock(sch->state(), sref, tir::GetScopeRoot(sch->state(), sref, true))) |
Contributor
There was a problem hiding this comment.
Could add a simple test case to check resulted IR validity under this new condition ?
Contributor
Author
There was a problem hiding this comment.
Sure, I encountered this bug while using the meta_schedule to tune a conv2d operator. Here is the TIR example:
import tvm
from tvm import te, topi, tir
from tvm.ir.module import IRModule
from tvm.script import tir as T
from tvm.tir.schedule.transform import tile_with_tensor_intrin
from tvm.tir.tensor_intrin.cuda import WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
@tvm.script.ir_module
class conv2d_Module:
@T.prim_func
def main(A: T.Buffer((16, 3, 224, 224), "float16"), B: T.Buffer((64, 3, 7, 7), "float16"), conv2d_nchw: T.Buffer((16, 64, 112, 112), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((16, 3, 230, 230), "float16")
conv2d_nchw_reindex = T.alloc_buffer((200704, 64), "float16")
pad_temp_reindex = T.alloc_buffer((200704, 147), "float16")
B_reindex = T.alloc_buffer((64, 147), "float16")
for i0, i1, i2, i3 in T.grid(16, 3, 230, 230):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(A[v_i0, v_i1, v_i2 - 3, v_i3 - 3])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i2 and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, A[v_i0, v_i1, v_i2 - 3, v_i3 - 3], T.float16(0))
for ax0, ax1 in T.grid(200704, 147):
with T.block("pad_temp_reindex_reindex"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7])
T.writes(pad_temp_reindex[v0, v1])
pad_temp_reindex[v0, v1] = pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7]
for ax0, ax1 in T.grid(64, 147):
with T.block("B_reindex_reindex"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(B[v0, v1 // 49, v1 % 49 // 7, v1 % 7])
T.writes(B_reindex[v0, v1])
B_reindex[v0, v1] = B[v0, v1 // 49, v1 % 49 // 7, v1 % 7]
for ax0, ax1, ax2 in T.grid(200704, 64, 147):
with T.block("conv2d_nchw"):
v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2])
T.reads(pad_temp_reindex[v0, v2], B_reindex[v1, v2])
T.writes(conv2d_nchw_reindex[v0, v1])
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
with T.init():
conv2d_nchw_reindex[v0, v1] = T.float16(0)
conv2d_nchw_reindex[v0, v1] = conv2d_nchw_reindex[v0, v1] + pad_temp_reindex[v0, v2] * B_reindex[v1, v2]
for ax0, ax1 in T.grid(200704, 64):
with T.block("conv2d_nchw_reindex"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(conv2d_nchw_reindex[v0, v1])
T.writes(conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112])
conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112] = conv2d_nchw_reindex[v0, v1]
sch = tvm.tir.Schedule(conv2d_Module)
intrin = WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
block = sch.get_block("conv2d_nchw")
tiled_loop = tile_with_tensor_intrin(sch, block, intrin, True)
print(sch.mod)And the output is :
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((16, 3, 224, 224), "float16"), B: T.Buffer((64, 3, 7, 7), "float16"), conv2d_nchw: T.Buffer((16, 64, 112, 112), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((16, 3, 230, 230), "float16")
conv2d_nchw_reindex = T.alloc_buffer((200704, 64), "float16")
pad_temp_reindex_pad = T.alloc_buffer((200704, 160), "float16")
B_reindex_pad = T.alloc_buffer((64, 160), "float16")
for i0, i1, i2, i3 in T.grid(16, 3, 230, 230):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(A[v_i0, v_i1, v_i2 - 3, v_i3 - 3])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(3 <= v_i2 and v_i2 < 227 and 3 <= v_i3 and v_i3 < 227, A[v_i0, v_i1, v_i2 - 3, v_i3 - 3], T.float16(0))
for i0, i1 in T.grid(200704, 160):
with T.block("pad_temp_reindex_pad"):
v0, v1 = T.axis.remap("SS", [i0, i1])
T.reads(pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7])
T.writes(pad_temp_reindex_pad[v0, v1])
pad_temp_reindex_pad[v0, v1] = T.if_then_else(v1 < 147, pad_temp[v0 // 12544, v1 // 49, v0 % 12544 // 112 * 2 + v1 % 49 // 7, v0 % 112 * 2 + v1 % 7], T.float16(0))
for i0, i1 in T.grid(64, 160):
with T.block("B_reindex_pad"):
v0, v1 = T.axis.remap("SS", [i0, i1])
T.reads(B[v0, v1 // 49, v1 % 49 // 7, v1 % 7])
T.writes(B_reindex_pad[v0, v1])
B_reindex_pad[v0, v1] = T.if_then_else(v1 < 147, B[v0, v1 // 49, v1 % 49 // 7, v1 % 7], T.float16(0))
for ax0_0, ax1_0, ax2_0, ax0_1, ax1_1, ax2_1 in T.grid(12544, 4, 10, 16, 16, 16):
with T.block("conv2d_nchw"):
v0 = T.axis.spatial(200704, ax0_0 * 16 + ax0_1)
v1 = T.axis.spatial(64, ax1_0 * 16 + ax1_1)
v2 = T.axis.reduce(160, ax2_0 * 16 + ax2_1)
T.reads(pad_temp_reindex_pad[v0, v2], B_reindex_pad[v1, v2])
T.writes(conv2d_nchw_reindex[v0, v1])
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
with T.init():
conv2d_nchw_reindex[v0, v1] = T.float16(0)
conv2d_nchw_reindex[v0, v1] = conv2d_nchw_reindex[v0, v1] + pad_temp_reindex_pad[v0, v2] * B_reindex_pad[v1, v2]
for ax0, ax1 in T.grid(200704, 64):
with T.block("conv2d_nchw_reindex"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(conv2d_nchw_reindex[v0, v1])
T.writes(conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112])
conv2d_nchw[v0 // 12544, v1, v0 % 12544 // 112, v0 % 112] = conv2d_nchw_reindex[v0, v1]The product of the three reduction axes is 147, hence padding is required.
Contributor
There was a problem hiding this comment.
Could add it as a simple testcase script e.g. for tests/python/meta_schedule ?
During a tuning process similar (padding) issues might be overlooked, but a testcase always catch it in CI.
Contributor
Author
There was a problem hiding this comment.
OK, I will do it later.
cbalint13
approved these changes
Jul 16, 2024
Contributor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Bug Fix
In the
TileWithTensorIntrinfunction, when theallow_paddingparameter is enabled, the original implementation inlines all consumer blocks. This behavior can lead to incorrect inlining of output blocks, causing issues with block shapes and dependencies. To ensure correct inlining operations, only non-output consumer blocks should be inlined.Changes Made
tir::IsOutputBlockfunction to determine if a block is an output block.sch->ComputeInlineonly if the block is not an output block.Specific Code Changes
TileWithTensorIntrinfunction:Impact
These changes ensure that when padding is enabled, only non-output blocks will be inlined, maintaining correct block shapes and dependencies. This fixes the issue in previous versions where output blocks might be incorrectly inlined.
Please review these changes and provide feedback for further improvements. Thank you for your time and assistance!