Skip to content

Conversation

@lsy323
Copy link
Collaborator

@lsy323 lsy323 commented May 17, 2024

Make the flag config cleaner

  • Move common configs to jetstream_pt/config.py
  • Remove the unused create_config in config.py

fixes #83

@lsy323 lsy323 requested review from FanhaiLu1, bhavya01, qihqi and wang2yn84 and removed request for FanhaiLu1, bhavya01 and qihqi May 17, 2024 23:46
@lsy323 lsy323 force-pushed the lsiyuan/refactor-flags branch from 10e1c71 to 75d7fc3 Compare May 17, 2024 23:51

def define_common_flags():
"""Add common config flags to global FLAG."""
flags.DEFINE_string(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just define those in the top level. and then whoever imports it will have that flags. see: https://source.corp.google.com/search?q=flags.DEFINE_string

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, done. Intention was to optionally import all common flags from scripts

run_server.py Outdated
devices = server_lib.get_devices()
print(f"devices: {devices}")
sharding_config_path = _SHARDING_CONFIG.value
engine = jetstream_pt.create_pytorch_engine(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use create_engine_from_flags here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed this one, done

run_server.py Outdated
)
server_config = ServerConfig(
interleaved_slices=(_PLATFORM.value,),
interleaved_slices=(FLAGS.platform,),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's get rid of this:

let's do f"tpu={len(jax.devices())}" here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

define_common_flags()
define_profiling_flags()

_PORT = flags.DEFINE_integer("port", 9000, "port to listen on")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should leave the thread / port flags in this file instead of in the config file.

Copy link
Collaborator Author

@lsy323 lsy323 May 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those 2 are defined in line 32 and 33, used a different way to define them and the global var can be avoided. The flag value lis are accessible by FLAGS.port, FLAGS is an exisiting global from absl.flag

@lsy323 lsy323 requested a review from qihqi May 18, 2024 00:17
Copy link
Collaborator

@FanhaiLu1 FanhaiLu1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for refactor the flags!

@FanhaiLu1 FanhaiLu1 merged commit 0fe239b into AI-Hypercomputer:main May 20, 2024
@lsy323 lsy323 deleted the lsiyuan/refactor-flags branch May 20, 2024 17:39
wang2yn84 pushed a commit that referenced this pull request May 21, 2024
* refactor flags

* clean up:

* fix run_server

* move common flags to global

* format

* update

* udpate readme

* update run_interactive
wang2yn84 pushed a commit that referenced this pull request May 23, 2024
* refactor flags

* clean up:

* fix run_server

* move common flags to global

* format

* update

* udpate readme

* update run_interactive
wang2yn84 added a commit that referenced this pull request May 23, 2024
* Stable version of ragged attention.

* Converts the attention output types the same as q.

* Fixes the typo for the ragged attention.

* Provides the default value for partition_by_axis.

* Provides mesh to the shard_map.

* Fixes typo.

* Fixes typo, should be start instead of start_pos.

* Should use "//" instead of "/" to get int results.

* Use block size // 2 as the starting current position for better initial performance. Fix the typo that should use jax.lax.div instead of jnp.div

* Updates the run_interactive script to use the correct result token processing API from JetStream.

* Fix typo, should use token_utils.process_result_token.

* Fix typo.

* Fixes the sampled tokens list.

* Use text_tokens_to_str to convert the output tokens.

* Reshape the precomputed grid indices to 1D. Removes the
dense_attention_quantized and use option to control
if it's quantization or not. Use the new torch_xla2 API.

* Should check if X is None instead of if X

* Fix the dense_attention not returning data.

* Reshape the kv scaler to 3 dim for ragged attention.

* Cannot stop the input_pos counter from increasing since we are using a ring buffer. Will cause error.

* Adds starting_position and profiling_prefill for better testing and benchmarking.

* Move flags in scripts to a common function (#92)

* refactor flags

* clean up:

* fix run_server

* move common flags to global

* format

* update

* udpate readme

* update run_interactive

* Stable version of ragged attention.

* Fix the merge conflicts

* Fixes the missing pieces after merging conflicts. Adds couple of new flags for debugging and performance tuning.

* Integrates ragged attention to Gemma too.

* Somehow have some local changes to run_interactive, reverting them to align with main.

* Set the default value for the newly added parameters.

* Adds more descriptions to the ragged attention index precompuation function.

* Merges the quantized ragged attention kernel with the non quantized version.

* Moves the attention calculation to attention.py for better code structure.

* Fix run issues refactoring.

* Fix the quantized version for ragged attention.

* Fix test_attention by adding default value for the newly added arguments. The error message is missing positional arguments.

* Fixes unit tests, changes the Transformer model call argument order(input_pos)  back to original to avoid unnecessary issues.

* Format attention_kernel.py

* Add descrpitions to ragged attention outputs.

* Fix quantization tests by adding default value to quantization kernel class.

* Reformat attention_kernel.py. Format with black doesn't comply with the pylink rules.

* Ignores R0913: Too many arguments link error for ragged attention kernel. Fix other lint errors.

* Ignore R0903: Too few public methods. Fix lint errors.

* Fix the rest of the lint errors.

---------

Co-authored-by: Siyuan Liu <lsiyuan@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Clean up flags

3 participants