Skip to content

Conversation

@qihqi
Copy link
Collaborator

@qihqi qihqi commented May 3, 2024

Changes include:

  • refactor engine to not use hardcoded sharding annotations; but instead read from a file.
  • Replace GemmaAttention class used in Gemma with layers.Attention: they seems to be doing the same thing.
  • added model_name arg to create_engine

Run interactive and run server both works with random weights.

@qihqi qihqi requested review from FanhaiLu1, lsy323 and wang2yn84 May 3, 2024 22:45
@qihqi qihqi force-pushed the hanq_add_model branch 2 times, most recently from b77f88c to eb739eb Compare May 4, 2024 00:22
shard config for llamA

gemma 3

gemma4

formatter
Copy link
Collaborator

@FanhaiLu1 FanhaiLu1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Hans for adding gemma support and thanks for refactor engine, env to make code more clear! Overall, it looks great!

# "replicated" to signify "replicated".
# Integer signify axis to shard: 0 <= shard axis < rank

freqs_cis : null # torch.complex64 (16384, 128)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep consistency on replicated sharding? If either null or -1 is fine, shall we just keep -1 in our code base (use null in gemma, but -1 in llama)?

)
return caches

def sharding_by_name(self, name):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, it's more clear than previous hardcode one.

"""Attention module."""

def __init__(self, args, env):
def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, it's nice to see the layers is decoupled with args.

print(f"---------> {jax.devices()}")

env, model_arg = helpers.make_env_tiny(bf16_enable=False)
torch.set_default_dtype(torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this line. helpers.make_env_tiny has code to do set_default_dtype:

torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
torch.set_default_dtype(torch_dtype)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as other torch.set_default_dtype in this files.

@qihqi qihqi force-pushed the hanq_add_model branch from eb739eb to f5de8d4 Compare May 4, 2024 00:56
@FanhaiLu1 FanhaiLu1 merged commit 9353640 into main May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants