Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTX-MMA] Add full PTX MMA code generation support #9909

Merged
merged 1 commit into from
Jan 24, 2022

Conversation

KnowingNothing
Copy link
Contributor

@KnowingNothing KnowingNothing commented Jan 12, 2022

This change adds full (although not all) PTX MMA code generation support for three generations of Tensor Core, including Volta, Turing, and Ampere. The generation logic is mainly implemented in ptx_mma.cc and should have no major influence on existing code. A test file is also provided in tests/python/unittest/test_tir_ptx_mma.py. Here is a list of limitations and further improvement is possible:

  1. Correctness tests for int4 and binary MMA instructions are missing because NumPy has no support for int4 and binary kernels.
  2. Tf32 and bf16 instructions are supported, but no tests are provided because as far as I know, these data types are not natively supported by TVM.
  3. Implementation for binary MMA generates mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc for uint1 and mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc for int1. This may not be a perfect decision.

Copy link
Member

@vinx13 vinx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor issues otherwise LGTM. There is some (unrelated) test errors on CI, could you try pushing again to restart the CI?

namespace tvm {
namespace codegen {

std::string PrintPTXAssembly(const std::string& shape, const std::string& A_layout,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe PrintMMAAssembly would be a better name

golden = np.matmul(A_np.astype("float64"), B_np.astype("float64").T)

C_numpy = C_tvm.numpy()
from tvm import testing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not needed as tvm.testing is already imported at the beginning

/*
* TODO: add mma.m16n8k128
*/
return "";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return "";
ICHECK(0);
throw;

if this is unreachable, just raises an error

for i in range(4):
Accum[i] = T.float32(0)

for mma_multi_a_col in T.vectorized(4):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I wonder if you could elaborate more on the necessity of the declarations of MultiA, MultiB and Accum buffers here. Do buffers like A, B and C not work within the MMA assembly code generated below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To use MMA instructions, the multiplicands and accumulator should be placed in registers, otherwise, the behavior is undefined. I have tried to use global buffers (e.g., A, B, C) to invoke MMA instructions, and the results are all wrong.

MultiA[mma_multi_a_col] = A[
(tx % 32) // 4 + mma_multi_a_col // 2 * 8, (tx % 32) % 4 * 2 + mma_multi_a_col % 2
]
for mma_multi_b_col in T.vectorized(4):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can combine the three loops to initialize MulitA MultiB and possibly Accum given the loop invariant are the same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it is more clear to make them separate because for people who are not familiar with CUDA or MMA, they can tell that the load of MultiA, MultiB, and the initialization of Accum are decoupled, which is also in accord with the pattern of the code generated by TVM.

"fp16",
"fp32",
MultiA,
0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the use of MultiA MultiB and Accum make the bias/offset here unnecessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe there is another way to implement the interface. I followed the existing manner of tvm_mma_sync.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally it is necessary if the buffer larger than required by mma.


A_np = np.random.uniform(-1, 1, [16, 8]).astype("float16")
B_np = np.random.uniform(-1, 1, [8, 8]).astype("float16")
C_np = np.random.uniform(-1, 1, [16, 8]).astype("float32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should't the value of C_np be zeros?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK to set C_np to random values, although the most standard way is to set C_np to zeros. The results are not affected by the initial value of C_np because the accumulators are always initialized to zeros.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I am worried about the implication may confuse people.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed to np.zeros.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that random initialization does follow the convention we have in the tvm repo, so i dont think it confuses anybody. changing to zeros should good too, so i dont have strong opinion

@junrushao
Copy link
Member

Some tests are failing (probably not relevant to this PR). Retriggering

@junrushao
Copy link
Member

Failed again. @KnowingNothing would you mind checking the unittests also on your side?

@KnowingNothing
Copy link
Contributor Author

It also failed on my local machine.

$ pytest tests/python/frontend/pytorch/qnn_test.py::test_serialized_modules
enabled targets: llvm; llvm -device=arm_cpu; cuda; cuda -model=unknown -libs=cudnn; nvptx; opencl; opencl -device=mali,aocl_sw_emu; opencl -device=intel_graphics
pytest marker:
============================================================== test session starts ===============================================================
platform linux -- Python 3.8.10, pytest-6.2.5, py-1.11.0, pluggy-1.0.0
rootdir: /home/zchno/TVM/tvm-mirror-pr
collected 1 item

tests/python/frontend/pytorch/qnn_test.py Fatal Python error: Aborted

Current thread 0x00007f149c7e1740 (most recent call first):
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/torch/jit/_serialization.py", line 161 in load
  File "/home/zchno/TVM/tvm-mirror-pr/tests/python/frontend/pytorch/qnn_test.py", line 513 in test_serialized_modules
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/python.py", line 183 in pytest_pyfunc_call
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/python.py", line 1641 in runtest
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/runner.py", line 162 in pytest_runtest_call
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/runner.py", line 255 in <lambda>
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/runner.py", line 311 in from_call
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/runner.py", line 254 in call_runtest_hook
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/runner.py", line 215 in call_and_report
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/runner.py", line 126 in runtestprotocol
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/runner.py", line 109 in pytest_runtest_protocol
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/main.py", line 348 in pytest_runtestloop
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/main.py", line 323 in _main
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/main.py", line 269 in wrap_session
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/main.py", line 316 in pytest_cmdline_main
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/config/__init__.py", line 162 in main
  File "/home/zchno/venv/prime/lib/python3.8/site-packages/_pytest/config/__init__.py", line 185 in console_main
  File "/home/zchno/venv/prime/bin/pytest", line 8 in <module>
Aborted (core dumped)

@vinx13
Copy link
Member

vinx13 commented Jan 16, 2022

@KnowingNothing Can you try rebasing and testing again?

@vinx13
Copy link
Member

vinx13 commented Jan 18, 2022

I checked the unit test and confirmed it is caused by this commit. It seems the error only happens when std::regex is used, probably because of C++ ABI incompatibility with libtorch.
@junrushao1994 Do you have more insights on this?

@junrushao
Copy link
Member

@vinx13 My experience with std::regex is overwhelmingly negative. If it's the source of these bugs, let's consider other alternatives

@junrushao
Copy link
Member

CC: @jinhongyii

@KnowingNothing
Copy link
Contributor Author

I tried to replace std::regex with normal string operations. Hope this will work.

@vinx13 vinx13 merged commit d066441 into apache:main Jan 24, 2022
@junrushao
Copy link
Member

Thanks! This is huge

yuanfz98 pushed a commit to yuanfz98/tvm that referenced this pull request Jan 24, 2022
ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
junrushao pushed a commit that referenced this pull request Apr 3, 2022
…y to warp memory (#10855)

We already have PTX mma and mma.sp builtin support in #9909  and #10339 . However, we have not supported corresponding data movement builtins for these mma instructions, so the data movement would not be as fast as wmma.

This PR brings the `ldmatrix` builtin, which is a native PTX warp-level instruction (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix), and we can use it to load several (1/2/4) 8x8 matrices from shared memory to warp memory.
pfk-beta pushed a commit to pfk-beta/tvm that referenced this pull request Apr 11, 2022
…y to warp memory (apache#10855)

We already have PTX mma and mma.sp builtin support in apache#9909  and apache#10339 . However, we have not supported corresponding data movement builtins for these mma instructions, so the data movement would not be as fast as wmma.

This PR brings the `ldmatrix` builtin, which is a native PTX warp-level instruction (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix), and we can use it to load several (1/2/4) 8x8 matrices from shared memory to warp memory.
mehrdadh pushed a commit to mehrdadh/tvm that referenced this pull request Apr 11, 2022
…y to warp memory (apache#10855)

We already have PTX mma and mma.sp builtin support in apache#9909  and apache#10339 . However, we have not supported corresponding data movement builtins for these mma instructions, so the data movement would not be as fast as wmma.

This PR brings the `ldmatrix` builtin, which is a native PTX warp-level instruction (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix), and we can use it to load several (1/2/4) 8x8 matrices from shared memory to warp memory.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants