Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move MHAttention layer to Flux #141

Closed
CarloLucibello opened this issue Apr 4, 2022 · 9 comments · Fixed by FluxML/Flux.jl#2146
Closed

Move MHAttention layer to Flux #141

CarloLucibello opened this issue Apr 4, 2022 · 9 comments · Fixed by FluxML/Flux.jl#2146

Comments

@CarloLucibello
Copy link
Member

@theabhirath @darsnack do you think the multi head attention layer is now general enough we can move it to Flux?

@CarloLucibello CarloLucibello changed the title Move MHAttention to Flux Move MHAttention layer to Flux Apr 4, 2022
@CarloLucibello
Copy link
Member Author

Is NNAttentionLib.matmul simple enough that can be ported without carrying over another library dependance?

@darsnack
Copy link
Member

darsnack commented Apr 4, 2022

If Flux is willing to take on NeuralAttentionlib as a dep, and we would need to rework it to accept more inputs. Currently, it only accepts 3D.

@darsnack
Copy link
Member

darsnack commented Apr 4, 2022

Oh sorry, I posted my comment before the page refreshed with your's. I think all we need to provide is a parallelized version of NNlib.batched_mul for 4D inputs. It could be something that's specific for this layer that calls batched_mul under the hood, since I know there were some concerns about generic 4D implementation when it was brought up.

@theabhirath
Copy link
Member

theabhirath commented Apr 4, 2022

NeuralAttentionlib already works with more than 3D inputs - one of the reasons I used it as a dep was that it would allow that functionality in the future (see #135 (comment)). The only concern could probably be that while the GPU path is parallelised (it uses the same CUBLAS functions underneath as NNlib), the CPU path is not (#135 (comment)). And NeuralAttentionlib basically already provides a readymade multiheaded self-attention layer. I had thought about a PR but decided against it because vanilla attention is hardly ever used anymore - most attention layers involve some novelty and so have to be written out in a custom manner (A 4D+ version of NNlib.batched_mul wouldn't hurt though, as I brought up in FluxML/NNlib.jl#391)

@CarloLucibello
Copy link
Member Author

So you say that the attention layer in pytorch is hardly used in practice? https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html

@darsnack
Copy link
Member

darsnack commented Apr 6, 2022

I'm not sure what you mean. I'm pretty sure everyone in this thread wants to add the layer to Flux if that's what you're getting at.

@theabhirath
Copy link
Member

So you say that the attention layer in pytorch is hardly used in practice? https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html

No, this one is used quite often...I meant the layer in the form as it is in Metalhead, which is quite task-specific, as opposed to say the NeuralAttentionlib functions (https://chengchingwen.github.io/NeuralAttentionlib.jl/stable/api/#NeuralAttentionlib.multihead_qkv_attention or https://chengchingwen.github.io/NeuralAttentionlib.jl/stable/api/#NeuralAttentionlib.generic_multihead_qkv_attention, although the PyTorch function certainly exposes a lot more stuff that can be tweaked - something the Flux layer could possibly incorporate)

@chengchingwen
Copy link
Member

The only concern could probably be that while the GPU path is parallelised (it uses the same CUBLAS functions underneath as NNlib), the CPU path is not (#135 (comment)).

I can adapt the same multithreading approach that batched_mul use if the CPU part is really a concern.

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Apr 7, 2022

I'm not sure what you mean. I'm pretty sure everyone in this thread wants to add the layer to Flux if that's what you're getting at.

Sorry I misread this remark:

I had thought about a PR but decided against it because vanilla attention is hardly ever used anymore - most attention layers involve some novelty and so have to be written out in a custom manner

Having the new NeuralAttentionLib dependency in Flux should be fine, it seems a well designed and well maintained library. Maybe it contains more than what is strictly needed, so I was hoping we could just consolidate things in NNlib and avoid dispersion. I would be ok with both paths forward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants