-
Notifications
You must be signed in to change notification settings - Fork 18
Integrates ragged attention to JetStream Pytorch #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| mask, | ||
| start, | ||
| input_pos, | ||
| pre_batch, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few comments on what is pre_batch and pre_block?
Also should gemma/model_exportable.py also be modified?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing out this! Pushed the commits to update the Gemma model.
| slot, | ||
| ) | ||
|
|
||
| def precompute_ragged_block_indices(self, decode_state: DecodeState): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: do you need other attribute from decode_state besides start and input_pos, pass start and input_pos instead of heavy object decode_state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only the start and input_pos is used, but considering we just passing the reference to the decode_state, so it should not affect the performance at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good
| end_ref, | ||
| line_end_ref, | ||
| pre_b_ref, | ||
| pre_i_ref, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
q, k, v and related are easy to read. What are the b, i, o, m, l, bk and pre? Can you add brief description to describe them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added. Done.
| slot, | ||
| ) | ||
|
|
||
| def precompute_ragged_block_indices(self, decode_state: DecodeState): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good
…al performance. Fix the typo that should use jax.lax.div instead of jnp.div
…ocessing API from JetStream.
dense_attention_quantized and use option to control if it's quantization or not. Use the new torch_xla2 API.
…a ring buffer. Will cause error.
* refactor flags * clean up: * fix run_server * move common flags to global * format * update * udpate readme * update run_interactive
…flags for debugging and performance tuning.
… align with main.
cbb2fe9 to
ab38726
Compare
…nts. The error message is missing positional arguments.
…nput_pos) back to original to avoid unnecessary issues.
…nel. Fix other lint errors.
Currently the performance is on par with dense attention. We can keep improving the performance in the following PRs.