Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark effect of merging query and keys matrices in transformers #3

Open
Ayushk4 opened this issue Mar 19, 2023 · 1 comment
Open

Comments

@Ayushk4
Copy link
Member

Ayushk4 commented Mar 19, 2023

For certain architectures (like GPTJ and LLaMa), it may be possible to replace Query $Q$ and Key $K$ matrices by a single matrix - saving on 1 out of seven/eight matrix multiplications in the transformer. I don't see an obvious way of having this for GPT-NeoX and OPT.

Take a standard benchmark, run the model before and after merging Query and Key matrices.

---------- Following are the details: (How to write latex in GitHub?)----------
.T() denotes transpose

Consider the input representation $X = {x1, ... xi, ... xj, ... xn}$.
qi = MatMul(Q, xi)
kj = MatMul(K, xj)

score_i,j = MatMul(qi.T(), kj)
= MatMul( MatMul(Q, xi).T(), MatMul(K, xj) )
= MatMul( MatMul(xi.T(), Q.T()), MatMul(K, xj) )
= MatrixChainMul(xi.T(), Q.T(), K, xj)

let QKMerge = MatMul(Q.T(), K)

score_i,j = MatrixChainMul(xi.T(), QKMerge, xj)

@Ayushk4 Ayushk4 changed the title Benchmark after merging query and keys matrices in transformers Benchmark effect of merging query and keys matrices in transformers Mar 19, 2023
@Ayushk4
Copy link
Member Author

Ayushk4 commented Mar 21, 2023

Above formula will have to be modified for rotary embeddings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant