Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] The optimal implementation of reduce_sum searched by Ansor is more than 30x slower than torch.sum #15342

Open
MrJungle1 opened this issue Jul 18, 2023 · 10 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@MrJungle1
Copy link

Expected behavior

the optimal implementation of reduce_sum searched by Ansor will have a performance similar to that of torch.sum

Actual behavior

But the optimal implementation of reduce_sum searched by Ansor is more than 30x slower than torch.sum

Environment

Any environment details, such as: Operating System, TVM version, etc
TVM version:0.12.0 release
NVCC:11.0

Steps to reproduce

image
image
image

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • tune:auto_scheduler
@MrJungle1 MrJungle1 added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Jul 18, 2023
@MrJungle1 MrJungle1 changed the title [Bug] [Bug] The optimal implementation of reduce_sum searched by Ansor is more than 30x slower than torch.sum Jul 18, 2023
@yzh119
Copy link
Member

yzh119 commented Jul 18, 2023

Hi @MrJungle1 , firstly, the previous CUDA codegen is suboptimal, some recent improvements such as multi-warp reduction (#15327) should help.

Besides, the way you profile PyTorch CUDA kernels is problematic because of synchronous execution, time.time() would return before the kernel execution finishes. please either insert torch.cuda.synchronize() barriers or uses PyTorch's native profiler.

@MrJungle1
Copy link
Author

MrJungle1 commented Jul 19, 2023

Hi, @yzh119 thank you for your patience to reply. Yes, I have also seen the improvements about (#15327), but it seems to be for TIR, and there is no mention in Ansor, in addition, when I use metaschedule search, it has no effect.

Thanks for your correction, I have added torch.cuda.synchronize(), Ansor is still 6x slower than torch
image

@yzh119
Copy link
Member

yzh119 commented Jul 19, 2023

Hi @MrJungle1 , TE is being deprecated, there is a create_prim_func function that turning a TE to TIR.

Ansor is no longer the preferred auto-scheduler because it lacks support for Tensor Cores. Metaschedule should work, I tried using metaschedule to tune the operator, the script is attached below:

import tvm
import tvm.te as te
from tvm import meta_schedule as ms

a, b, c, d = 1024, 8, 17, 17

A = te.placeholder((a, b, c, d), name="A", dtype="float32")
k1 = te.reduce_axis((0, a), name="k1")
k2 = te.reduce_axis((0, c), name="k2")
k3 = te.reduce_axis((0, d), name="k3")

C = te.compute(
    (b,),
    lambda i: te.sum(A[k1, i, k2, k3], axis=[k1, k2, k3]),
    name="C",
)

f = te.create_prim_func([A, C])
mod = tvm.IRModule.from_expr(f)

target = tvm.target.Target("nvidia/geforce-rtx-3090", host="llvm")

database = ms.tune_tir(
    mod=mod,
    target=target,
    max_trials_global=64,
    num_trials_per_iter=64,
    space="cuda",
    work_dir="./tune_tmp",
)

sch = ms.tir_integration.compile_tir(database, mod, target)
f = tvm.build(sch.mod["main"], target=target)
print(f.imported_modules[0].get_source())

#15327 indeed improves performance, before this PR, searching for a TVM kernel tuned by metaschedule has latency 38.7868 us on my machine, and after #15327 merging, the TVM kernel tuned by metaschedule has latency 35.7256 us, as a reference, the running time of PyTorch kernel on my machine is 23.62 us, the script to profile PyTorch kernel is attached below for your reference:

import torch
from typing import List, Callable, Any, Tuple, Union

def profile_pytorch_ms(f: Callable[[], None]) -> float:
    r"""
    Use Triton's profiler that flushes L2 cache.
    """
    n_wait = 1
    n_warmup = 10
    n_repeat = 100
    """The following code copied from Triton profiler."""
    cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda")
    start_event = [
        torch.cuda.Event(enable_timing=True) for i in range(n_repeat)
    ]
    end_event = [
        torch.cuda.Event(enable_timing=True) for i in range(n_repeat)
    ]
    # Warm-up
    for _ in range(n_warmup):
        f()
    # Benchmark
    for i in range(n_repeat):
        # we clear the L2 cache before each run
        cache.zero_()
        # record time of `fn`
        start_event[i].record()
        f()
        end_event[i].record()
    # Record clocks
    torch.cuda.synchronize()
    times = torch.tensor(
        [s.elapsed_time(e) for s, e in zip(start_event, end_event)])
    dur = torch.mean(times).item()
    return dur

x = torch.randn(1024, 8, 17, 17).float().to(0)

print("{} ms".format(profile_pytorch_ms(lambda: torch.sum(x, (0, 2, 3)))))

@yzh119
Copy link
Member

yzh119 commented Jul 19, 2023

I tried to manually merge the last 2 dimensions in your input tensor (which has the same semantics as the original operator), and then TVM kernel tuned by metaschedule get a latency of 23.4150 us, a little bit stronger than PyTorch's one (PyTorch kernel's speed doesn't change if we modify the torch tensor shape to (1024, 8, 289)), the script is attached below (289 = 17 * 17):

import tvm
import tvm.te as te
import tvm.dlight as dl
from tvm import meta_schedule as ms

a, b, c = 1024, 8, 289

A = te.placeholder((a, b, c), name="A", dtype="float32")
k1 = te.reduce_axis((0, a), name="k1")
k2 = te.reduce_axis((0, c), name="k2")

C = te.compute(
    (b,),
    lambda i: te.sum(A[k1, i, k2], axis=[k1, k2]),
    name="C",
)

f = te.create_prim_func([A, C])
mod = tvm.IRModule.from_expr(f)

target = tvm.target.Target("nvidia/geforce-rtx-3090", host="llvm")

database = ms.tune_tir(
    mod=mod,
    target=target,
    max_trials_global=64,
    num_trials_per_iter=64,
    space="cuda",
    work_dir="./tune_tmp",
)

sch = ms.tir_integration.compile_tir(database, mod, target)
f = tvm.build(sch.mod["main"], target=target)

print(f.imported_modules[0].get_source())

@MrJungle1
Copy link
Author

Hi @yzh119 Thank you for your patience to solve the problem for me and provide the script, but from your conclusion, in the example I provided, even using metaschedule search is 1.5x slower than torch. You mentioned about the merged two-dimensional experiment, why can’t you find it in metaschedule, because in my opinion, it should be able to be searched, I don’t know if my understanding is wrong, I hope to hear your understanding.

@yzh119
Copy link
Member

yzh119 commented Jul 19, 2023

It suggests that our schedule rules are not optimal :).

I think merging the reduction axes is a general optimization rule that it's worthwhile to consider as a normalization before we perform auto-tuning (I think @spectrometerHBH has done something like merging spatial axes in dlight). You are welcome to contribute to this part if you are interested.

@MrJungle1
Copy link
Author

Hi @yzh119 Thank you for your patience and reply. I will try to optimize this part. If there is progress, I am very happy to contribute my code. In addition, the merge reduction axis should be done in the first step when generating the sketch in Ansor. According to my understanding for the printing of the initial stage, Ansor will merge all the axes of the reduction together. I haven’t looked at the metaschedule carefully. I thought it would be the same as Ansor. In addition, I tried the pytorch kernel script you provided. Why is it different after running several times? The gap is quite large, respectively 0.105ms, 0.127ms, and 0.100ms. What is the reason?

@MrJungle1
Copy link
Author

MrJungle1 commented Jul 19, 2023

In addition, in another case, shape = (120, 128, 128, 128) axis = (1,2) generated very common code
image
shape = (128, 8, 128, 128) axis = (1,)
image

shape = (128, 8, 127, 127) axis = (0,2,3)
image

@yzh119
Copy link
Member

yzh119 commented Jul 19, 2023

@MrJungle1 , did you have other programs running on the same GPU you are using for profiling kernels (e.g. xorg), please kill them before profiling.

Also, you can lock your GPU frequency via:

sudo nvidia-smi -pm 1
sudo nvidia-smi -ac MEMORY_CLOCK,GRAPHICS_CLOCK

The value of MEMORY_CLOCK and GRAPHICS_CLOCK is GPU dependent, and you can query the possible values via nvidia-smi: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries

Then the profiling results should be stable.

@MrJungle1
Copy link
Author

Ok, thank you for your patience and reply.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

2 participants