Skip to content

Conversation

@qihqi
Copy link
Collaborator

@qihqi qihqi commented May 10, 2024

For Gemma 2B we need to change the shardings because the dimension we usually shard, num_kv_heads happens to be 1 for Gemma 2B. So we pick a different one to shard.

@qihqi qihqi requested review from FanhaiLu1 and lsy323 May 10, 2024 00:09
@qihqi qihqi force-pushed the hanq_add_model branch from e4c6542 to c4679e7 Compare May 10, 2024 02:04

with jax.named_scope("attn_insert_cache"):
keys, values = cache.update(xk, xv)
self.env.apply_sharding(keys, axis=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to remove the keys, values sharding?

Copy link
Collaborator

@FanhaiLu1 FanhaiLu1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the lint error:

************* Module benchmarks.analyze_sharegpt
benchmarks/analyze_sharegpt.py:75:19: W0621: Redefining name 'prefill_bucket_size_to_ms' from outer scope (line 56) (redefined-outer-name)

@FanhaiLu1 FanhaiLu1 merged commit 48a8a22 into main May 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants