Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TE] Fix MakeLoopNest for warp memory #5382

Merged
merged 4 commits into from
May 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/te/operation/op_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,21 @@ MakeLoopNest(const Stage& stage,
value_map[iv] = dom->min;
} else {
runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag);
if (stage->scope == "" || stage->scope == "warp" ||
if (stage->scope == "" ||
static_cast<int>(runtime::StorageScope::make(stage->scope).rank) <= ts.rank) {
value_map[iv] = var;
} else if (stage->scope == "warp" && ts.rank == 1) {
// To determine whether a thread index is inside or outside a warp, we need
// to know the thread extent. We leave a warning for now.
if (ts.dim_index == 0) {
value_map[iv] = var;
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this condition mean that "warp" is actually treated as "local"? I don't know whether there is a such use case. If so, it may be better to do sanity check, convert from "warp" to "local", and give warning before making loop nests. It is at least worth a test to lock the use case down. If not, giving error message here seems more reasonable. A test is also worthy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me explain this piece of code:

Line 167: If we are using a local storage, then each iteration variable iv accounts to the accessing offset, but thread indices or block indices do not. For example, when accessing a local buffer a, it should be something like a[i], instead of a[blockIdx.x, threadIdx.x, i].

Line 168: There are two cases:

  1. If we are using a global storage, then all indices account to the offset. For example, b[blockIdx.x, threadIdx.x, i].
  2. If we are using a shared storage, then variables inside a block, i.e. iteration variables and thread indices (storage rank <= thread rank) account to the offset, but block indices (storage rank > thread rank) do not. For example, when accessing a shared buffer c, it should be c[threadIdx.x, i], instead of c[blockIdx.x, threadIdx.x, i].

Now we come to warp storage. Ideally, all the variables inside a warp should account to the offset, but those outside a warp do not. It is complex to determine whether a variable is inside a warp, but usually only threadIdx.x is inside, so we made the assumption, and give a warning otherwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for the reason of making the message a warning rather than an error, consider 3 situations:

  1. One uses threadIdx.x only. No threadIdx.y or threadIdx.z, no warning is shown.
  2. One uses threadIdx.x and threadIdx.y, and the extent of threadIdx.x is 32 (warp size), which is a common use case. Now a warning is shown, but the code is correct.
  3. One uses threadIdx.x and threadIdx.y, but the extent of threadIdx.x < 32, now a warning is shown, and the code is wrong.

We can still proceed in Situation 2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the elaboration.

I added line 167/168 a few weeks ago. I am glad that this PR is fixing for the "warp" scope, as I wasn't very sure about it and left it with the old behavior to avoid one test failure. Your explanation is almost same as what I thought: when a StorageScope is only visible to a ThreadScope (such as shared to block or local to thread), we don't access it with thread IterVars (such as blockIdx or threadIdx), i.e. that they are not account to the offset.

If my idea makes sense, I am confused by this statement

but those outside a warp do not

and the second situation. In my mind, the conceptual "warp" storage is accessible by multiple threads, it makes sense to keep threadIdx IterVar to offset the access. But I don't understand why things change when it comes to outside a warp. In the 2nd situation, is cross "warp" access allowed?

In the 1st situation, does threadIdx.x's extent matter here or is it handled by warp memory lowering?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an example for the 2nd situation: Consider extent of threadIdx.x = 32, extent of threadIdx.y = 4. Now there are 4 warps, each consisting of 32 threads.

In the 2nd situation, is cross "warp" access allowed?

No, all of the 128 threads access a warp storage a, but threads in different warps are actually accessing different group of registers, although they are all called a. Only threads in the same warp are accessing the same group of registers. This is what I mean by inside/outside. Just like shared memory, different blocks access different piece of SRAM, even their variable name is the same.

In the 1st situation, does threadIdx.x's extent matter here or is it handled by warp memory lowering?

It dose not matter. lower_warp_memory pass will handle different exntents.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

threadIdx.x means a lower level of hierarchy than threadIdx.y and threadIdx.z for "warp" scope -- "warp" memory is local to threadIdx.y and threadIdx.z, but shared by threadIdx.x. One idea for you to consider is to change threadIdx.x's rank to 2 and keep threadIdx.y and threadIdx.z's rank 1. It will avoid special casing "warp" here, but may cause other regression.

In fact, things are more complicated because threadIdx can be less than the warp size. To determine whether an index is inside or outside a warp, we need to know the extent of that index. But this information is not available in MakeLoopNest.

Does dom and/or bind_iv->dom have the extent for the checks?

Here is an example for the 2nd situation: Consider extent of threadIdx.x = 32, extent of threadIdx.y = 4. Now there are 4 warps, each consisting of 32 threads.

This example clarifies a lot. Is it possible to make it a test to cover the "warp" specific code here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

threadIdx.x means a lower level of hierarchy than threadIdx.y and threadIdx.z for "warp" scope -- "warp" memory is local to threadIdx.y and threadIdx.z, but shared by threadIdx.x. One idea for you to consider is to change threadIdx.x's rank to 2 and keep threadIdx.y and threadIdx.z's rank 1. It will avoid special casing "warp" here, but may cause other regression.

I agree. Currently there are some hard-coded rank. It may be difficult to modify the rank definition.

Does dom and/or bind_iv->dom have the extent for the checks?

Yes, dom does. But where can we get the warp size? Warp size is not a constant value, it is 32 for NVIDIA GPU and 64 for AMD GPU. The warp size information is stored in Target class. Since MakeLoopNest is designed to be a target-irrelevant function, I don't think it a good idea to add an argument to specify the target.

This example clarifies a lot. Is it possible to make it a test to cover the "warp" specific code here?

Actually I did try to run this example, but TVM somehow generated a correct code. It only led to a wrong boundary checking bug in a very complex program.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, dom does. But where can we get the warp size? Warp size is not a constant value, it is 32 for NVIDIA GPU and 64 for AMD GPU. The warp size information is stored in Target class. Since MakeLoopNest is designed to be a target-irrelevant function, I don't think it a good idea to add an argument to specify the target.

I agree with you about MakeLoopNest being target independent. My first thought is to defer the check and threadIdx.x vs. threadIdx.y/z handling till LowerWarpMemory(), such as adjusting indices in WarpAccessRewriter(), since it does target specific "warp" lowering. Maybe substitute threadIdx.y/z with 0 in the group indices for the supported cases and give error otherwise? However, it seems not the case, as your 3rd situation ends with incorrect code instead of an error from LowerWarpMemory(). But I don't know the reason.

Actually I did try to run this example, but TVM somehow generated a correct code. It only led to a wrong boundary checking bug in a very complex program.

I see. Does your simplified case trigger the warning? If so, checking for the warning can guard your changes from being accidentally deleted or skipped.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, it seems not the case, as your 3rd situation ends with incorrect code instead of an error from LowerWarpMemory(). But I don't know the reason.

Actually I mean:

  1. The 1st situation has no problem, before and after this PR.
  2. The 2nd situation led to incorrect code before this PR, and correct code after this PR. Plus, we will see a warning after this PR.
  3. The 3rd situation is currently not supported by lower_warp_memory, which will lead to an error. Plus, we will see a warning after this PR.

No matter how, I still have no idea how the 2nd situation ends up with incorrect code.

I see. Does your simplified case trigger the warning? If so, checking for the warning can guard your changes from being accidentally deleted or skipped.

Actually no. Let's have a look at some details.

Here's my simplified example:

import tvm
import topi
import numpy as np

from tvm import te

n = 32
A = te.placeholder((n, n), name='A', dtype="float32")
C = te.compute((n, n), lambda i, j: A(i, (j + 1) % n), name='C')

s = te.create_schedule(C.op)
th_y = te.thread_axis("threadIdx.y")
th_x = te.thread_axis("threadIdx.x")
B = s.cache_read(A, "warp", [C])
ci, cj = C.op.axis
bi, bj = B.op.axis
s[C].bind(ci, th_y)
s[C].bind(cj, th_x)
s[B].compute_at(s[C], ci)
s[B].bind(bj, th_x)

print(tvm.lower(s, [A, C]))

And here's the result, which is unexpectedly correct before this PR.

PrimFunc([A, C]) attrs={"tir.noalias": (bool)1, "global_symbol": "main"} {
  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 32
  // attr [A.warp] storage_scope = "warp"
  allocate A.warp[float32 * 32]
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32
  A.warp[threadIdx.x] = A[((threadIdx.y*32) + threadIdx.x)]
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32
  C[((threadIdx.y*32) + threadIdx.x)] = A.warp[floormod((threadIdx.x + 1), 32)]
}

The if (stage->scope == "warp" && ts.rank == 1) branch in the modified code is only triggered once, where ts.dim_index == 0. I don't know why the ts.dim_index == 1 IterVar is ignored.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified the simplified example a little to bind threadIdx.y in the warp stage to let threadIdx.y pass through the new code.

import tvm
import topi
import numpy as np

from tvm import te

n = 32
A = te.placeholder((2, n, n), name='A', dtype="float32")
C = te.compute((2, n, n), lambda x, i, j: A(x, i, (j + 1) % n), name='C')

s = te.create_schedule(C.op)
bk_x = te.thread_axis("blockIdx.x")
th_y = te.thread_axis("threadIdx.y")
th_x = te.thread_axis("threadIdx.x")
B = s.cache_read(A, "warp", [C])
cx, ci, cj = C.op.axis
bx, bi, bj = B.op.axis
# s[C].bind(ci, th_y)
s[C].bind(cj, th_x)
s[C].bind(cx, bk_x)
s[B].compute_at(s[C], cx)
s[B].bind(bi, th_y)
s[B].bind(bj, th_x)

print(tvm.lower(s, [A, C]))
func = tvm.build(s, [A, C], target="cuda", name='tid')
print(func.imported_modules[0].get_source())

The three situations make a good summary.
1st one already has at least one test in tests/python/unittest/test_tir_transform_lower_warp_memory.py.
I hope the above code can lock down the 2nd situation and probably the error for the 3rd one by reducing threadIdx.x's extent.

Warp has been special cased in several places, e.g. in bound.cc and here before this PR. I tried to push back to add more special case code, but I am Ok the accept the current change. Please try to add tests.

LOG(WARNING)
<< "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. "
<< "TVM assumes only threadIdx.x indicates threads inside a warp, "
<< "while threadIdx.y and threadIdx.z indicates different warps.";
value_map[iv] = dom->min;
}
} else {
value_map[iv] = dom->min;
}
Expand Down
37 changes: 37 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_warp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,42 @@ def test_lower_warp_memory_local_scope():
assert(fdevice.body.body.value.value == "local")
assert(fdevice.body.body.body.extents[0].value == 2)

def test_lower_warp_memory_correct_indices():
n = 32
A = te.placeholder((2, n, n), name='A', dtype="float32")
C = te.compute((2, n, n), lambda x, i, j: A(x, i, (j + 1) % n), name='C')

s = te.create_schedule(C.op)
bk_x = te.thread_axis("blockIdx.x")
th_y = te.thread_axis("threadIdx.y")
th_x = te.thread_axis("threadIdx.x")
B = s.cache_read(A, "warp", [C])
cx, ci, cj = C.op.axis
bx, bi, bj = B.op.axis
s[C].bind(cj, th_x)
s[C].bind(cx, bk_x)
s[B].compute_at(s[C], cx)
s[B].bind(bi, th_y)
s[B].bind(bj, th_x)

bounds = tvm.te.schedule.InferBound(s)
ir = tvm.te.schedule.ScheduleOps(s, bounds)
inner_func = ir.body.body.body.body
store_A_warp = inner_func.body.seq[0].body.body
indices = list(store_A_warp.args)

# A.warp is actually many buffers, one for each warp, although they are all called A.warp
# 1. If we are accessing from different threads within a same warp (different
# threadIdx.x), we need to distinguish between each elements using threadIdx.x,
# so threadIdx.x is one if the indices.
# 2. If we are accessing from different warps (different threadIdx.y), we are actually
# assessing different buffers, so there is no need to distinguish from elements,
# and therefore threadIdx.y is NOT a index.
idx_names = map(lambda x: x.name,
filter(lambda x: type(x) is tvm.tir.expr.Var, indices))
assert "threadIdx.x" in idx_names
assert "threadIdx.y" not in idx_names

def test_lower_warp_memory_cuda_end_to_end():
def check_cuda(dtype):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
Expand Down Expand Up @@ -182,6 +218,7 @@ def check_cuda(dtype):

if __name__ == "__main__":
test_lower_warp_memory_local_scope()
test_lower_warp_memory_correct_indices()
test_lower_warp_memory_cuda_end_to_end()
test_lower_warp_memory_cuda_half_a_warp()
test_lower_warp_memory_cuda_2_buffers()