Conversation
| if update_workspace and quantizer is not None and quantizer.columnwise_usage: | ||
| tensor.update_usage( | ||
| rowwise_usage=quantizer.rowwise_usage, | ||
| columnwise_usage=quantizer.columnwise_usage, |
There was a problem hiding this comment.
This change makes the update_usage logic redundant, so we might as well remove it:
| if update_workspace and quantizer is not None and quantizer.columnwise_usage: | |
| tensor.update_usage( | |
| rowwise_usage=quantizer.rowwise_usage, | |
| columnwise_usage=quantizer.columnwise_usage, |
That leads us to a more fundamental challenge: inference and validation have different data requirements, and we don't really expose a good way to distinguish between them. If you're doing MXFP8 inference, you don't want MXFP8 column-wise data since that's wasted memory. If you're loading a BF16 checkpoint, doing MXFP8 validation, and then doing MXFP8 training, you need to cast to MXFP8 row-wise and column-wise data when loading the checkpoint. Removing this update_usage logic has the effect of "however you've initialized the weights, you're stuck with those buffers forever". This is unfortunate for inference because the default behavior is to initialize weights for training, i.e. with both row-wise and column-wise data (#1827 does add an option to initialize for inference).
I think removing this logic does make sense eventually, especially after Mcore adds support for #1827, but for now I'm wary of breaking users doing inference. For now, is it possible to fix the bug by configuring distopt to avoid overlapping the param AG with validation?
There was a problem hiding this comment.
Thanks! Why do you think this bug has to do with overlapping the param AG with validation? I thought as long as we do training after validation, this error will occur.
There was a problem hiding this comment.
Ah, you're right. I'll incorporate this logic in #1827.
|
I've included these fixes in #1827. |
Description
Fix 2 bugs in mxfp8 training.
(1) The following error happened after validation since colwise data is cleared by update_usage() during validation.
(2) This happened after training completes and copying tensor to CPU:
Type of change
[ ] Bug fix (non-breaking change which fixes an issue)
Changes
Please list the changes introduced in this PR:
Checklist: