Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cannot match attention layer output to pytorch's one #231

Closed
CarloLucibello opened this issue May 5, 2023 · 2 comments · Fixed by #230
Closed

cannot match attention layer output to pytorch's one #231

CarloLucibello opened this issue May 5, 2023 · 2 comments · Fixed by #230

Comments

@CarloLucibello
Copy link
Member

CarloLucibello commented May 5, 2023

I'm struggling to import the weights from torchvision's ViT to ours. The problem is that the correct map of the attention layers in torch to the one in metalhead seems non-trivial.

using PythonCall, Metalhead
torch = pyimport("torch")

function th2jl(x::Py)
    xj = pyconvert(Array, x.detach().numpy())
    xj = permutedims(xj, ndims(xj):-1:1)
    return xj
end

m = torch.nn.MultiheadAttention(embed_dim=2, num_heads=1, batch_first=true, bias=false, add_bias_kv=false)
# python forward pass
x = torch.randn(1, 3, 2)
y, a = m(x, x, x, need_weights=true)

mj = Metalhead.MHAttention(2, 1, qkv_bias=false)
# copy weights
mj.qkv_layer.weight .=  th2jl(m.in_proj_weight)'    # transpose back since Linear layers in pytorch don't need transpose 
mj.projection.layers[1].weight .=  th2jl(m.out_proj.weight)' 
# julia forward pass
xj = th2jl(x)
yj = mj(xj)

@assert yj  th2jl(y) # false

Probably this is due to the permutations and chunking in our initial projection, possibly we should rearrange them in such a way that the natural weight mapping from pytorch just works.

Pinging @theabhirath for more insights.

@theabhirath
Copy link
Member

The attention layers need a lot of TLC. They were written before a lot of functionality landed in upstream libraries such as FluxML/NNlib.jl#455, and so are presumably not only slower but also doing way more things than they need to. This is one aspect that someone can take up and re-write. Given that NNlib has the functionality we need, now only a couple of questions need to be answered:

  1. Does NNlib's attention also have an NNlibCUDA equivalent giving us good performance on GPUs?
  2. Metalhead should probably not use the attention function directly and use something like TensorCast along with the extended batched_mul to write the MHAttention layer. Is this GPU friendly, AD friendly and performant enough? If not, then we can always fall back to the NNlib version but this prevents Metalhead from adding its own goodies

@CarloLucibello
Copy link
Member Author

Let's use NNlib as a backend then. It is gpu and AD friendly.
The forward pass has to be implemented as follows (single head case, to be generalized):

qkv = mj.qkv_layer(xj)
q, k, v = chunk(qkv, 3, dims=1)
yj, aj= NNlib.dot_product_attention(q, k, v)
yj = mj.projection(yj)

@assert yj  th2jl(y)  # PASS, compatible with pytorch layer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants