-
Notifications
You must be signed in to change notification settings - Fork 18
Move flags in scripts to a common function #92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move flags in scripts to a common function #92
Conversation
10e1c71 to
75d7fc3
Compare
jetstream_pt/config.py
Outdated
|
|
||
| def define_common_flags(): | ||
| """Add common config flags to global FLAG.""" | ||
| flags.DEFINE_string( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just define those in the top level. and then whoever imports it will have that flags. see: https://source.corp.google.com/search?q=flags.DEFINE_string
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, done. Intention was to optionally import all common flags from scripts
run_server.py
Outdated
| devices = server_lib.get_devices() | ||
| print(f"devices: {devices}") | ||
| sharding_config_path = _SHARDING_CONFIG.value | ||
| engine = jetstream_pt.create_pytorch_engine( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not use create_engine_from_flags here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed this one, done
run_server.py
Outdated
| ) | ||
| server_config = ServerConfig( | ||
| interleaved_slices=(_PLATFORM.value,), | ||
| interleaved_slices=(FLAGS.platform,), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's get rid of this:
let's do f"tpu={len(jax.devices())}" here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| define_common_flags() | ||
| define_profiling_flags() | ||
|
|
||
| _PORT = flags.DEFINE_integer("port", 9000, "port to listen on") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should leave the thread / port flags in this file instead of in the config file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those 2 are defined in line 32 and 33, used a different way to define them and the global var can be avoided. The flag value lis are accessible by FLAGS.port, FLAGS is an exisiting global from absl.flag
FanhaiLu1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for refactor the flags!
* refactor flags * clean up: * fix run_server * move common flags to global * format * update * udpate readme * update run_interactive
* refactor flags * clean up: * fix run_server * move common flags to global * format * update * udpate readme * update run_interactive
* Stable version of ragged attention. * Converts the attention output types the same as q. * Fixes the typo for the ragged attention. * Provides the default value for partition_by_axis. * Provides mesh to the shard_map. * Fixes typo. * Fixes typo, should be start instead of start_pos. * Should use "//" instead of "/" to get int results. * Use block size // 2 as the starting current position for better initial performance. Fix the typo that should use jax.lax.div instead of jnp.div * Updates the run_interactive script to use the correct result token processing API from JetStream. * Fix typo, should use token_utils.process_result_token. * Fix typo. * Fixes the sampled tokens list. * Use text_tokens_to_str to convert the output tokens. * Reshape the precomputed grid indices to 1D. Removes the dense_attention_quantized and use option to control if it's quantization or not. Use the new torch_xla2 API. * Should check if X is None instead of if X * Fix the dense_attention not returning data. * Reshape the kv scaler to 3 dim for ragged attention. * Cannot stop the input_pos counter from increasing since we are using a ring buffer. Will cause error. * Adds starting_position and profiling_prefill for better testing and benchmarking. * Move flags in scripts to a common function (#92) * refactor flags * clean up: * fix run_server * move common flags to global * format * update * udpate readme * update run_interactive * Stable version of ragged attention. * Fix the merge conflicts * Fixes the missing pieces after merging conflicts. Adds couple of new flags for debugging and performance tuning. * Integrates ragged attention to Gemma too. * Somehow have some local changes to run_interactive, reverting them to align with main. * Set the default value for the newly added parameters. * Adds more descriptions to the ragged attention index precompuation function. * Merges the quantized ragged attention kernel with the non quantized version. * Moves the attention calculation to attention.py for better code structure. * Fix run issues refactoring. * Fix the quantized version for ragged attention. * Fix test_attention by adding default value for the newly added arguments. The error message is missing positional arguments. * Fixes unit tests, changes the Transformer model call argument order(input_pos) back to original to avoid unnecessary issues. * Format attention_kernel.py * Add descrpitions to ragged attention outputs. * Fix quantization tests by adding default value to quantization kernel class. * Reformat attention_kernel.py. Format with black doesn't comply with the pylink rules. * Ignores R0913: Too many arguments link error for ragged attention kernel. Fix other lint errors. * Ignore R0903: Too few public methods. Fix lint errors. * Fix the rest of the lint errors. --------- Co-authored-by: Siyuan Liu <lsiyuan@google.com>
Make the flag config cleaner
jetstream_pt/config.pycreate_configinconfig.pyfixes #83