Skip to content

[Unity] Fix FuseTIR when the same buffer is read multiple times with different access pattern#14603

Merged
tqchen merged 2 commits intoapache:unityfrom
masahi:fix-fuse-tir
Apr 12, 2023
Merged

[Unity] Fix FuseTIR when the same buffer is read multiple times with different access pattern#14603
tqchen merged 2 commits intoapache:unityfrom
masahi:fix-fuse-tir

Conversation

@masahi
Copy link
Member

@masahi masahi commented Apr 12, 2023

When the same buffer is read multiple times with different access patterns in a single expression, the check below fails
https://github.com/apache/tvm/blob/unity/src/relax/transform/fuse_tir.cc#L266

But this case should be allowed, for example in the following subgraph inp_0 is used twice in different read regions. See the test case for details. This subgraph arises if we run ConvertLayout on SD UNet from web-stable diffusion.

lv: R.Tensor((2, 4, 64, 64), dtype="float32") = R.concat((inp_0, inp_0), axis=0)
lv_2: R.Tensor((2, 64, 64, 4), dtype="float32") = R.permute_dims(lv, axes=[0, 2, 3, 1])

@Hzfengsy

@tvm-bot
Copy link
Collaborator

tvm-bot commented Apr 12, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@masahi masahi changed the title [Unity] Fix FuseTIR when the same buffer is read multiple times with different acess pattern [Unity] Fix FuseTIR when the same buffer is read multiple times with different access pattern Apr 12, 2023
ret.push_back(region);
buffer_region_set[region->buffer.get()] = region->region;
} else {
ICHECK(structural_equal_(region->region, it->second));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry the diff is noisy due to the unrelated style updates, but this is the only important diff.

v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
rxplaceholder_1[v_ax0 - T.int64(1), v_ax1, v_ax2, v_ax3],
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that in this test case, rxplaceholder_1 and rxplaceholder refer to the same buffer inp_0. But they are used with different access patterns.

@tqchen tqchen merged commit 61fbf42 into apache:unity Apr 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants