-
-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move MHAttention layer to Flux #141
Comments
Is NNAttentionLib.matmul simple enough that can be ported without carrying over another library dependance? |
If Flux is willing to take on NeuralAttentionlib as a dep, and we would need to rework it to accept more inputs. Currently, it only accepts 3D. |
Oh sorry, I posted my comment before the page refreshed with your's. I think all we need to provide is a parallelized version of |
NeuralAttentionlib already works with more than 3D inputs - one of the reasons I used it as a dep was that it would allow that functionality in the future (see #135 (comment)). The only concern could probably be that while the GPU path is parallelised (it uses the same CUBLAS functions underneath as NNlib), the CPU path is not (#135 (comment)). And NeuralAttentionlib basically already provides a readymade multiheaded self-attention layer. I had thought about a PR but decided against it because vanilla attention is hardly ever used anymore - most attention layers involve some novelty and so have to be written out in a custom manner (A 4D+ version of |
So you say that the attention layer in pytorch is hardly used in practice? https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html |
I'm not sure what you mean. I'm pretty sure everyone in this thread wants to add the layer to Flux if that's what you're getting at. |
No, this one is used quite often...I meant the layer in the form as it is in Metalhead, which is quite task-specific, as opposed to say the NeuralAttentionlib functions (https://chengchingwen.github.io/NeuralAttentionlib.jl/stable/api/#NeuralAttentionlib.multihead_qkv_attention or https://chengchingwen.github.io/NeuralAttentionlib.jl/stable/api/#NeuralAttentionlib.generic_multihead_qkv_attention, although the PyTorch function certainly exposes a lot more stuff that can be tweaked - something the Flux layer could possibly incorporate) |
I can adapt the same multithreading approach that |
Sorry I misread this remark:
Having the new NeuralAttentionLib dependency in Flux should be fine, it seems a well designed and well maintained library. Maybe it contains more than what is strictly needed, so I was hoping we could just consolidate things in NNlib and avoid dispersion. I would be ok with both paths forward. |
@theabhirath @darsnack do you think the multi head attention layer is now general enough we can move it to Flux?
The text was updated successfully, but these errors were encountered: