-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Using torch 2.1.1, running bash examples/bert/train_bert_340m_distributed.sh produces JIT error due to the Sequence annotator in calculate_logits_max
return torch.jit.script(fn, _rcb=rcb) File "/root/sharedDisk/home/tanwenxuan/miniconda3/lib/python3.8/site-packages/torch/jit/_script.py", line 1381, in script
File "/root/sharedDisk/home/tanwenxuan/miniconda3/lib/python3.8/site-packages/torch/jit/_script.py", line 1381, in script
return torch.jit.script(fn, _rcb=rcb)
return torch.jit.script(fn, _rcb=rcb) File "/root/sharedDisk/home/tanwenxuan/miniconda3/lib/python3.8/site-packages/torch/jit/_script.py", line 1381, in script
return torch.jit.script(fn, _rcb=rcb) File "/root/sharedDisk/home/tanwenxuan/miniconda3/lib/python3.8/site-packages/torch/jit/_script.py", line 1381, in script
File "/root/sharedDisk/home/tanwenxuan/miniconda3/lib/python3.8/site-packages/torch/jit/_script.py", line 1381, in script
fn = torch._C._jit_script_compile(
RuntimeError:
Unknown type constructor Sequence:
File "/root/sharedDisk/home/tanwenxuan/Megatron-LM/megatron/core/tensor_parallel/utils.py", line 106
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, rank, world_size: int
) -> Sequence[int]:
~~~~~~~~~~~~~ <--- HERE
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
'vocab_range_from_per_partition_vocab_size' is being compiled since it was called from 'calculate_predicted_logits'
File "/root/sharedDisk/home/tanwenxuan/Megatron-LM/megatron/core/tensor_parallel/cross_entropy.py", line 41
# Get the partition's vocab indices
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = get_tensor_model_parallel_rank()
'calculate_predicted_logits' is being compiled since it was called from 'calculate_predicted_logits'
To Reproduce
bash examples/bert/train_bert_340m_distributed.sh
Expected behavior
Stack trace/logs
Environment (please complete the following information):
- Megatron-LM e33c8f7
- PyTorch 2.1.1
- CUDA 12.1
Proposed fix
Disable jit if this occurs
Additional context
Add any other context about the problem here.
bentherien and divisionblurnone0663
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working