Skip to content

Conversation

@wang2yn84
Copy link
Collaborator

No description provided.

@wang2yn84 wang2yn84 requested review from FanhaiLu1 and qihqi June 21, 2024 23:04
# fill mask first
mask = decode_state.mask.at[:, decode_state.current_position].set(0)
if self.env.ring_buffer:
input_indexes = jnp.full((1,), pos)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we change current_position to [batch_size, 1], can we use same logic do mask for both ring_buffer and onn_ring_buffer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. For non ring buffer case, there is one single value of current position to indicate the decoding position for all the batches. But for ring buffer, every batch has different position, so we cannot use the current_position here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I mean if we change current_position to [batch_size, 1], different slot can have different the current_position. For non ring buffer case, the current_position should be same as input_pos.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will cause performance regression. Please check jax_experiments.py/test7, inserting with batching + position array takes much longer, like x4~x5

input_pos = jnp.where(
decode_state.input_pos == 0,
0,
decode_state.input_pos + 1 % self.env.cache_len,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In non ring buffer case, can input_pos be larger than cache len?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no, I feel we don't need do % since it never reach the cache len.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have control for this. Generate() will keep running if no new prefill results are inserted.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing the details!

@qihqi qihqi merged commit 175d956 into main Jun 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants