Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create loops according to storage scope and thread hierarchies #5190

Merged
merged 8 commits into from Apr 10, 2020

Conversation

yongfeng-nv
Copy link
Contributor

@yongfeng-nv yongfeng-nv commented Mar 31, 2020

The following small example generates wrong IR and code:

from __future__ import absolute_import, print_function
import tvm
import numpy as np

# Input declarations
m = 1
n = 3
p = 2
A = tvm.te.placeholder((m, n, p), name='A')
B = tvm.te.compute((m, n, p), lambda bi, bj, bk: A[bi, bj, bk], name="B")
C = tvm.te.compute((m, n, p), lambda ci, cj, ck: B[ci, cj, ck], name="C")
s = tvm.te.create_schedule(C.op)
bx = tvm.te.thread_axis("blockIdx.x")
tx = tvm.te.thread_axis("threadIdx.x")
s[B].compute_at(s[C], s[C].op.axis[0])
s[B].set_scope("local")
bno, bni = s[B].split(s[B].op.axis[1], n)
s[C].bind(s[C].op.axis[0], bx)
s[C].bind(s[C].op.axis[1], tx)
s[B].bind(bni, tx)
from tvm.contrib import tedd
tedd.viz_dataflow_graph(s, True, '/tmp/dfg.dot')
tedd.viz_schedule_tree(s, True, '/tmp/scheduletree.dot')
tedd.viz_itervar_relationship_graph(s, True, '/tmp/itervar.dot')
print(tvm.lower(s, [A, C], simple_mode=True))
fcuda = tvm.build(s, [A, C], "cuda")
print(fcuda.imported_modules[0].get_source())
ctx = tvm.context("cuda", 0)

# Random tensors for testing
a = tvm.nd.array(
    np.random.rand(A.shape[0].value, A.shape[1].value,
                   A.shape[2].value).astype("float32"), ctx)
c = tvm.nd.array(
    np.random.rand(C.shape[0].value, C.shape[1].value,
                   C.shape[2].value).astype("float32"), ctx)

fcuda(a, c)
result = c.asnumpy()
answer = a.asnumpy()
tvm.testing.assert_allclose(result, answer, rtol=1e-5)
evaluator = fcuda.time_evaluator(fcuda.entry_name, ctx, number=1)
print(fcuda.entry_name + ': %f ms' % (evaluator(a, c).mean * 1e3))

The IR:

produce C {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
  // attr [B] storage_scope = "local"
  allocate B[float32 * 2]
  produce B {
    // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 3
    for (bk, 0, 2) {
      if (likely((threadIdx.x < 1))) {
        if (likely((threadIdx.x < 2))) {
          B[((threadIdx.x*2) + bk)] = A[((threadIdx.x*4) + bk)]
        }
      }
    }
  }
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 3
  for (ck, 0, 2) {
    C[((threadIdx.x*2) + ck)] = B[ck]
  }
}

The direct problem is a wrong predicate for stage B:

if (likely((threadIdx.x < 1)))

A more fundamental cause is that loop nest creation doesn't consider thread/memory hierarchies. In this example, stage B's scope is local. InferBound, needRelax in particular, doesn't relax the IterVar binding to a threadIdx, because the memory space is only used within the thread. However, when creating loops, schedule_ops uses the IterVar to offset the index to the memory regardless hierarchies. This is wrong, as the code uses threadIdx to address local memory.

The fix is straightforward -- MakeLoopNest adds the thread bound IterVar to the index only if the storage rank is lower than the thread rank (same the logic as needRelax() in bound.cc). vthread doesn't need this change, as it doesn't reduce storage scope, i.e. that vthreads still share one local memory space.

Similar issue was identified in this thread: https://discuss.tvm.ai/t/how-to-avoid-allocating-the-wrong-amount-of-registers-in-cuda-scheduling-in-this-example/5310, although it has more to address. I add this example as a new test.

Questions:

  1. I leave "warp" with existing behavior to avoid one test failure (tests/python/unittest/test_tir_transform_lower_warp_memory.py). The IR from the test doesn't have the undesired predicate. Please confirm.
  2. The current storage ranks and threads ranks make sense to GPU. How about other architectures? Shall these ranks and their relationship be arch specific?
  3. Shall I refactor the rank comparison logic to a central location, such as thread_storage_scope.h?

@yongfeng-nv yongfeng-nv changed the title WIP: Fix thread local bound (Don't merge) Create loops according to storage scope and thread hierachies. Apr 6, 2020
@yongfeng-nv yongfeng-nv changed the title Create loops according to storage scope and thread hierachies. Create loops according to storage scope and thread hierarchies Apr 6, 2020
@tqchen
Copy link
Member

tqchen commented Apr 6, 2020

@Hzfengsy @vinx13 @ZihengJiang can you help to take a look?

…al_stage_predicate; remove test_schedule_schedule_ops.py which was added by mistake.
@tqchen tqchen merged commit 3d09e64 into apache:master Apr 10, 2020
@tqchen
Copy link
Member

tqchen commented Apr 10, 2020

Thanks @yongfeng-nv ! this PR is now merged

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Apr 16, 2020
…e#5190)

* Set IterVar index to 0 for local thread bound IterVars.

* Lint fix

* Use rank instead of scope name to predicate.  Add tests.

* Handle cases other than local/threadIdx.

* Turn warp to the old behavior.

* Modify test to cover global/blockIdx.

* Fix a typo.

* Update test_te_schedule_ops.py with more testing coverage in test_local_stage_predicate; remove test_schedule_schedule_ops.py which was added by mistake.
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Apr 17, 2020
…e#5190)

* Set IterVar index to 0 for local thread bound IterVars.

* Lint fix

* Use rank instead of scope name to predicate.  Add tests.

* Handle cases other than local/threadIdx.

* Turn warp to the old behavior.

* Modify test to cover global/blockIdx.

* Fix a typo.

* Update test_te_schedule_ops.py with more testing coverage in test_local_stage_predicate; remove test_schedule_schedule_ops.py which was added by mistake.
dpankratz pushed a commit to dpankratz/incubator-tvm that referenced this pull request Apr 24, 2020
…e#5190)

* Set IterVar index to 0 for local thread bound IterVars.

* Lint fix

* Use rank instead of scope name to predicate.  Add tests.

* Handle cases other than local/threadIdx.

* Turn warp to the old behavior.

* Modify test to cover global/blockIdx.

* Fix a typo.

* Update test_te_schedule_ops.py with more testing coverage in test_local_stage_predicate; remove test_schedule_schedule_ops.py which was added by mistake.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants