Skip to content

Commit

Permalink
[docs] fixed codebase_walkthrough document bug (#13008)
Browse files Browse the repository at this point in the history
When I was studying the "TVM Codebase Walkthrough by Example" document, I found that the code didn't work, so I fixed it.

Bind the iteration axis to threads in the GPU.
  • Loading branch information
wufeng15226 committed Oct 8, 2022
1 parent 189338c commit d92d47a
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion docs/dev/tutorial/codebase_walkthrough.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,14 @@ This function is mapped to the C++ function in ``include/tvm/schedule.h``.

``Schedule`` and ``Stage`` are defined in ``tvm/python/te/schedule.py``, ``include/tvm/te/schedule.h``, and ``src/te/schedule/schedule_ops.cc``.

To keep it simple, we call ``tvm.build(...)`` on the default schedule created by ``create_schedule()`` function above.
To keep it simple, we call ``tvm.build(...)`` on the default schedule created by ``create_schedule()`` function above, and we must add necessary thread bindings to make it runnable on GPU.

::

target = "cuda"
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, tvm.te.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.te.thread_axis("threadIdx.x"))
fadd = tvm.build(s, [A, B, C], target)

``tvm.build()``, defined in ``python/tvm/driver/build_module.py``, takes a schedule, input and output ``Tensor``, and a target, and returns a :py:class:`tvm.runtime.Module` object. A :py:class:`tvm.runtime.Module` object contains a compiled function which can be invoked with function call syntax.
Expand Down

0 comments on commit d92d47a

Please sign in to comment.