-
Notifications
You must be signed in to change notification settings - Fork 220
[tx] fix type checking in layers folder #681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@pcmoritz @tyler-griggs Please take a look at your earliest convenience. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses type checking issues in tx/layers/lora.py by adding necessary imports, type hints, and runtime checks. The changes are a good step towards improving type safety. My review includes a few suggestions to further enhance the code. I've recommended combining imports for better readability and replacing assert statements with more robust runtime checks using if/raise to prevent potential issues in production environments where assertions might be disabled.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
pcmoritz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the contribution! I removed the tx/models/__init__.py since it often leads to circular imports and simplified the handling of the model class :)
| *shape_A, | ||
| dtype=dtype, | ||
| kernel_init=nnx.with_partitioning(nnx.initializers.he_uniform(), sharding_A), | ||
| kernel_init=nnx.with_partitioning(nnx.initializers.he_uniform(), tuple(sharding_A)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately this was invalid, it generates W1119 05:41:01.754788 47420 spmd_partitioner.cc:4935] You have to use Shardy for RaggedDot. If not, the behavior is undefined.
Will fix it with #682
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind, this warning seems to actually come from the recently released Jax 0.8.1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created an issue here jax-ml/jax#33403
Some additional fixes on top of #681 and also some more errors that happen when upgrading ty to the recently released 0.0.1a27 Makes the signatures slightly stricter with less defaults too.
This PR fixes all type checking errors using `ty` inside the `tx/layers` folder as requested in NovaSky-AI#673. ``` (base) ray@4de7f1d7011b:/workspace/SkyRL/skyrl-tx$ uv run --extra dev ty check tx/layers/util.py All checks passed! (base) ray@4de7f1d7011b:/workspace/SkyRL/skyrl-tx$ uv run --extra dev ty check tx/layers/lora.py All checks passed! ``` Once I finish with with the entire `skyrl-tx` folder, I'll add it to CI. --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
Some additional fixes on top of NovaSky-AI#681 and also some more errors that happen when upgrading ty to the recently released 0.0.1a27 Makes the signatures slightly stricter with less defaults too.
This PR fixes all type checking errors using
tyinside thetx/layersfolder as requested in #673.Once I finish with with the entire
skyrl-txfolder, I'll add it to CI.