New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] The optimal implementation of reduce_sum searched by Ansor is more than 30x slower than torch.sum #15342
Comments
Hi @MrJungle1 , firstly, the previous CUDA codegen is suboptimal, some recent improvements such as multi-warp reduction (#15327) should help. Besides, the way you profile PyTorch CUDA kernels is problematic because of synchronous execution, |
Hi, @yzh119 thank you for your patience to reply. Yes, I have also seen the improvements about (#15327), but it seems to be for TIR, and there is no mention in Ansor, in addition, when I use metaschedule search, it has no effect. Thanks for your correction, I have added torch.cuda.synchronize(), Ansor is still 6x slower than torch |
Hi @MrJungle1 , TE is being deprecated, there is a create_prim_func function that turning a TE to TIR. Ansor is no longer the preferred auto-scheduler because it lacks support for Tensor Cores. Metaschedule should work, I tried using metaschedule to tune the operator, the script is attached below: import tvm
import tvm.te as te
from tvm import meta_schedule as ms
a, b, c, d = 1024, 8, 17, 17
A = te.placeholder((a, b, c, d), name="A", dtype="float32")
k1 = te.reduce_axis((0, a), name="k1")
k2 = te.reduce_axis((0, c), name="k2")
k3 = te.reduce_axis((0, d), name="k3")
C = te.compute(
(b,),
lambda i: te.sum(A[k1, i, k2, k3], axis=[k1, k2, k3]),
name="C",
)
f = te.create_prim_func([A, C])
mod = tvm.IRModule.from_expr(f)
target = tvm.target.Target("nvidia/geforce-rtx-3090", host="llvm")
database = ms.tune_tir(
mod=mod,
target=target,
max_trials_global=64,
num_trials_per_iter=64,
space="cuda",
work_dir="./tune_tmp",
)
sch = ms.tir_integration.compile_tir(database, mod, target)
f = tvm.build(sch.mod["main"], target=target)
print(f.imported_modules[0].get_source()) #15327 indeed improves performance, before this PR, searching for a TVM kernel tuned by metaschedule has latency import torch
from typing import List, Callable, Any, Tuple, Union
def profile_pytorch_ms(f: Callable[[], None]) -> float:
r"""
Use Triton's profiler that flushes L2 cache.
"""
n_wait = 1
n_warmup = 10
n_repeat = 100
"""The following code copied from Triton profiler."""
cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda")
start_event = [
torch.cuda.Event(enable_timing=True) for i in range(n_repeat)
]
end_event = [
torch.cuda.Event(enable_timing=True) for i in range(n_repeat)
]
# Warm-up
for _ in range(n_warmup):
f()
# Benchmark
for i in range(n_repeat):
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`
start_event[i].record()
f()
end_event[i].record()
# Record clocks
torch.cuda.synchronize()
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)])
dur = torch.mean(times).item()
return dur
x = torch.randn(1024, 8, 17, 17).float().to(0)
print("{} ms".format(profile_pytorch_ms(lambda: torch.sum(x, (0, 2, 3))))) |
I tried to manually merge the last 2 dimensions in your input tensor (which has the same semantics as the original operator), and then TVM kernel tuned by metaschedule get a latency of import tvm
import tvm.te as te
import tvm.dlight as dl
from tvm import meta_schedule as ms
a, b, c = 1024, 8, 289
A = te.placeholder((a, b, c), name="A", dtype="float32")
k1 = te.reduce_axis((0, a), name="k1")
k2 = te.reduce_axis((0, c), name="k2")
C = te.compute(
(b,),
lambda i: te.sum(A[k1, i, k2], axis=[k1, k2]),
name="C",
)
f = te.create_prim_func([A, C])
mod = tvm.IRModule.from_expr(f)
target = tvm.target.Target("nvidia/geforce-rtx-3090", host="llvm")
database = ms.tune_tir(
mod=mod,
target=target,
max_trials_global=64,
num_trials_per_iter=64,
space="cuda",
work_dir="./tune_tmp",
)
sch = ms.tir_integration.compile_tir(database, mod, target)
f = tvm.build(sch.mod["main"], target=target)
print(f.imported_modules[0].get_source()) |
Hi @yzh119 Thank you for your patience to solve the problem for me and provide the script, but from your conclusion, in the example I provided, even using metaschedule search is 1.5x slower than torch. You mentioned about the merged two-dimensional experiment, why can’t you find it in metaschedule, because in my opinion, it should be able to be searched, I don’t know if my understanding is wrong, I hope to hear your understanding. |
It suggests that our schedule rules are not optimal :). I think merging the reduction axes is a general optimization rule that it's worthwhile to consider as a |
Hi @yzh119 Thank you for your patience and reply. I will try to optimize this part. If there is progress, I am very happy to contribute my code. In addition, the merge reduction axis should be done in the first step when generating the sketch in Ansor. According to my understanding for the printing of the initial stage, Ansor will merge all the axes of the reduction together. I haven’t looked at the metaschedule carefully. I thought it would be the same as Ansor. In addition, I tried the pytorch kernel script you provided. Why is it different after running several times? The gap is quite large, respectively 0.105ms, 0.127ms, and 0.100ms. What is the reason? |
@MrJungle1 , did you have other programs running on the same GPU you are using for profiling kernels (e.g. xorg), please kill them before profiling. Also, you can lock your GPU frequency via: sudo nvidia-smi -pm 1
sudo nvidia-smi -ac MEMORY_CLOCK,GRAPHICS_CLOCK The value of MEMORY_CLOCK and GRAPHICS_CLOCK is GPU dependent, and you can query the possible values via nvidia-smi: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries Then the profiling results should be stable. |
Ok, thank you for your patience and reply. |
Expected behavior
the optimal implementation of reduce_sum searched by Ansor will have a performance similar to that of torch.sum
Actual behavior
But the optimal implementation of reduce_sum searched by Ansor is more than 30x slower than torch.sum
Environment
Any environment details, such as: Operating System, TVM version, etc
TVM version:0.12.0 release
NVCC:11.0
Steps to reproduce
Triage
Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).
The text was updated successfully, but these errors were encountered: