Skip to content

Conversation

@wang2yn84
Copy link
Collaborator

Currently the performance is on par with dense attention. We can keep improving the performance in the following PRs.

@wang2yn84 wang2yn84 requested review from FanhaiLu1 and qihqi May 20, 2024 21:43
mask,
start,
input_pos,
pre_batch,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few comments on what is pre_batch and pre_block?

Also should gemma/model_exportable.py also be modified?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out this! Pushed the commits to update the Gemma model.

slot,
)

def precompute_ragged_block_indices(self, decode_state: DecodeState):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: do you need other attribute from decode_state besides start and input_pos, pass start and input_pos instead of heavy object decode_state.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only the start and input_pos is used, but considering we just passing the reference to the decode_state, so it should not affect the performance at all.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

end_ref,
line_end_ref,
pre_b_ref,
pre_i_ref,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q, k, v and related are easy to read. What are the b, i, o, m, l, bk and pre? Can you add brief description to describe them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. Done.

slot,
)

def precompute_ragged_block_indices(self, decode_state: DecodeState):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

wang2yn84 and others added 25 commits May 23, 2024 05:46
…al performance. Fix the typo that should use jax.lax.div instead of jnp.div
dense_attention_quantized and use option to control
if it's quantization or not. Use the new torch_xla2 API.
* refactor flags

* clean up:

* fix run_server

* move common flags to global

* format

* update

* udpate readme

* update run_interactive
@wang2yn84 wang2yn84 force-pushed the ragged-attention-final2 branch from cbb2fe9 to ab38726 Compare May 23, 2024 18:33
@wang2yn84 wang2yn84 merged commit 517d847 into main May 23, 2024
@wang2yn84 wang2yn84 deleted the ragged-attention-final2 branch May 23, 2024 21:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants