New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RDNA3 support #27
Comments
A CK disscussion has show a branch which has flash-attention kernel impl and already work in ait ROCm/composable_kernel#1032 |
Hi @WilliamGazeley and @sdli1995, I wanted to update you on the attention kernels for the NAVI platform. My colleague @aska-0096 did activate them. However, these Flash-Attention kernels were initially developed for MI devices, which operate on a distinct set of CK kernels. In essence, the issue is that we haven't yet integrated these kernels into our current API. I plan to work on this integration in my spare time. |
Great work on the fork, @howiejayz Will it take too long to port the kernels for gfx1100 and other non-MI architectures? I need the kernels for my project and I'd be willing to help out if I can. |
Thanks for reaching out and offering to help @AlpinDale! I'm currently tied up with a few other projects, so I can't give an exact timeframe for porting the kernels for gfx1100 and other architectures. But I'm definitely planning to tackle this soon. The first step will be creating a new code path for gfx110x, considering their CK kernels are only for forward ops. I'm totally open to suggestions or any help you can provide. It'd be great to have some extra hands on this. Let me know if you're interested! |
I am a complete novice in this field, but a few months ago I managed to make Composable Kernel, Flash Attention and PyTorch work together for my RX 7900 XTX (see here, sort by performance, and look for the first one). Although I was able get that Flash Attention implementation "working" in the end, the generated images were meaningless, and I gave up because I didn't know how to fix it. Here are the relevant branch links, and I hope they can be of some help to you:
I made a grouped fused kernel by Frankensteining the batched fused kernel, which matched the call signatures in this repo at that time. However, that self-made kernel might just be broken. |
Hi @evshiron! First off, I must say I'm seriously impressed by your work! It's quite an achievement, and the resources you've provided are invaluable. I've had the opportunity to build your implementation on gfx1100, and I'm pleased to report that the build was successful. However, I encountered an issue with the unit tests not passing due to incorrect results in the forward pass: assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() which is likely stemming from incorrect parameter settings in the CK kernels. I guess this should be the reason why the output image become meaningless. Despite this, your work has been immensely helpful! This will massively speed up the navi porting process for the v2 implementation. |
I'm glad that my humble work could be of some help. I am indeed unfamiliar with this field, so I can only leave it to professionals. |
Guys I have added the batched forward(consistent sequence lengths) support for gfx1100, gfx1101, gfx1102 under this branch. Thanks to @aska-0096's CK kernels. The implementation is still under development and there are a lot of things to fine-tune. For now I see the performance is generally better when To install just use I only had the chance to test it on gfx1100 but I expect it works as well for the other two. Let me know if there is any issue! The docker I used to test is |
Using a 7900XTX with Results for `benchmark_flash_attention_forward.py`
Results for `test_flash_attn_wmma_rocm.py`
Full Log: test_flash_attn_wmma_rocm.log Error for `benchmark_flash_attention.py`
|
Some benchmark results. RTX 4090
7900 XTX
|
4090 fp16 accumulate fp16 tensorcore performance is 330T ,while 7900xtx is 120T the better reference nvidia card is rtx3090 |
Thanks for the benchmark data. We are going to launch a new version of Composable Kernel with better flash-attention performance. Adapt the optimization on RDNA3 is in my plan. |
@sdli1995 here's the benchmarks with a 3090:
|
Any updates on this? |
We need official support for flash attention |
trust bro, be patient don't rush them |
I've been using the howiejayz/navi_support branch on here with stable-diffusion-webui for a few weeks now. The implementation is perfect. On an RX 7800 XT, it speeds it/s up from 1.75 it/s to 2 it/s, all while massively decreasing VRAM usage. |
Could you please provide more information about how. Did you just install the branch install it and it worked out of the box or did you have to change code of the webui you are using ? |
@gel-crabs I failed to install flash-attn for Navi. please give more info |
Alright, I'm going to try to give instructions on how I got this to work. If you're on Arch, I have a very amateur PKGBUILD (requiring --skipinteg) that gets it to work. You need to go into the PKGBUILD and replace GPU_ARCHS=gfx1101 with your GPU's architecture and MAX_JOBS to however many CPU cores you have. I can only confirm it will work on gfx11+. The patch just changes the C++ standard from c++20 to c++17 to allow it to build. If you aren't on Arch, you can generally just follow the commands and install the python wheel file afterwards, in your virtualenv if you're using one. You can clone the repo with Now for webui. You will have to use a patch that has been closed since it will be obsolete once AMD finishes xformers support. https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11902.patch That's the link to the raw patch file, use AUTOMATIC1111/stable-diffusion-webui#11902 In the discussion for this patch, I posted a long comment on getting it to work. (AUTOMATIC1111/stable-diffusion-webui#11902 (comment)). The important info from that is the 2 code blocks you need to change manually. After that, add --flash-attn to your command line arguments and it should work. If you get slower or the same speed, flash attention isn't working. You may get a HIP OOM at the end of generation if you're using a higher resolution than usual, as it needs to switch back to SDP at the end due to not supporting a head dim over 128. If you get an error involving a flash-attn so not loading, rebuild the PKGBUILD but change CC=clang and CXX=clang++ to CC=gcc and CXX=g++. |
Switching setup.py to c++17 built successfully on gfx1100 Seems to work in transformers, as a 7b model OOMs @ >8k context using the default attention but doesn't even crack 20 gigs with FA2 enabled. Interestingly I lose about 1 t/s though? I'll have to see if I can monkeypatch it into Diffusers... |
@gel-crabs that and tried FLASH_ATTENTION_INTERNAL_USE_RTN=1 pip install . (I use Debian.) |
Do you mean the unit testing? For that you need to export FLASH_ATTENTION_INTERNAL_UNIT_TEST_MODE=1 and FLASH_ATTENTION_INTERNAL_DETERMINISTIC=1. You should also set your GPU_ARCHS to your GPU architecture (gfx1100, gfx1101, etc.) and try building with GCC and Clang. I can also only guarantee this will work on ROCM 5.7 and up. For anything other than SD webui, you will likely have to create a forward function yourself, as it is a PyTorch extension and isn't integrated into PyTorch yet. The implementation is here, but keep in mind it requires my changes as well: https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11902/files |
I don't have the knowledge to contribute to this issue, but I'm really rooting for this support feature! |
Could you provide me what you found. What‘s the size you can go before it has to fallback to sdp. |
I found larger resolutions benefit more. On SD1.5, 1088x1088 on an RX 7900 XTX went from 1.67 it/s to 3 it/s while 512px was a more modest 16 it/s to 18 it/s. Note: If using a venv I found I had to build the wheel with the venv activated. Otherwise, the library complained about undefined symbols and failed to load. |
I think I explained wrong; it will always fall back to SDP at the very end of generation (after all the steps have finished), resolution doesn't factor into it. What I meant is that the massively decreased VRAM usage will (currently..?) not allow you to use a higher resolution than with regular SDP attention, as the VRAM usage will jump back up at the end. However, it most likely could... if AMD hooked up CK's Navi branch to their new xformers port. ;) |
Also note: the switch back to SDP at the end (I believe only with SDXL) can be prevented by switching from full VAE to TAESD, or (I assume) a tiled VAE implementation. This allows a 1024x1024 image to be upscaled to 2048x2048 with SDXL on an RX 7800 XT with 16GB of VRAM. |
I have installed flash attention using this PKGBUILD and am getting this error when trying to load a model using exllamav2 in text-generation-webui:
|
Did you compile bitsandbytes with RoCm support? you will need to uninstall the auto-installed version. For RDNA3 GPU try:
|
The navi flash attention branch won't work with exllamav2 regardless. It returns garbage if you bypass the v2.2.1 check.
Clean rebuild should fix. Additionally make sure your pytorch and system rocm are the same. If your system is on 6.0 you'll have to use pytorch nightly then compile exllama, llama_cpp, + others for rocm 6 as text gen webui only provides 5.6 wheels. I know pytorch statically links to rocm 5.6 but I believe things like |
What loaders do work then? AutoGPTQ?
I'm having some issues installing exllama from the repo:
Though torch nightly is installed, and I am able to import torch in python console. |
AFAIK just Transformers and even that's super rough. I see my VRAM go down but it actually runs slower than without any flash attention... If you're on ROCm 6, the If you're on 5.7 or earlier I think you need to checkout the
Yea some builds of rocm torch nightly aren't correctly picked up by pip as satisfying the Also if you're installing exllama2 from the repo make sure to check out on a tag instead of master as master contains breaking API changes oobabooga-webui doesn't account for yet. TL;DR: There's no good way to get long context on AMD cards yet. The best option is still exllama2 and just accepting the high memory usage and long 8bit cache builds. |
Big update for Stable Diffusion WebUI users!! So as it turns out, it was actually super easy to replace SDP with Doggettx/sub-quadratic the whole time, I was just looking in the wrong place. XFormers does the exact same thing, just in the attnblock forward instead of the attention forward. Above is an updated version of the WebUI patch if you haven't applied it already (rename it to 11902.patch). If you've already applied it, you can just replace sd_hijack_optimizations.py with this copy (rename it to sd_hijack_optimizations.py): sd_hijack_optimizations.py.txt Note: I chose Sub-quadratic as Doggettx has similar VRAM usage as SDP, and it only switches at the end of generation anyway so VRAM use matters more than speed here. |
FYI You can find the shader ISA ( You can see it's a gfx1030. seems like this issue is mainly concerining RX7000 gpus; does that mean 6000 gpus won't be supported? |
Or simply rocminfo | grep Name which will give you the board names for all ROCm devices |
will varlen fwd be added? |
I made a custom attention processor for use in Diffusers which falls back to SDP on > 128 head dims. Additionally I have a small guide going over setup. Seems to be about +30% throughput in SDXL 1024², ramping up to +80% at the comically large 3840x2160. |
If anyone wants to use Flash Attention wherever Torch 2 SDP works, you can simply monkey patch it in before the call import torch
if "AMD" in torch.cuda.get_device_name() or "Radeon" in torch.cuda.get_device_name():
try:
from flash_attn import flash_attn_func
sdpa = torch.nn.functional.scaled_dot_product_attention
def sdpa_hijack(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
if query.shape[3] <= 128 and attn_mask is None:
hidden_states = flash_attn_func(
q=query.transpose(1, 2),
k=key.transpose(1, 2),
v=value.transpose(1, 2),
dropout_p=dropout_p,
causal=is_causal,
softmax_scale=scale,
).transpose(1, 2)
else:
hidden_states = sdpa(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
return hidden_states
torch.nn.functional.scaled_dot_product_attention = sdpa_hijack
print("# # #\nHijacked SDPA with ROCm Flash Attention\n# # #")
except ImportError as e:
print(f"# # #\nCould not load Flash Attention for hijack:\n{e}\n# # #")
else:
print(f"# # #\nCould not detect AMD GPU from:\n{torch.cuda.get_device_name()}\n# # #") Then whenever something downstream requests the torch 2 sdp attention it should instead flash attention where supported. |
Hi there. I tried to implement this method for flash attention, and I am getting this error in webui: I am on latest ROCM and using 7800xt. Flash attention works with another repo: quickdif, but not with sd webui. Any suggestions to fix this? I mean how do we even limit the head dim in webui? |
@sancspro The patch is missing a check for head dim 128. You can add this to def flash_attn_attention_forward(self, x, context=None, mask=None, **kwargs):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
+ if q_in.shape[-1] // h > 128:
+ return sub_quad_attention_forward(self, x, context, mask, **kwargs) |
Hi @feffy380 |
subquad uses transposed tensors compared to flash attention so that needs to be accounted for or else it'll be garbage. If it's FA producing the garbage, make sure the q/k/v are ordered batch seq nhead dim. Other attentions use different orderings and FA won't err of they're mixed up, it'll just produce soup. |
I am unable to make it work with sd webui. Either throws head dimension above 128 error or just generates garbage. I am a novice and sorry for this question. What is head dim with respect to sd webui? Is it related to the image resolution ? On Mint, 7800xt, 7800x3d, 32GB RAM |
Oh crap, sorry. It was working on my setup so I had no idea about these issues. I updated the files I uploaded above with fixes, is it producing garbage now? |
Hi @gel-crabs Could be that I did not build flash-attn wheel correctly. This is what I used to build it:
Do you think this is correct? |
Alright, I think I've got it. This just copy-pastes @Beinsezii's hijack method in with a small fix for the SDP attnblock, so just change from flash-attn back to SDP, then remove the --flash-attn option from modules/cmd_args.py and your webui startup. This eschews sub-quadratic altogether (without a memory spike, somehow), so it shouldn't be producing garbage anymore. Thank you so much, @Beinsezii |
I forgot to mention here but for ComfyUI people in this thread I already made a ComfyUI addon a few weeks ago that's basically just the SDPA hijack plonked into a fake node.
What's actually changed? It looks like my monkey patch is the same.
Subquad should use less memory than SDPA, but if the tile size is configured to be high enough it can end up using about the same. I've never looked at Auto's code so I can't say what goes on there. I'd still recommend SDPA since it's basically the same function and really only the VAEs have memory issues when I tested up to 3840x2160. ComfyUI and Diffusers both have a tiled VAE flag which fixes that more or less flawlessly at a mild speed cost; I'm assuming Auto does too. I've tested the current SDPA hijack on SD 1.5, 2.1, XL, Stable Cascade, and Pixart Alpha to great success so there shouldn't be anything broken as long as Auto's code is good. Even img2img and multi fused Lora seems to behave well. |
In the sdp_attnblock_forward function at the end:
Changed to:
This was in the original flash-attn patch, and I had to add it due to this line in the patch producing a "tuple index out of range" error at the VAE stage:
|
Finally!!! Wow thank you so much guys. @Beinsezii and @gel-crabs It works but please continue to read below. When I replaced the existing module with this latest hijack file from @gel-crabs , it did not work straightaway. Auto1111 continued to produce garbage. Refusing to give up, I uninstalled flash-attn and reinstalled it again. This time I had overridden gfx version as below before reinstalling flash-attn: And let it build wheels for flash-attn. After this auto1111 started working fine with the flash-attn magic. For anyone who wants to try it. Here's the commands that I ran to get it working after replacing sd_hijack_optimizations.py file:
Some quick comparison on how tremendous this is for AMD folks who use auto1111. SD1.5 base model on Flash attention SD1.5 base model on Doggetx (default) How ridiculous this looks when you see the VRAM usage.. LOL My specs: Linux Mint, 7800XT, torch==2.3.0.dev20240314+rocm6.0. |
Just a note that this also works perfect. |
Has anyone had luck with GFX 1030? I think I am out of luck :/ |
RDNA ≤ 2 cards don't have WMMAs so I'm not sure they'll be supported anytime soon. |
@xzuyn I'm having the same error, did you find a fix for it? it's on the backward pass... |
Is this a absolutely required? Could it be forked to use some other implementation that does functionally the same thing? |
You could always use a pure pytorch impl like sub quadratic. Tune the tile sizes to your GPU and it should perform okayish and not OOM on large shapes. |
Great work so far. I'm trying to run vLLM on my 7900XTX cards and was wondering if there were any plans to support RDNA3?
The text was updated successfully, but these errors were encountered: