Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the bug of using loss before assignment #700

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

LiuXTao
Copy link

@LiuXTao LiuXTao commented Feb 22, 2024

Bug Description
A bug is triggered when using MoE in conjunction with pipeline_parallel_size > 1, resulting in a 'referenced before assignment' error.
The complete error report is as follows

Traceback (most recent call last):  
  File "/mnt/nanjing3cephfs/mm-base-plt2/dev-xtl/temp-test/Megatron-LM-fix/pretrain_gpt.py", line 207, in <module>  
    pretrain(train_valid_test_datasets_provider,
  File "/mnt/nanjing3cephfs/mm-base-plt2/dev-xtl/temp-test/Megatron-LM-fix/megatron/training.py", line 258, in pretrain
    iteration, num_floating_point_operations_so_far = train(
  File "/mnt/nanjing3cephfs/mm-base-plt2/dev-xtl/temp-test/Megatron-LM-fix/megatron/training.py", line 970, in train
    train_step(forward_step_func,
  File "/mnt/nanjing3cephfs/mm-base-plt2/dev-xtl/temp-test/Megatron-LM-fix/megatron/training.py", line 535, in train_step
    losses_reduced = forward_backward_func(
  File "/mnt/nanjing3cephfs/mm-base-plt2/dev-xtl/temp-test/Megatron-LM-fix/megatron/core/pipeline_parallel/schedules.py", line 1212, in forward_backward_pipelining_without_interleaving
    output_tensor = forward_step(
  File "/mnt/nanjing3cephfs/mm-base-plt2/dev-xtl/temp-test/Megatron-LM-fix/megatron/core/pipeline_parallel/schedules.py", line 216, in forward_step
    config.grad_scale_func(torch.tensor(1.0, device=loss.device))
UnboundLocalError: local variable 'loss' referenced before assignment

Script
The above BUG can be reproduced using my script examples/pretrain_gpt_moe_demo.sh

Solution
Upon examining the error code, I noticed a potential issue in line 216 of megatron/core/pipeline_parallel/schedules.py:

config.grad_scale_func(torch.tensor(1.0, device=loss.device))

The variable loss runs the risk of being referenced before assignment. Therefore, I suggest modifying it as per my pull request:

config.grad_scale_func(torch.tensor(1.0, device=output_tensor.device))

fix: remove useless params
Copy link

Marking as stale. No activity in 60 days.

@github-actions github-actions bot added the stale No activity in 60 days on issue or PR label Apr 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale No activity in 60 days on issue or PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant