-
Notifications
You must be signed in to change notification settings - Fork 18
Add gemma support #69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
b77f88c to
eb739eb
Compare
shard config for llamA gemma 3 gemma4 formatter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Hans for adding gemma support and thanks for refactor engine, env to make code more clear! Overall, it looks great!
| # "replicated" to signify "replicated". | ||
| # Integer signify axis to shard: 0 <= shard axis < rank | ||
|
|
||
| freqs_cis : null # torch.complex64 (16384, 128) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we keep consistency on replicated sharding? If either null or -1 is fine, shall we just keep -1 in our code base (use null in gemma, but -1 in llama)?
| ) | ||
| return caches | ||
|
|
||
| def sharding_by_name(self, name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, it's more clear than previous hardcode one.
| """Attention module.""" | ||
|
|
||
| def __init__(self, args, env): | ||
| def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, it's nice to see the layers is decoupled with args.
| print(f"---------> {jax.devices()}") | ||
|
|
||
| env, model_arg = helpers.make_env_tiny(bf16_enable=False) | ||
| torch.set_default_dtype(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove this line. helpers.make_env_tiny has code to do set_default_dtype:
torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
torch.set_default_dtype(torch_dtype)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as other torch.set_default_dtype in this files.
Changes include:
Run interactive and run server both works with random weights.