Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RDNA3 support #27

Open
WilliamGazeley opened this issue Dec 5, 2023 · 62 comments
Open

RDNA3 support #27

WilliamGazeley opened this issue Dec 5, 2023 · 62 comments
Labels
navi hardware

Comments

@WilliamGazeley
Copy link

Great work so far. I'm trying to run vLLM on my 7900XTX cards and was wondering if there were any plans to support RDNA3?

@sdli1995
Copy link

sdli1995 commented Dec 6, 2023

A CK disscussion has show a branch which has flash-attention kernel impl and already work in ait ROCm/composable_kernel#1032
are there any barrier on RNDA3 support ?

@howiejayz
Copy link
Member

Hi @WilliamGazeley and @sdli1995, I wanted to update you on the attention kernels for the NAVI platform. My colleague @aska-0096 did activate them. However, these Flash-Attention kernels were initially developed for MI devices, which operate on a distinct set of CK kernels.

In essence, the issue is that we haven't yet integrated these kernels into our current API. I plan to work on this integration in my spare time.

@AlpinDale
Copy link

Great work on the fork, @howiejayz

Will it take too long to port the kernels for gfx1100 and other non-MI architectures? I need the kernels for my project and I'd be willing to help out if I can.

@howiejayz
Copy link
Member

Thanks for reaching out and offering to help @AlpinDale! I'm currently tied up with a few other projects, so I can't give an exact timeframe for porting the kernels for gfx1100 and other architectures. But I'm definitely planning to tackle this soon. The first step will be creating a new code path for gfx110x, considering their CK kernels are only for forward ops.

I'm totally open to suggestions or any help you can provide. It'd be great to have some extra hands on this. Let me know if you're interested!

@evshiron
Copy link

evshiron commented Dec 7, 2023

I am a complete novice in this field, but a few months ago I managed to make Composable Kernel, Flash Attention and PyTorch work together for my RX 7900 XTX (see here, sort by performance, and look for the first one). Although I was able get that Flash Attention implementation "working" in the end, the generated images were meaningless, and I gave up because I didn't know how to fix it. Here are the relevant branch links, and I hope they can be of some help to you:

I made a grouped fused kernel by Frankensteining the batched fused kernel, which matched the call signatures in this repo at that time. However, that self-made kernel might just be broken.

@howiejayz
Copy link
Member

Hi @evshiron! First off, I must say I'm seriously impressed by your work! It's quite an achievement, and the resources you've provided are invaluable.

I've had the opportunity to build your implementation on gfx1100, and I'm pleased to report that the build was successful. However, I encountered an issue with the unit tests not passing due to incorrect results in the forward pass:

assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
AssertionError: assert 0.109619140625 <= (2 * 0.000244140625)

which is likely stemming from incorrect parameter settings in the CK kernels. I guess this should be the reason why the output image become meaningless.

Despite this, your work has been immensely helpful! This will massively speed up the navi porting process for the v2 implementation.

@evshiron
Copy link

evshiron commented Dec 7, 2023

@howiejayz

I'm glad that my humble work could be of some help. I am indeed unfamiliar with this field, so I can only leave it to professionals.
Furthermore, as you can see, even though I managed to compile it, the improvement in the benchmark is quite limited (I didn't use the specific commit showed here). I hope it's just an issue with my implementation and I look forward to better performance in future implementations.

@howiejayz
Copy link
Member

howiejayz commented Dec 9, 2023

Guys I have added the batched forward(consistent sequence lengths) support for gfx1100, gfx1101, gfx1102 under this branch. Thanks to @aska-0096's CK kernels. The implementation is still under development and there are a lot of things to fine-tune. For now I see the performance is generally better when head dim = 64

To install just use pip install .

I only had the chance to test it on gfx1100 but I expect it works as well for the other two. Let me know if there is any issue! The docker I used to test is rocm/pytorch:latest where torch==2.1.0

@xzuyn
Copy link

xzuyn commented Dec 10, 2023

under this branch.

benchmark_flash_attention_forward.py works, but benchmark_flash_attention.py doesn't. Forward speeds look pretty nice.

Using a 7900XTX with torch 2.2.0.dev20231209+rocm5.7.

Results for `benchmark_flash_attention_forward.py`
### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 38.52 TFLOPs/s
Pytorch fwd: 13.37 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 39.33 TFLOPs/s
Pytorch fwd: 14.22 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 41.23 TFLOPs/s
Pytorch fwd: 16.04 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 42.06 TFLOPs/s
Pytorch fwd: 18.02 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 42.02 TFLOPs/s
Pytorch fwd: 19.16 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 37.27 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=128, batch_size=32, seqlen=512 ###
Flash2 fwd: 28.27 TFLOPs/s
Pytorch fwd: 20.43 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=128, batch_size=16, seqlen=1024 ###
Flash2 fwd: 29.38 TFLOPs/s
Pytorch fwd: 21.05 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=128, batch_size=8, seqlen=2048 ###
Flash2 fwd: 30.49 TFLOPs/s
Pytorch fwd: 25.23 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=128, batch_size=4, seqlen=4096 ###
Flash2 fwd: 31.00 TFLOPs/s
Pytorch fwd: 26.99 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=128, batch_size=2, seqlen=8192 ###
Flash2 fwd: 27.50 TFLOPs/s
Pytorch fwd: 28.47 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=128, batch_size=1, seqlen=16384 ###
Flash2 fwd: 20.67 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 24.02 TFLOPs/s
Pytorch fwd: 5.07 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 29.08 TFLOPs/s
Pytorch fwd: 5.48 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 33.49 TFLOPs/s
Pytorch fwd: 5.84 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 36.44 TFLOPs/s
Pytorch fwd: 6.21 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 38.54 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 39.70 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=128, batch_size=32, seqlen=512 ###
Flash2 fwd: 17.89 TFLOPs/s
Pytorch fwd: 8.42 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=128, batch_size=16, seqlen=1024 ###
Flash2 fwd: 21.69 TFLOPs/s
Pytorch fwd: 8.68 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=128, batch_size=8, seqlen=2048 ###
Flash2 fwd: 25.64 TFLOPs/s
Pytorch fwd: 9.78 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=128, batch_size=4, seqlen=4096 ###
Flash2 fwd: 27.06 TFLOPs/s
Pytorch fwd: 10.01 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=128, batch_size=2, seqlen=8192 ###
Flash2 fwd: 27.43 TFLOPs/s
Pytorch fwd: 9.87 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=128, batch_size=1, seqlen=16384 ###
Flash2 fwd: 24.68 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
Results for `test_flash_attn_wmma_rocm.py`

=============== 125 failed, 2148 passed, 4606 skipped in 46.01s ================

Full Log: test_flash_attn_wmma_rocm.log

Error for `benchmark_flash_attention.py`
> python benchmarks/benchmark_flash_attention.py

Traceback (most recent call last):
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/flash-attention/benchmarks/benchmark_flash_attention.py", line 97, in <module>
    f, b = time_fwd_bwd(
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/flash-attention/benchmarks/benchmark_flash_attention.py", line 66, in time_fwd_bwd
    time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/flash_attn/utils/benchmark.py", line 99, in benchmark_fwd_bwd
    benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/flash_attn/utils/benchmark.py", line 53, in benchmark_backward
    m = t.timeit(repeats)
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py", line 274, in timeit
    self._timeit(number=max(int(number // 100), 2))
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py", line 264, in _timeit
    return max(self._timer.timeit(number), 1e-9)
  File "/usr/lib/python3.10/timeit.py", line 178, in timeit
    timing = self.inner(it, self.timer)
  File "<timeit-src>", line 6, in inner
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/flash_attn/utils/benchmark.py", line 46, in f
    y.backward(grad, retain_graph=True)
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/torch/_tensor.py", line 503, in backward
    torch.autograd.backward(
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 109, in backward
    _flash_attn_backward(
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 66, in _flash_attn_backward
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
TypeError: bwd(): incompatible function arguments. The following argument types are supported:
    1. () -> None

@AlpinDale
Copy link

Some benchmark results.

RTX 4090

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 150.90 TFLOPs/s,
Pytorch fwd: 20.23 TFLOPs/s,
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 154.49 TFLOPs/s,
Pytorch fwd: 23.69 TFLOPs/s,
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 171.80 TFLOPs/s,
Pytorch fwd: 26.21 TFLOPs/s,
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 172.81 TFLOPs/s,
Pytorch fwd: 27.89 TFLOPs/s,
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 172.96 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 173.04 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,

7900 XTX

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 42.98 TFLOPs/s
Pytorch fwd: 14.38 TFLOPs/s
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 44.31 TFLOPs/s
Pytorch fwd: 14.83 TFLOPs/s
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 48.25 TFLOPs/s
Pytorch fwd: 17.32 TFLOPs/s
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 47.55 TFLOPs/s
Pytorch fwd: 19.40 TFLOPs/s
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 38.40 TFLOPs/s
Pytorch fwd: 20.19 TFLOPs/s
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 41.01 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s

@sdli1995
Copy link

Some benchmark results.

RTX 4090

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 150.90 TFLOPs/s,
Pytorch fwd: 20.23 TFLOPs/s,
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 154.49 TFLOPs/s,
Pytorch fwd: 23.69 TFLOPs/s,
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 171.80 TFLOPs/s,
Pytorch fwd: 26.21 TFLOPs/s,
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 172.81 TFLOPs/s,
Pytorch fwd: 27.89 TFLOPs/s,
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 172.96 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 173.04 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,

7900 XTX

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 42.98 TFLOPs/s
Pytorch fwd: 14.38 TFLOPs/s
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 44.31 TFLOPs/s
Pytorch fwd: 14.83 TFLOPs/s
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 48.25 TFLOPs/s
Pytorch fwd: 17.32 TFLOPs/s
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 47.55 TFLOPs/s
Pytorch fwd: 19.40 TFLOPs/s
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 38.40 TFLOPs/s
Pytorch fwd: 20.19 TFLOPs/s
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 41.01 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s

4090 fp16 accumulate fp16 tensorcore performance is 330T ,while 7900xtx is 120T the better reference nvidia card is rtx3090

@aska-0096
Copy link

Some benchmark results.
RTX 4090

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 150.90 TFLOPs/s,
Pytorch fwd: 20.23 TFLOPs/s,
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 154.49 TFLOPs/s,
Pytorch fwd: 23.69 TFLOPs/s,
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 171.80 TFLOPs/s,
Pytorch fwd: 26.21 TFLOPs/s,
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 172.81 TFLOPs/s,
Pytorch fwd: 27.89 TFLOPs/s,
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 172.96 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 173.04 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,

7900 XTX

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 42.98 TFLOPs/s
Pytorch fwd: 14.38 TFLOPs/s
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 44.31 TFLOPs/s
Pytorch fwd: 14.83 TFLOPs/s
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 48.25 TFLOPs/s
Pytorch fwd: 17.32 TFLOPs/s
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 47.55 TFLOPs/s
Pytorch fwd: 19.40 TFLOPs/s
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 38.40 TFLOPs/s
Pytorch fwd: 20.19 TFLOPs/s
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 41.01 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s

4090 fp16 accumulate fp16 tensorcore performance is 330T ,while 7900xtx is 120T the better reference nvidia card is rtx3090

Thanks for the benchmark data. We are going to launch a new version of Composable Kernel with better flash-attention performance. Adapt the optimization on RDNA3 is in my plan.

@AlpinDale
Copy link

@sdli1995 here's the benchmarks with a 3090:

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 65.38 TFLOPs/s,
Pytorch fwd: 18.38 TFLOPs/s,
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 72.94 TFLOPs/s,
Pytorch fwd: 21.69 TFLOPs/s,
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 74.11 TFLOPs/s,
Pytorch fwd: 18.92 TFLOPs/s,
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 74.98 TFLOPs/s,
Pytorch fwd: 22.27 TFLOPs/s,
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 75.06 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 75.12 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,

@AlpinDale
Copy link

Any updates on this?

@Wintoplay
Copy link

We need official support for flash attention

@ewof
Copy link

ewof commented Dec 24, 2023

trust bro, be patient don't rush them

@gel-crabs
Copy link

I've been using the howiejayz/navi_support branch on here with stable-diffusion-webui for a few weeks now. The implementation is perfect.

On an RX 7800 XT, it speeds it/s up from 1.75 it/s to 2 it/s, all while massively decreasing VRAM usage.

@Kademo15
Copy link

I've been using the howiejayz/navi_support branch on here with stable-diffusion-webui for a few weeks now. The implementation is perfect.

On an RX 7800 XT, it speeds it/s up from 1.75 it/s to 2 it/s, all while massively decreasing VRAM usage.

Could you please provide more information about how. Did you just install the branch install it and it worked out of the box or did you have to change code of the webui you are using ?

@Wintoplay
Copy link

@gel-crabs I failed to install flash-attn for Navi. please give more info

@gel-crabs
Copy link

gel-crabs commented Jan 4, 2024

@Kademo15 @Wintoplay

Alright, I'm going to try to give instructions on how I got this to work. If you're on Arch, I have a very amateur PKGBUILD (requiring --skipinteg) that gets it to work. You need to go into the PKGBUILD and replace GPU_ARCHS=gfx1101 with your GPU's architecture and MAX_JOBS to however many CPU cores you have. I can only confirm it will work on gfx11+. The patch just changes the C++ standard from c++20 to c++17 to allow it to build.

python-flash-attention.tar.gz

If you aren't on Arch, you can generally just follow the commands and install the python wheel file afterwards, in your virtualenv if you're using one. You can clone the repo with git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git -b howiejayz/navi_support --depth=1 in this case.

Now for webui. You will have to use a patch that has been closed since it will be obsolete once AMD finishes xformers support.

https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11902.patch

That's the link to the raw patch file, use patch -p1 < 11902.patch in the webui directory to apply it to your webui. Please run patch -p1 --dry-run < 11902.patch first so it won't screw up your installation if it doesn't apply correctly. We're not done yet, however.

AUTOMATIC1111/stable-diffusion-webui#11902

In the discussion for this patch, I posted a long comment on getting it to work. (AUTOMATIC1111/stable-diffusion-webui#11902 (comment)). The important info from that is the 2 code blocks you need to change manually.

After that, add --flash-attn to your command line arguments and it should work. If you get slower or the same speed, flash attention isn't working. You may get a HIP OOM at the end of generation if you're using a higher resolution than usual, as it needs to switch back to SDP at the end due to not supporting a head dim over 128.

If you get an error involving a flash-attn so not loading, rebuild the PKGBUILD but change CC=clang and CXX=clang++ to CC=gcc and CXX=g++.

@Beinsezii
Copy link

Beinsezii commented Jan 4, 2024

Switching setup.py to c++17 built successfully on gfx1100

Seems to work in transformers, as a 7b model OOMs @ >8k context using the default attention but doesn't even crack 20 gigs with FA2 enabled. Interestingly I lose about 1 t/s though?

I'll have to see if I can monkeypatch it into Diffusers...

@Wintoplay
Copy link

@gel-crabs that and tried FLASH_ATTENTION_INTERNAL_USE_RTN=1 pip install .
it just say FFFFF result for the testing

(I use Debian.)
I have not tried the SD patch though cuz I want it for inference of LLM.

@gel-crabs
Copy link

@gel-crabs that and tried FLASH_ATTENTION_INTERNAL_USE_RTN=1 pip install . it just say FFFFF result for the testing

(I use Debian.) I have not tried the SD patch though cuz I want it for inference of LLM.

Do you mean the unit testing? For that you need to export FLASH_ATTENTION_INTERNAL_UNIT_TEST_MODE=1 and FLASH_ATTENTION_INTERNAL_DETERMINISTIC=1.

You should also set your GPU_ARCHS to your GPU architecture (gfx1100, gfx1101, etc.) and try building with GCC and Clang. I can also only guarantee this will work on ROCM 5.7 and up.

For anything other than SD webui, you will likely have to create a forward function yourself, as it is a PyTorch extension and isn't integrated into PyTorch yet. The implementation is here, but keep in mind it requires my changes as well:

https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11902/files

@j-dominguez9
Copy link

I don't have the knowledge to contribute to this issue, but I'm really rooting for this support feature!

@Kademo15
Copy link

Kademo15 commented Jan 5, 2024

You may get a HIP OOM at the end of generation if you're using a higher resolution than usual, as it needs to switch back to SDP at the end due to not supporting a head dim over 128.

Could you provide me what you found. What‘s the size you can go before it has to fallback to sdp.

@feffy380
Copy link

feffy380 commented Jan 5, 2024

I've been using the howiejayz/navi_support branch on here with stable-diffusion-webui for a few weeks now. The implementation is perfect.

On an RX 7800 XT, it speeds it/s up from 1.75 it/s to 2 it/s, all while massively decreasing VRAM usage.

I found larger resolutions benefit more. On SD1.5, 1088x1088 on an RX 7900 XTX went from 1.67 it/s to 3 it/s while 512px was a more modest 16 it/s to 18 it/s.
VRAM usage also drops dramatically. Generating a 1088x1088 image goes from 18GB down to about 6GB. I don't see any spike at the end, though. Is it specific to SDXL maybe?

Note: If using a venv I found I had to build the wheel with the venv activated. Otherwise, the library complained about undefined symbols and failed to load.

@gel-crabs
Copy link

gel-crabs commented Jan 5, 2024

You may get a HIP OOM at the end of generation if you're using a higher resolution than usual, as it needs to switch back to SDP at the end due to not supporting a head dim over 128.

Could you provide me what you found. What‘s the size you can go before it has to fallback to sdp.

I think I explained wrong; it will always fall back to SDP at the very end of generation (after all the steps have finished), resolution doesn't factor into it.

What I meant is that the massively decreased VRAM usage will (currently..?) not allow you to use a higher resolution than with regular SDP attention, as the VRAM usage will jump back up at the end.

However, it most likely could... if AMD hooked up CK's Navi branch to their new xformers port. ;)

@gel-crabs
Copy link

gel-crabs commented Jan 14, 2024

Also note: the switch back to SDP at the end (I believe only with SDXL) can be prevented by switching from full VAE to TAESD, or (I assume) a tiled VAE implementation.

This allows a 1024x1024 image to be upscaled to 2048x2048 with SDXL on an RX 7800 XT with 16GB of VRAM.

@sabreshao sabreshao added the navi hardware label Jan 16, 2024
@feffy380
Copy link

feffy380 commented Feb 8, 2024

Yes. The forward pass output is within rounding error of a plain pytorch implementation but the LSE varies wildly (mean value shown), without which we can't even substitute the missing backward pass for a pure pytorch implementation
image

@ZhenyaPav
Copy link

python-flash-attention.tar.gz

I have installed flash attention using this PKGBUILD and am getting this error when trying to load a model using exllamav2 in text-generation-webui:

19:47:00-728633 ERROR    Failed to load the model.                                                      
Traceback (most recent call last):
  File "/home/zhenyapav/Projects/text-generation-webui/modules/ui_model_menu.py", line 242, in load_model_wrapper
    shared.model, shared.tokenizer = load_model(selected_model, loader)
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhenyapav/Projects/text-generation-webui/modules/models.py", line 87, in load_model
    output = load_func_map[loader](model_name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhenyapav/Projects/text-generation-webui/modules/models.py", line 371, in ExLlamav2_loader
    from modules.exllamav2 import Exllamav2Model
  File "/home/zhenyapav/Projects/text-generation-webui/modules/exllamav2.py", line 5, in <module>
    from exllamav2 import (
  File "/home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/exllamav2/__init__.py", line 3, in <module>
    from exllamav2.model import ExLlamaV2
  File "/home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/exllamav2/model.py", line 21, in <module>
    from exllamav2.attn import ExLlamaV2Attention
  File "/home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/exllamav2/attn.py", line 19, in <module>
    import flash_attn
  File "/home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/flash_attn/__init__.py", line 3, in <module>
    from flash_attn.flash_attn_interface import flash_attn_func
  File "/home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 4, in <module>
    import flash_attn_2_cuda as flash_attn_cuda
ImportError: /home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv

@mega-ice
Copy link

ImportError: /home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv

Did you compile bitsandbytes with RoCm support? you will need to uninstall the auto-installed version. For RDNA3 GPU try:
(i am using this version of bitsandbytes on rocm 6.0.2, torch 2.3.0.dev20240219+rocm6.0)

git clone https://github.com/arlo-phoenix/bitsandbytes-rocm-5.6
ROCM_HOME=/opt/rocm ROCM_TARGET=gfx1100 make hip
pip install .

@Beinsezii
Copy link

I have installed flash attention using this PKGBUILD and am getting this error when trying to load a model using exllamav2 in text-generation-webui:

The navi flash attention branch won't work with exllamav2 regardless. It returns garbage if you bypass the v2.2.1 check.

ImportError: /home/zhenyapav/Projects/text-generation-webui/venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol

Clean rebuild should fix. git clean -fd; python setup.py clean; make clean; python setup.py bdist_wheel

Additionally make sure your pytorch and system rocm are the same. If your system is on 6.0 you'll have to use pytorch nightly then compile exllama, llama_cpp, + others for rocm 6 as text gen webui only provides 5.6 wheels.

I know pytorch statically links to rocm 5.6 but I believe things like pip install and the exllamav2 JIT will still look in /opt/rocm/ first.

@ZhenyaPav
Copy link

ZhenyaPav commented Feb 20, 2024

The navi flash attention branch won't work with exllamav2 regardless. It returns garbage if you bypass the v2.2.1 check.

What loaders do work then? AutoGPTQ?

Additionally make sure your pytorch and system rocm are the same. If your system is on 6.0 you'll have to use pytorch >nightly then compile exllama, llama_cpp, + others for rocm 6 as text gen webui only provides 5.6 wheels.

I'm having some issues installing exllama from the repo:

ModuleNotFoundError: No module named 'torch'

Though torch nightly is installed, and I am able to import torch in python console.

@Beinsezii
Copy link

Beinsezii commented Feb 20, 2024

What loaders do work then? AutoGPTQ?

AFAIK just Transformers and even that's super rough. I see my VRAM go down but it actually runs slower than without any flash attention...

If you're on ROCm 6, the rocm_enabled bitsandbytes branch should compile and run allowing 8/4 bit models in Transformers, but it seems to randomly cut some responses short. pytest shows a couple tests as failing due to produced infs so I'm wondering if that's why. Also it's slow, 1/10th the speed of Exllama for an equivalent bit depth.

If you're on 5.7 or earlier I think you need to checkout the rocm_enabled branch to the commit before they changed the HIPBLAS enum names.

I'm having some issues installing exllama from the repo:

ModuleNotFoundError: No module named 'torch'

Yea some builds of rocm torch nightly aren't correctly picked up by pip as satisfying the torch requirement. I have no idea why, but you can just comment out "torch" from the setup.py's dependencies list and it'll install fine. Or you could try an earlier torch nightly build, some of them don't have that issue. It seems to come and go. Maybe update pip?

Also if you're installing exllama2 from the repo make sure to check out on a tag instead of master as master contains breaking API changes oobabooga-webui doesn't account for yet.

TL;DR: There's no good way to get long context on AMD cards yet. The best option is still exllama2 and just accepting the high memory usage and long 8bit cache builds.

@gel-crabs
Copy link

gel-crabs commented Feb 26, 2024

Big update for Stable Diffusion WebUI users!!

So as it turns out, it was actually super easy to replace SDP with Doggettx/sub-quadratic the whole time, I was just looking in the wrong place. XFormers does the exact same thing, just in the attnblock forward instead of the attention forward.

11902.patch.txt

Above is an updated version of the WebUI patch if you haven't applied it already (rename it to 11902.patch).

If you've already applied it, you can just replace sd_hijack_optimizations.py with this copy (rename it to sd_hijack_optimizations.py):

sd_hijack_optimizations.py.txt

Note: I chose Sub-quadratic as Doggettx has similar VRAM usage as SDP, and it only switches at the end of generation anyway so VRAM use matters more than speed here.

@xhluca
Copy link

xhluca commented Feb 28, 2024

FYI You can find the shader ISA (gfxOOOO number) on techpowerup, e.g.: https://www.techpowerup.com/gpu-specs/radeon-rx-6900-xt.c3481

You can see it's a gfx1030.

seems like this issue is mainly concerining RX7000 gpus; does that mean 6000 gpus won't be supported?

@Beinsezii
Copy link

Or simply

rocminfo | grep Name

which will give you the board names for all ROCm devices

@ewof
Copy link

ewof commented Feb 28, 2024

will varlen fwd be added?

@Beinsezii
Copy link

Beinsezii commented Mar 1, 2024

I made a custom attention processor for use in Diffusers which falls back to SDP on > 128 head dims. Additionally I have a small guide going over setup.

Seems to be about +30% throughput in SDXL 1024², ramping up to +80% at the comically large 3840x2160.

@Beinsezii
Copy link

Beinsezii commented Mar 1, 2024

If anyone wants to use Flash Attention wherever Torch 2 SDP works, you can simply monkey patch it in before the call

import torch

if "AMD" in torch.cuda.get_device_name() or "Radeon" in torch.cuda.get_device_name():
    try:
        from flash_attn import flash_attn_func

        sdpa = torch.nn.functional.scaled_dot_product_attention

        def sdpa_hijack(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
            if query.shape[3] <= 128 and attn_mask is None:
                hidden_states = flash_attn_func(
                    q=query.transpose(1, 2),
                    k=key.transpose(1, 2),
                    v=value.transpose(1, 2),
                    dropout_p=dropout_p,
                    causal=is_causal,
                    softmax_scale=scale,
                ).transpose(1, 2)
            else:
                hidden_states = sdpa(
                    query=query,
                    key=key,
                    value=value,
                    attn_mask=attn_mask,
                    dropout_p=dropout_p,
                    is_causal=is_causal,
                    scale=scale,
                )
            return hidden_states

        torch.nn.functional.scaled_dot_product_attention = sdpa_hijack
        print("# # #\nHijacked SDPA with ROCm Flash Attention\n# # #")
    except ImportError as e:
        print(f"# # #\nCould not load Flash Attention for hijack:\n{e}\n# # #")
else:
    print(f"# # #\nCould not detect AMD GPU from:\n{torch.cuda.get_device_name()}\n# # #")

Then whenever something downstream requests the torch 2 sdp attention it should instead flash attention where supported.

@sancspro
Copy link

Big update for Stable Diffusion WebUI users!!

So as it turns out, it was actually super easy to replace SDP with Doggettx/sub-quadratic the whole time, I was just looking in the wrong place. XFormers does the exact same thing, just in the attnblock forward instead of the attention forward.

11902.patch.txt

Above is an updated version of the WebUI patch if you haven't applied it already (rename it to 11902.patch).

If you've already applied it, you can just replace sd_hijack_optimizations.py with this copy (rename it to sd_hijack_optimizations.py):

sd_hijack_optimizations.py.txt

Note: I chose Sub-quadratic as Doggettx has similar VRAM usage as SDP, and it only switches at the end of generation anyway so VRAM use matters more than speed here.

Hi there. I tried to implement this method for flash attention, and I am getting this error in webui:
RuntimeError: FlashAttention forward only supports head dimension at most 128

I am on latest ROCM and using 7800xt. Flash attention works with another repo: quickdif, but not with sd webui. Any suggestions to fix this? I mean how do we even limit the head dim in webui?

@feffy380
Copy link

@sancspro The patch is missing a check for head dim 128. You can add this to flash_attn_attention_forward to fall back to sub_quad_attention:

 def flash_attn_attention_forward(self, x, context=None, mask=None, **kwargs):
     h = self.heads
     q_in = self.to_q(x)
     context = default(context, x)
 
+    if q_in.shape[-1] // h > 128:
+        return sub_quad_attention_forward(self, x, context, mask, **kwargs)

@sancspro
Copy link

sancspro commented Mar 11, 2024

Hi @feffy380
Thanks for your response. I added the code as you suggested and now, the error is gone. But the generated image is unusable/noise. I thought maybe sub-quad attention is causing this but when I loaded webui with --opt-sub-quad-attention, it works fine.

@Beinsezii
Copy link

Beinsezii commented Mar 12, 2024

subquad uses transposed tensors compared to flash attention so that needs to be accounted for or else it'll be garbage.

If it's FA producing the garbage, make sure the q/k/v are ordered batch seq nhead dim. Other attentions use different orderings and FA won't err of they're mixed up, it'll just produce soup.

@sancspro
Copy link

I am unable to make it work with sd webui. Either throws head dimension above 128 error or just generates garbage.

I am a novice and sorry for this question. What is head dim with respect to sd webui? Is it related to the image resolution ?

On Mint, 7800xt, 7800x3d, 32GB RAM

@gel-crabs
Copy link

Oh crap, sorry. It was working on my setup so I had no idea about these issues. I updated the files I uploaded above with fixes, is it producing garbage now?

@sancspro
Copy link

sancspro commented Mar 13, 2024

Oh crap, sorry. It was working on my setup so I had no idea about these issues. I updated the files I uploaded above with fixes, is it producing garbage now?

Hi @gel-crabs
It is producing garbage even after replacing with your updated file.

Could be that I did not build flash-attn wheel correctly.

This is what I used to build it:

cd stable-diffusion-webui
python -m venv venv
source venv/bin/activate
pip install -U git+https://github.com/ROCm/flash-attention@howiejay/navi_support

Do you think this is correct?

@gel-crabs
Copy link

gel-crabs commented Mar 14, 2024

Oh crap, sorry. It was working on my setup so I had no idea about these issues. I updated the files I uploaded above with fixes, is it producing garbage now?

Hi @gel-crabs It is producing garbage even after replacing with your updated file.

Could be that I did not build flash-attn wheel correctly.

This is what I used to build it:

cd stable-diffusion-webui
python -m venv venv
source venv/bin/activate
pip install -U git+https://github.com/ROCm/flash-attention@howiejay/navi_support

Do you think this is correct?

Alright, I think I've got it. This just copy-pastes @Beinsezii's hijack method in with a small fix for the SDP attnblock, so just change from flash-attn back to SDP, then remove the --flash-attn option from modules/cmd_args.py and your webui startup.

This eschews sub-quadratic altogether (without a memory spike, somehow), so it shouldn't be producing garbage anymore. Thank you so much, @Beinsezii

sd_hijack_optimizations.py.txt

@Beinsezii
Copy link

I forgot to mention here but for ComfyUI people in this thread I already made a ComfyUI addon a few weeks ago that's basically just the SDPA hijack plonked into a fake node.

This just copy-pastes @Beinsezii's hijack method in with a small fix for the SDP attnblock

What's actually changed? It looks like my monkey patch is the same.

This eschews sub-quadratic altogether (without a memory spike, somehow), so it shouldn't be producing garbage anymore. Thank you so much, @Beinsezii

Subquad should use less memory than SDPA, but if the tile size is configured to be high enough it can end up using about the same. I've never looked at Auto's code so I can't say what goes on there.

I'd still recommend SDPA since it's basically the same function and really only the VAEs have memory issues when I tested up to 3840x2160. ComfyUI and Diffusers both have a tiled VAE flag which fixes that more or less flawlessly at a mild speed cost; I'm assuming Auto does too.

I've tested the current SDPA hijack on SD 1.5, 2.1, XL, Stable Cascade, and Pixart Alpha to great success so there shouldn't be anything broken as long as Auto's code is good. Even img2img and multi fused Lora seems to behave well.

@gel-crabs
Copy link

gel-crabs commented Mar 14, 2024

I forgot to mention here but for ComfyUI people in this thread I already made a ComfyUI addon a few weeks ago that's basically just the SDPA hijack plonked into a fake node.

This just copy-pastes @Beinsezii's hijack method in with a small fix for the SDP attnblock

What's actually changed? It looks like my monkey patch is the same.

This eschews sub-quadratic altogether (without a memory spike, somehow), so it shouldn't be producing garbage anymore. Thank you so much, @Beinsezii

Subquad should use less memory than SDPA, but if the tile size is configured to be high enough it can end up using about the same. I've never looked at Auto's code so I can't say what goes on there.

I'd still recommend SDPA since it's basically the same function and really only the VAEs have memory issues when I tested up to 3840x2160. ComfyUI and Diffusers both have a tiled VAE flag which fixes that more or less flawlessly at a mild speed cost; I'm assuming Auto does too.

I've tested the current SDPA hijack on SD 1.5, 2.1, XL, Stable Cascade, and Pixart Alpha to great success so there shouldn't be anything broken as long as Auto's code is good. Even img2img and multi fused Lora seems to behave well.

In the sdp_attnblock_forward function at the end:

q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
out = rearrange(out, 'b (h w) c -> b c h w', h=h)

Changed to:

q, k, v = (rearrange(t, 'b c h w -> b (h w) 1 c') for t in (q, k, v))
out = rearrange(out, 'b (h w) 1 c -> b c h w', h=h)

This was in the original flash-attn patch, and I had to add it due to this line in the patch producing a "tuple index out of range" error at the VAE stage:

if query.shape[3] <= 128 and attn_mask is None:

@sancspro
Copy link

sancspro commented Mar 14, 2024

sd_hijack_optimizations.py.txt

Finally!!! Wow thank you so much guys. @Beinsezii and @gel-crabs

It works but please continue to read below.

When I replaced the existing module with this latest hijack file from @gel-crabs , it did not work straightaway. Auto1111 continued to produce garbage. Refusing to give up, I uninstalled flash-attn and reinstalled it again. This time I had overridden gfx version as below before reinstalling flash-attn:
export HSA_OVERRIDE_GFX_VERSION=11.0.0

And let it build wheels for flash-attn. After this auto1111 started working fine with the flash-attn magic.

For anyone who wants to try it. Here's the commands that I ran to get it working after replacing sd_hijack_optimizations.py file:

export HSA_OVERRIDE_GFX_VERSION=11.0.0
cd stable-diffusion-webui
python -m venv venv
source venv/bin/activate

pip install -U git+https://github.com/ROCm/flash-attention@howiejay/navi_support

python launch.py --opt-sdp-attention

Some quick comparison on how tremendous this is for AMD folks who use auto1111.

SD1.5 base model on Flash attention
1024 – 2.31 it/s VRAM used 5.5 GB
768 – 5.56 it/s VRAM 3.5 GB
512 – 15 it/s VRAM 3.5 GB
SDXL Juggernaut on Flash attention
1024 – 1.90 it/s VRAM used 9.0 GB

SD1.5 base model on Doggetx (default)
1024 – 1.50 it/s VRAM used 15.3 GB
768 – 3.86 it/s VRAM 8.5 GB
512 – 12 it/s VRAM 4.5 GB
SDXL Juggernaut on Doggetx (default)
1024 – 1.50 it/s VRAM used 9.5 GB

How ridiculous this looks when you see the VRAM usage.. LOL

My specs: Linux Mint, 7800XT, torch==2.3.0.dev20240314+rocm6.0.

@sancspro
Copy link

sancspro commented Mar 14, 2024

Big update for Stable Diffusion WebUI users!!
So as it turns out, it was actually super easy to replace SDP with Doggettx/sub-quadratic the whole time, I was just looking in the wrong place. XFormers does the exact same thing, just in the attnblock forward instead of the attention forward.
11902.patch.txt
Above is an updated version of the WebUI patch if you haven't applied it already (rename it to 11902.patch).
If you've already applied it, you can just replace sd_hijack_optimizations.py with this copy (rename it to sd_hijack_optimizations.py):
sd_hijack_optimizations.py.txt
Note: I chose Sub-quadratic as Doggettx has similar VRAM usage as SDP, and it only switches at the end of generation anyway so VRAM use matters more than speed here.

Just a note that this also works perfect.

@nonetrix
Copy link

Has anyone had luck with GFX 1030? I think I am out of luck :/

@Beinsezii
Copy link

RDNA ≤ 2 cards don't have WMMAs so I'm not sure they'll be supported anytime soon.

@hbfreed
Copy link

hbfreed commented Apr 22, 2024

@xzuyn I'm having the same error, did you find a fix for it? it's on the backward pass...

@nonetrix
Copy link

RDNA ≤ 2 cards don't have WMMAs so I'm not sure they'll be supported anytime soon.

Is this a absolutely required? Could it be forked to use some other implementation that does functionally the same thing?

@Beinsezii
Copy link

Is this a absolutely required? Could it be forked to use some other implementation that does functionally the same thing?

You could always use a pure pytorch impl like sub quadratic. Tune the tile sizes to your GPU and it should perform okayish and not OOM on large shapes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
navi hardware
Projects
None yet
Development

No branches or pull requests