-
Notifications
You must be signed in to change notification settings - Fork 12.2k
ggml : fix FA mask dim 2 and 3 #14505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Edit: nvm, I think my approach was not correct. |
2a20a7e
to
89ee2f1
Compare
@JohannesGaessler @jeffbolznv For now I will disable the broadcast of To summarize, the correct broadcast logic that we have to support is as follow: Lines 1983 to 2003 in 89ee2f1
In practice, the The This way, the mask that we pass to Merging this for now so I can continue working on top of this and later on we'll hopefully add support for these cases. |
a65fa3a
to
b1b22ae
Compare
In #14500, @JohannesGaessler correctly noted that the FA did not utilize dim 3 of the mask. I overlooked this and now as I was updating #14363 realized that we need to align the dimensions.
The fix is simple, I will try to update it myself across the Vulkan and CUDA backends later today.Also small fix for the
ggml_soft_max_ext()
: was incorrectly requiring the mask to be a 3D array + test for this.