-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPTSFTChatDataset loss_mask becomes all False when prompt length > max_seq_length #8025
Comments
I ran into the same issue => hacked the code to zero out the loss on problematic micro-batches: odelalleau@b383e6a (obviously not a proper fix but can be useful to get unblocked) |
Yeah, this will also fix the problem. |
This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days. |
This issue should keep active since the same problem has been encountered by others as well. |
This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days. |
This issue was closed because it has been inactive for 7 days since being marked as stale. |
Describe the bug
In the
GPTSFTChatDataset
, if the first prompt length exceedsmax_seq_length
, all following turns are truncated out. Then theloss_mask
becomes allFalse
for the example.https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py#L359
This is problematic because, when the
loss_mask
is all False, theloss
of the MegatronGPTModel will benan
.https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L1015
Steps/Code to reproduce bug
This can be reproduced by passing into
GPTSFTChatDataset
an example whose first turn prompt has > 2048 tokens when themax_seq_length=2048
. Then use theGPTSFTChatDataset
in a supervised fine-tuning job (e.g.,train_gpt_sft.py
in NeMo-Aligner)Expected behavior
In the
collate_fn
function, check if the loss_masks of all examples are not empty. If not, raise an error.The text was updated successfully, but these errors were encountered: