-
Notifications
You must be signed in to change notification settings - Fork 438
Improve forward_pass_logit_checker.py to perform mutual conversion check #1839
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
hengtaoguo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work!
|
Hi @gagika ! I've heard this might be interesting to you for loading/saving HF checkpoints. Would you like to take a look when you got a chance? Thanks a lot for your time! |
shralex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Yixuan! Added a few comments
shralex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the comments! I have 1 small comment and also a question -- did you test both directions -- to and from HF ? if so can you add both to the PR description testing section, currently it includes 1 example. Thanks!
from HF conversion with examples is pushed in previous PR: #1785 and #1821. And I have revised the run name. |
|
@YixuanWang-99 thank you for consolidating these files. Before merging this, lets make sure that end-to-end tests using forward logits checker still work - can you please run a couple of these tests. |
b2faae0 to
015c6cf
Compare
b2faae0 to
d8de947
Compare
Thank you for the constructive feedback! The new flag Workload that runs a full test_gemma.sh: link |
a0725a2 to
49da068
Compare
Description
This update to forward_pass_logit_checker.py enables direct comparisons between MaxText and Hugging Face model checkpoints.
Previously, the script could only compare a single checkpoint (either Hugging Face or MaxText) against a set of "golden logits." This was problematic for fine-tuned models, as their outputs often diverge from the original golden logits. Additionally, when converting models between MaxText and Hugging Face formats, it was difficult to verify the conversion's accuracy.
forward_pass_logit_checker .pywill run both MaxText and HuggingFace models on-the-fly and compare their output logits, including evaluating output logits for the last token prediction, top-k predicted tokens and their corresponding scores, and KL-divergence between the full logit distributions, ensuring similarity.This enhancement is crucial for verifying that model conversions accurately preserve predictive behavior.
Tests
Tested on Gemma-2b Model, with an example to comparing MaxText/Hugging Face models runs:
A successful check between huggingface and MaxText checkpoints like this. And the similarity and KL div check should be with no errors
Checklist
Before submitting this PR, please make sure (put X in square brackets):